随着深度学习模型的应用和推广,人们逐渐发现模型常常会利用数据中存在的虚假关联(Spurious Correlation)来获得较高的训练表现。但由于这类关联在测试数据上往往并不成立,因此这类模型的测试表现往往不尽如人意 [1]。其本质是由于传统的机器学习目标(Empirical Risk Minimization,ERM)假设了训练测试集的独立同分布特性,而在现实中该独立同分布假设成立的场景往往有限。

在很多现实场景中,训练数据的分布与测试数据分布通常表现出不一致性,即分布偏移(Distribution Shifts),旨在提升模型在该类场景下性能的问题通常被称为分布外泛化(Out-of-Distribution Generalization)问题。

关注学习数据中的相关性而非因果性的 ERM 等一类方法往往难以应对分布偏移。尽管近年涌现了诸多方法借助因果推断(Causal Inference)中的不变性原理(Invariance Principle)在分布外泛化问题上取得了一定的进展,但在图数据上的研究依然有限。这是因为图数据的分布外泛化比传统的欧式数据更加困难,给图机器学习带来了更多的挑战。

本文以图分类任务为例,对借助因果不变性原理的图分布外泛化进行了探究。

图片

近年来,借助因果不变性原理,人们在欧式数据的分布外泛化问题中取得了一定的成功,但对图数据的研究仍然有限。与欧式数据不同,图的复杂性对因果不变性原理的使用以及克服分布外泛化难题提出了独特的挑战。

为了应对该挑战,我们在本工作中将因果不变性融入到图机器学习中,并提出了因果启发的不变图学习框架,为解决图数据的分布外泛化问题提供了新的理论和方法。论文已在 NeurIPS 2022 发表并入选了会议的 Spotlight Presentation(比例约 5%),本工作由香港中文大学、香港浸会大学、腾讯 AI Lab 以及悉尼大学合作完成。

图片
论文标题:
Learning Causally Invariant Representations for Out-of-Distribution Generalization on Graphs

论文链接:

https://openreview.net/forum?id=A6AFK_JwrIW

项目代码:

https://github.com/LFhase/CIGA

本文通过因果推断的角度,首次将因果不变性引入至多种图分布偏移下的图分布外泛化问题中,并提出了一个全新的具有理论保证的解决框架 CIGA。大量实验也充分验证了 CIGA 优秀的分布外泛化性能。

内容中包含的图片若涉及版权问题,请及时与我们联系删除