Stanford CS224W: Machine Learning with Graphs

By Siavosh Shadpey, Vasanti Wall-Persad as part of the Stanford CS 224W Final Project


代码下载:https://t.zsxq.com/Kuly2


请索引第36个项目



深度学习在药物发现中的应用


药物研发面临着规模上的根本挑战。对于任何一种致病蛋白,都存在数百万种潜在的候选药物(称为配体的小分子,它们可能与该蛋白结合并抑制其功能)。对每一种候选药物进行实验测试既耗时又昂贵,通常需要数年时间,而且每成功研发出一种药物就要花费数百万美元。问题在于:我们如何在进入实验室之前,有效地缩小这个庞大的搜索范围?


近年来,深度学习已成为完成这项任务的强大工具,神经网络能够学习预测哪些分子会与其蛋白质靶标牢固结合。虽然结合亲和力最终取决于复杂的量子力学相互作用和热力学稳定性,但分子的三维结构提供了关于这些潜在性质以及结合所需的精确几何匹配的重要信息。


三肽配体与寡肽结合蛋白结合的示意图


为什么选择图论?为什么选择图神经网络?


蛋白质和配体的分子结构天然适合用图来表示。每个原子都作为一个节点,并具有相关的化学性质(原子类型、电荷、杂化方式),而原子之间的连接(强共价键和较弱的非共价键)则构成边。图神经网络(GNN)可以直接处理这种天然的表示形式,学习理解定义分子结合相互作用的复杂拓扑和化学关系。


这种基于图的方法相比于其他分子表示方法(如序列表示(例如 SMILES 字符串)、分子指纹或距离矩阵表示)具有显著优势,因为 GNN 可以同时显式地对连接性、键信息和 3D 几何形状进行建模,而这些都与结合亲和力相关 [10]。


挑战:标记数据有限


尽管基础已相当充分,但要准确预测结合亲和力并达到治疗药物开发所需的精度仍然是一项艰巨的挑战[1]。主要瓶颈在于高质量标注数据的匮乏。实验测定结合亲和力需要耗费大量时间和资源,导致我们只能获得相对较小的训练数据集(仅有数万个样本,而计算机视觉或自然语言处理领域则拥有数百万个样本)。如果没有足够的标注数据,模型就难以泛化到训练所用的特定蛋白质-配体复合物之外的其他情况。


我们的解决方案:自主预训练


为了克服数据限制,我们采用了自监督学习。我们并非直接训练模型预测结合亲和力,而是首先通过一个无需标签的预训练任务,教会模型理解分子几何结构。


我们提出的创新方法结合了两个关键的架构组件:


  • 掩码自编码器(MAE):我们随机掩蔽分子的一部分,并训练我们的模型来重建隐藏原子的 3D 坐标,迫使它学习有意义的结构表示。

  • 等变图神经网络(EGNN):我们确保我们的模型尊重 3D 空间的基本对称性。


这种预训练策略使我们的EGNN编码器无需亲和力标签即可学习鲁棒且可泛化的几何特征。编码器必须充分理解分子几何结构,才能根据周围环境预测被掩蔽原子的位置。预训练完成后,我们对一个轻量级回归头进行微调,用于下游的结合亲和力预测任务。


在这篇文章中,我们将详细介绍我们如何构建这个 MAE-EGNN 框架、我们的设计选择和优化步骤,以及我们如何在行业标准 CASF-2016 结合亲和力预测基准测试中取得具有竞争力的性能。


数据集:PDBbind CleanSplit


我们使用PDBbind CleanSplit 数据集训练和评估我们的模型,PDBbind CleanSplit 是 Graber 等人 [1] 创建的标准 PDBbind 数据集的改进版本。PDBbind CleanSplit 经过精心设计,旨在解决原始数据集的两个关键问题:


  • 数据泄露:确保测试集和验证集中的复杂样本与训练样本真正不同。这迫使模型学习可泛化的模式,而不是记忆特定的样本。

  • 训练冗余:去除相似的复合体,从而降低模型过拟合训练数据的倾向。这会产生一个规模更小但更多样化的训练集,Graber 等人观察到,这实际上提高了模型在预留测试数据上的性能。


我们使用的 PDBbind CleanSplit 数据集包括训练集中的 13,168 个蛋白质-配体复合物、验证集中的 3,294 个蛋白质-配体复合物和 CASF-2016 基准测试集(测试集)中的 282 个蛋白质-配体复合物。


数据分析



结合亲和力分布:我们通过分布分析验证了PDBbind CleanSplit的质量。该数据集包含从弱到强的蛋白质-配体相互作用的结合亲和力(pKd),其值范围为0.03至0.94。三个分割结果显示出非常相似的亲和力分布,证实所有分割结果均取自相同的结合强度范围。

原子位置分布:空间分布也保持一致。由于测试集 (CASF-2016) 经过精心筛选,作为高质量基准数据集,其方差略低(27.6 Å 对比 41.7 Å)。这些结果验证了 CleanSplit 能够防止数据泄露,同时保持分割的代表性,并确保模型在学习过程中接触到各种几何结构,即使并非所有几何结构在测试时都得到同等程度的体现。


从蛋白质到图:输入表示


PDBBind CleanSplit 图构建过程。


此图概述了构建模型输入数据的图的过程。我们数据集中的每个样本都代表一个亲和力标记的蛋白质-配体复合物,该复合物以统一的相互作用图的形式呈现,包含以下组成部分:


  • 节点:配体原子(以原子类型、质量、部分电荷、杂化方式等原子属性为特征)和蛋白质口袋氨基酸(按残基类型进行独热编码)。配体和蛋白质的节点特征相互连接,每个节点最终包含 60 个特征。

  • 边:有两种类型:配体原子之间的共价键,以及将空间上邻近的配体原子与 5Å 以内的酸节点连接起来的非共价相互作用。所有边都有 16 个边属性。

  • 几何结构:所有配体原子和蛋白质中心碳原子的明确三维坐标


我们遵循 Graber 等人记录的预处理步骤,并针对我们的 MAE 预训练目标进行了三项修改:


  • 包含明确的 3D 原子坐标,这对于通过掩码位置重建来学习几何结构至关重要。

  • 省略预先计算的基于距离的边缘特征,使 EGNN 能够直接从坐标中学习几何关系。

  • 排除外部语言模型嵌入,以便将分析重点放在 EGNN 模型的功能上,并保持模型的轻量级。


这种表示方法既捕捉了结合位点的化学特性,又捕捉了其空间排列,为我们的模型提供了通过自监督重建学习几何结构所需的信息。


蛋白质三维结构的重要性


“结构决定功能。”


具有生物学背景的人可能对这句谚语并不陌生。对于不熟悉的人,我们来简单解释一下。蛋白质复杂的3D结构并非随机形成;它是无限多种可能形状中唯一一种能够达到最稳定、热力学能量最低状态的形状。未折叠的蛋白质总是会折叠回其天然形状,就像落入山谷的球总是会沉到谷底一样。研究蛋白质的形状可以揭示其许多特性和性质。例如,药物研究人员首先需要确定蛋白质的结合口袋,才能设计出能够与之结合的药物,就像设计一把能够完美打开锁的钥匙一样。

导致新冠肺炎疫情的SARS-CoV-2刺突糖蛋白的复杂结构。蛋白质处于天然状态,采用能量最低的构象。


鉴于此,在要求我们的模型预测药物(或配体)与蛋白质的结合强度(也称为蛋白质-配体亲和力预测)之前,我们首先的目标是教会模型理解蛋白质和配体的三维结构。我们认为,具备这种理解能力的模型能够做出更准确的亲和力预测。


掩码自编码器


但是,我们如何教会模型理解三维分子形状,尤其是在我们自身未必完全理解这些形状的情况下?这时,掩码自编码器(MAE)框架就派上了用场。我们选取一个分子,随机掩码(隐藏)其部分结构,并强制模型仅使用未被掩码的部分来重建被掩码的部分。为了使模型能够准确完成这项具有挑战性的任务,它必须隐式地学习描述分子结构的特征。这种方法的另一个重要优势在于,我们无需提供任何关于分子形状的专家领域知识:模型能够自主学习并判断哪些特征是必要的,哪些是不必要的。


更具体地说,我们在 MAE 框架内预训练一个编码器,方法是随机掩蔽分子中大部分原子,然后训练自编码器以自监督的方式重建这些原子的原始 3D 坐标。我们最小化的重建损失定义为掩蔽原子的实际坐标与预测坐标之间的均方距离。预训练的编码器随后将生成关于分子结构的通用、鲁棒且有意义的特征。我们的方法借鉴了其他成功的几何 MAE 的一些思想,例如用于 3D 计算机视觉的 PointMAE [2] 和用于生物化学的 ProteinMAE [3]。例如,ProteinMAE 网络将分子表面细分为重叠的区域,并使用视觉变换器 (ViT) 来重建掩蔽区域。然而,我们认为使用图神经网络 (GNN) 代替 ViT 可以:


  • 由于 ViT 的时间复杂度为二次方,因此效率更高;

  • 能够更好地捕捉分子结构,因为分子本身就是由原子和键构成的图来表示的。


此外,执行原子级掩蔽而不是基于区域的掩蔽是一项更具挑战性的任务,并且有可能使模型学习到有关分子结构的更细致的细节。

我们的MAE框架。分子中随机选取一部分原子进行掩蔽(灰色)。自编码器的任务是根据可见原子的三维坐标重建被掩蔽原子的三维坐标。


MAE实施


每个配体-蛋白质复合物最初都表示为一个图,其中原子是节点,原子键是边,正如前文“为什么使用图?为什么使用图神经网络?”部分所述。我们通过随机掩蔽一部分节点并移除它们的所有边来生成一个子图。然后,将这个包含可见节点的子图输入到图神经网络编码器中,编码器会为每个可见原子生成一个嵌入特征。接下来,将这些嵌入特征以及掩蔽节点的共享可学习标记一起输入到解码器中。解码器利用可见节点嵌入的上下文信息来转换掩蔽标记,并重建掩蔽原子的原始三维坐标。

from torch_geometric.utils import subgraph
#...
classMaskedGeometricAutoencoder(nn.Module):
"""
    Masked geometric autoencoder.

    Args:
        encoder (nn.Module): Encoder model.
        decoder (nn.Module): Decoder model.
        masking_ratio (float): Ratio of masking to apply on the input data.
    """

def__init__(self,
                 encoder: nn.Module,
                 decoder: nn.Module,
                 masking_ratio: float
):
#...

defforward(self,
                x  : Tensor,
                pos: Tensor,
                edge_index: Adj,
                edge_attr: Tensor
) -> Tuple[Tensor, Tensor]:

# Randomly mask nodes.
        num_nodes = x.size(0)
        num_masked = int(self.masking_ratio * num_nodes)
        num_visible = num_nodes - num_masked
        node_indices_perm = torch.randperm(num_nodes, device=x.device)
        mask_indices, vis_indices = (node_indices_perm[:num_masked], 
                                     node_indices_perm[num_masked:])

# Subgraph of visible nodes. 
# Edges connected to masked nodes are removed.
        edge_index_vis, edge_attr_vis = subgraph(
            vis_indices, 
            edge_index, 
            edge_attr,
            num_nodes=num_nodes,
            relabel_nodes=True# relabels edges from 0 to num_visible-1.
            )

# Encode visible nodes. 
        x_vis = x[vis_indices]
        pos_vis = pos[vis_indices]
        x_vis, pos_vis = self.encoder(x_vis, pos_vis,
                                      edge_index_vis, edge_attr_vis)

# Unlike the visible nodes, which use the encoded features
# and positions, the masked nodes share the same learnable token
# and their positions are initialized randomly.
        z_masked = self.masked_node_token.repeat(num_masked, 1).to(x.device)
        pos_masked = torch.randn((num_masked, pos.size(1)), device=pos.device)

# Combine visible and masked nodes.
        z = torch.empty((num_nodes, x_vis.size(1)), device=x.device)
        z[vis_indices,:] = x_vis
        z[mask_indices,:] = z_masked
        pos[vis_indices,:] = pos_vis
        pos[mask_indices,:] = pos_masked

# Decode to reconstruct masked node positions.
        pos_reconstructed = self.decoder(z, pos, edge_index, edge_attr)
return pos_reconstructed[mask_indices,:], mask_indices


等变 GNN 层

剧毒的氰化氢化合物的化学式。


考虑上图所示的氰化氢化合物。我们可以用其三个原子的三维坐标来表示它。但是,我们应该如何定义坐标系呢?碳原子自然而然地成为了原点。那么,氢原子应该在其左侧,氮原子在其右侧吗?还是应该在其右侧,氮原子在其左侧?定义坐标系的方法有很多种,而且每一种都是有效的。然而,在数值计算中,坐标系的选择至关重要。例如,如下图所示,一个用于预测热容的简单训练的多层感知器(MLP)模型,对于同一分子在不同方向上的计算结果会有所不同。

一个简单的模型,如果不能仔细处理坐标,对于同一个分子的不同取向,就会输出不同的结果(例如热容)。


缓解此问题的一种可能方法是数据增强。这涉及使用相同分子在不同随机方向上的数据集训练模型,但保持标签不变,从而使模型能够学习对几何变换的鲁棒性。由于三维旋转空间的无限性,这种方法需要大量的数据增强,因此即使对于规模适中的数据集,也会很快变得难以扩展。此外,这种方法还迫使模型消耗一部分运算能力来学习这种不变性;而这些运算能力原本可以用来学习对训练任务更有意义的特征。


更优雅高效的解决方案是设计一种具有内置平移和旋转不变性(称为SE(3)不变性)的架构。Satorras等人[4]提出的E(n)等变GNN(EGNN)层正是如此。事实上,它满足了我们模型所依赖的两个关键性质:


  • 它以不变的方式更新节点特征,不受旋转、平移和反射的影响,这种不变性被称为E(3)不变性。因此,最终的图预测结果(例如热容)将保持一致,与分子的方向无关。

  • 它以 E(3) 等变的方式更新节点位置。这意味着,如果我们旋转输入分子并将其输入到 EGNN 层,输出坐标将旋转相同的角度。


EGNN层由邻居间消息传递步骤、节点位置更新步骤和节点特征更新步骤组成。其算子φ 是可微的、参数化的逐元素边、位置和特征函数,类似于MLP层。


将邻居之间的欧氏距离平方作为消息传递函数的输入是实现第一个特性的关键,而通过相对距离的加权和来更新节点位置则确保了节点位置的 E(3) 等变性。由于我们的模型需要在预训练期间重建坐标,因此这种等变性使得 EGNN 层成为我们工作中不可或缺的组成部分。与其他 GNN 模型类似,EGNN 也具有节点置换等变性。


举例来说,下面的代码表明 EGNN 模型是 E(3) 等变的,而 GraphSAGE 模型的位置感知扩展 [5] 则不是。正式证明可以在论文中找到。

# Define EGNN model, x, edge_index, etc.
# ...
GraphSAGE_model = GraphSAGE(in_channels=in_channels+pos_dim,
                            hidden_channels=hidden_channels,
                            num_layers=num_layers,
                            out_channels=hidden_channels)

# Adapt graphSAGE to take in positions as part of node features.
defGeo_GraphSAGE_model(x, pos, edge_index):
    x = GraphSAGE_model(torch.concat([x, pos], dim=-1), edge_index)
return x

# Define a random rotation and translation.
R = torch.randn((pos_dim, pos_dim))
U, _, Vt = torch.linalg.svd(R)
R = U @ Vt  # Ensure it's a proper rotation matrix.
t = torch.randn((1, pos_dim))

pos_transformed = (pos @ R) + t
for name, model inzip(["EGNN""Geo GraphSAGE"], 
                       [EGNN_model, Geo_GraphSAGE_model]):
    out = model(x, pos, edge_index)
    out_transformed = model(x, pos_transformed, edge_index)
if torch.allclose(out_transformed, out, atol=1e-5):
print(f"The {name} model is E(3) equivariant.")
else:
print(f"The {name} model is NOT E(3) equivariant.")
The EGNN model is E(3) equivariant.
The Geo GraphSAGE model is NOT E(3) equivariant.


下面显示的 EGNN 层,以及 EGNN 模型实现可以在我们的GitHub 存储库中找到,或者目前可以在 PyGeometric 存储库中找到 PR。

classEGNNConv(MessagePassing):
"""
    Equivariant Graph Neural Network Layer.
    """

def__init__(self, 
                 nn_edge        : Callable,
                 nn_node        : Callable,
                 pos_dim        : int,
                 nn_pos         : Optional[Callable] = None,
                 skip_connection: bool = False,
                 **kwargs
):
#...

defforward(self,
                x          : Tensor,
                pos        : Tensor,
                edge_index : Adj, 
                edge_attr  : Optional[Tensor] = None,
                size       : Size = None
) -> Tuple[Tensor, Tensor]:

# Perform message passing.
        message = self.propagate(edge_index, x=(x, x), pos=(pos, pos), edge_attr=edge_attr, size=size)
if self.nn_pos isnotNone:
            (message_node, message_pos) = torch.split(message, [message.size(-1) - self.pos_dim, self.pos_dim], dim=-1)
            out_pos = pos + message_pos
else:
            message_node = message
            out_pos = pos

# Update node features.
        out_node = self.nn_node(torch.cat([x, message_node], dim=-1))
if self.skip_connection:
            out_node = out_node + x
return (out_node, out_pos)

defmessage(self, 
                x_i      : Tensor, 
                x_j      : Tensor, 
                pos_i    : Tensor, 
                pos_j    : Tensor, 
                edge_attr: Optional[Tensor] = None
) -> Tensor:

        pos_diff = pos_i - pos_j
        square_norm = torch.sum(pos_diff ** 2, dim=-1).unsqueeze(-1)
if edge_attr isnotNone:
            out = torch.cat([x_j, x_i, square_norm, edge_attr], dim=-1)
else:
            out = torch.cat([x_j, x_i, square_norm], dim=-1)

        out = self.nn_edge(out)

if self.nn_pos isnotNone:
            out_pos_j = self.nn_pos(out)
            pos_diff = pos_diff/(torch.sqrt(square_norm + eps) + eps)
            message_pos_j = pos_diff * out_pos_j
            out = torch.cat([out, message_pos_j], dim=-1)

return out

defaggregate(self, 
                  inputs  : Tensor, 
                  index   : Tensor, 
                  dim_size: Optional[int] = None
) -> Tensor:

        out_pos = None
if self.nn_pos isnotNone:
            (message_node, message_pos) = torch.split(inputs, [inputs.size(-1) - self.pos_dim, self.pos_dim], dim=-1)
            out_pos = torch_scatter.scatter_mean(message_pos, index, dim=0, dim_size=dim_size)
else:
            message_node = inputs

        out = torch_scatter.scatter_add(message_node, index, dim=0, dim_size=dim_size)

if out_pos isnotNone:
            out = torch.cat([out, out_pos], dim=-1)
return out


对于节点特征(nn_node)和节点位置(nn_pos)函数,我们使用一个线性层,后接Swish非线性激活函数,再接另一个线性层。对于边缘位置(nn_edge),我们使用一个带有Swish函数的两层多层感知器(MLP)。节点位置更新仅适用于解码器,用于推断掩码原子的坐标;在编码器中,由于位置已知且固定,因此节点位置更新被关闭。


模型训练与超参数调优


模型训练分为两个步骤。首先,我们使用基于图神经网络(GNN)的编码器和解码器,在自监督三维重建任务上预训练一个自编码器。其次,我们为预训练的编码器添加一个头部,并在下游亲和力结合回归任务上进行微调。我们使用 Google Vertex AI 超参数调优框架进行了广泛的超参数调优,以确定这两个步骤的最佳配置。该框架采用贝叶斯优化方法。


预训练最佳配置和主要发现


超参数搜索揭示了模型性能的显著差异。我们运行了 22 次迭代,每次迭代 100 个 epoch,以确定最佳参数。性能最佳的配置实现了 2.484 的验证损失。相比之下,性能最差的配置产生了 6.573 的损失,性能下降了 164.5%,这可能是由于批次大小过大和模型容量不足造成的。下表总结了超参数调优结果。



我们选择了验证损失最低的超参数配置,并运行了最后一轮预训练,迭代1000次,以生成预训练嵌入。相应的架构如下图所示。



超参数微调


我们对微调任务进行了超参数优化,主要目的是验证我们选择的回归头架构和池化方法。对于回归头,我们测试了三种架构:MLP、EGNN 以及 MLP/EGNN 组合,并采用了三种不同的全局池化方法(均值池化、求和池化和最大值池化)。我们运行了 50 次迭代,每次迭代 100 个 epoch,以最大化结合亲和力预测的 Pearson 相关系数。


优化最佳配置及主要发现


与预训练相比,微调对超参数选择的敏感性较低。性能最佳的配置达到了 0.72 的皮尔逊相关系数,而最差的配置则达到了 0.530(下降了 36%)。


最优回归头采用比预训练更简单、更窄的架构(64 个隐藏维度,2 个 EGNN 卷积层)。这表明预训练特征已经包含丰富的信息,只需进行极少的转换即可用于下游任务。显著更高的最优批大小(256 对比预训练时的 16)表明,即使使用更大的批大小,预训练后的梯度更新也更加稳定。最佳配置结合了 EGNN 头、全局均值池化和线性外层。这表明均值聚合优于求和池化和最大池化。在此配置下,完整的微调模型(预训练编码器 + 回归头)包含约 150 万个参数。



实验结果


不同的训练策略


为了研究预训练步骤的好处,我们比较了四种不同的训练策略:


  • 冻结编码器:冻结预训练编码器并微调头部。

  • 完全微调:从一开始就同时微调预训练编码器和头部。

  • 延迟微调:在初始几个 epoch 内冻结编码器,然后以较低的学习率对整个模型进行微调。

  • 从零开始:从头开始训练模型。


前三种策略均包括 1000 个 epoch 的预训练和 1000 个 epoch 的微调。为了公平比较,从零开始构建的模型训练了 2000 个 epoch,相当于前三种策略的组合。

在PDBbind CASF-2016测试数据集上,四种不同训练策略的预测准确率如下。预训练、延迟、微调的模型准确率最高。


如上表所示,冻结编码器策略的性能相对较差。其轻量级头部表达能力不足,无法从嵌入特征中准确回归结合亲和力。通过同时微调编码器和头部,模型取得了显著更好的结果。


从一开始就对编码器进行微调的一个缺点是,它可能会丢失一些关于分子结构的通用知识。由于头部是用随机权重初始化的,反向传播产生的梯度会迫使编码器过度补偿头部的不足,这可能会导致编码器“忘记”一些有价值的丰富特征。延迟微调策略可以缓解这个问题。其工作原理如下:


  • 为了确保模型头部训练后能够产生稳定的结果,我们会冻结编码器几个epoch。如下所示,对于我们的模型,编码器冻结至第45 epoch。

  • 然后,学习率降低,在本例中从 1E-4 降低到 1E-5,并对整个模型进行微调。


因此,在四种训练策略中,延迟微调策略表现最佳,测试均方根误差 (RMSE) 为 1.376,皮尔逊相关系数 (R) 为 0.787。值得注意的是,简单地降低其他训练策略的学习率并不能获得更好的结果,这证实了“延迟”是关键因素。


前 200 个 epoch 的 epoch 与验证数据集上的性能指标对比。在训练头部模型时延迟编码器微调可以获得更平滑的收敛效果。


尽管从零开始的模型运行的轮数是原来的两倍,但其性能仍不理想:均方根误差为 1.439,皮尔逊相关系数为 0.759。相比之下,延迟微调的模型在不到 200 个轮次内即可达到相同的性能。


这项研究凸显了我们预训练步骤的优势。它证实了我们的核心假设:理解蛋白质-配体复合物的三维结构对于预测结合亲和力至关重要。


不同训练策略的训练轮数与性能指标对比。微调策略能够在显著减少的训练轮数内实现更准确的预测。


与其他模型进行基准测试


为了更好地理解我们模型的性能,我们将其与在 PDBbind CleanSplit 数据集上训练的三个性能最佳的模型进行了比较:


  • GEMS: 一种基于几何 GNN 的模型,它利用丰富的语言模型嵌入,包括 ESM2 模型的 8M 参数变体 [6] 和 ChemBERTa-2 模型的 77M 参数变体 [7] 作为初始节点特征。[1]

  • GenScore:一种基于图Transformer的模型。[8]

  • Pafnucy:一种常被用作基线的 3D 卷积神经网络模型。在原始的、存在泄漏的 PDBbind 数据集上,Pafnucy 取得了优异的性能,均方误差 (MSE) 为 1.046,皮尔逊相关系数 (Pearson R) 为 0.906,正如 GEMS 的作者所展示的那样。[9]

本文对四个基于PDBbind CleanSplit训练数据集训练、并在PDBbind CASF-2016测试数据集上进行评估的不同模型进行了预测。结果表明,我们的模型在保持轻量级的同时,实现了具有竞争力的性能。


如上表所示,我们的模型极具竞争力,在这些基准测试中取得了第二高的皮尔逊相关系数。我们的结果表明,在 MAE 框架内精心预训练的 EGNN 架构可以生成准确、高效且独立的预测模型,而无需依赖像 GEMS 那样重量级的外部语言嵌入模型。


未来方向


尽管我们的结果令人鼓舞,但仍有许多方向值得探索。我们用于训练和验证的数据集包含约 16,500 个已标记的蛋白质-配体结构。由于实验测定双分子复合物结合亲和力的难度,这个数字相对较小。由于我们的预训练步骤是自监督的,不需要亲和力结合标签,因此可以使用更大的未标记蛋白质数据库。但在本项目中,由于计算资源有限,我们仅使用了一个易于管理的数据集来验证我们的方法。然而,下一步可以在蛋白质数据库(PDB)数据集上预训练一个真正通用且更丰富的模型,该数据集包含超过 24 万个分子!


结论


在这项工作中,我们开发了一种自监督预训练框架,该框架结合了掩码自编码器和E(n)等变图神经网络,用于学习蛋白质-配体复合物有意义的三维几何表示。我们的方法解决了计算药物发现领域的一个根本性挑战:标记结合亲和力数据的稀缺性,这限制了监督模型的性能。


我们的主要贡献包括:


  • 新型 MAE-EGNN 架构:我们展示了几何等变性在通过自监督重建学习分子三维结构方面的优势。通过将 EGNN 的 E(3) 不变性与掩码自编码相结合,我们的模型无需显式监督即可学习理解原子间的几何关系。

  • 有效的延迟微调策略。通过系统实验,我们发现延迟微调策略优于更简单的微调方法。在训练回归头的初始几个训练轮次中冻结预训练编码器,可以使回归头在编码器更新之前趋于稳定。这有助于后续的联合微调利用预训练知识,从而获得比其他方法更快收敛的显著结果。

  • 轻量级架构下的卓越性能:尽管我们的模型仅有 150 万个参数,且未使用任何外部语言模型嵌入,但在 CASF-2016 基准测试中仍取得了极具竞争力的结果,在所有评估模型中皮尔逊相关系数排名第二。这表明,精心设计的几何预训练可以与依赖大规模语言模型进行特征初始化(GEMS)、自注意力机制(GenScore)或 3D 体素化(Pafnucy)的方法相媲美(甚至超越)。


此外,我们广泛的超参数优化活动为预训练和微调阶段的模型设计提供了宝贵的见解。


  • 在预训练中,最佳配置和最差配置之间 164.5% 的性能差距表明,自监督重建任务对超参数的选择很敏感,需要仔细平衡模型容量和批次大小。

  • 在微调过程中,对超参数的敏感性较低(性能差距为 36%),这表明良好的预训练表示提供了稳定性,从而允许使用更简单、更窄的头部。


我们的研究结果有力地支持了我们的论断:理解三维分子结构能够提高下游结合亲和力预测的准确性。与从头开始训练相比,预训练延迟微调回归模型的优异性能(皮尔逊相关系数为0.787)表明,额外的自监督几何学习步骤确实具有实际价值。该模型在预训练过程中学习了可迁移的几何推理能力,并能有效地将其推广到亲和力预测中。

参考文献

  • David Graber, Peter Stockinger, Fabian Meyer, et al. GEMS — enhancing generalizable binding affinity prediction by removing data leakage and integrating language model embeddings into graph neural networks. bioRxiv, 2025.

  • Xin Pang, Wen-Xuan Mou, Yu-Qiang Ren, et al. Masked autoencoders for point cloud self-supervised learning. In Computer Vision — ECCV 2022, pages 693–709. Springer Nature, 2022.

  • Yijie He, Yixuan Wang, Dong-Jing Sha, et al. ProteinMAE: Self-supervised masked autoencoders for protein representation learning. arXiv preprint arXiv:2305.15571, 2023.

  • Vıctor Garcia Satorras, Emiel Hoogeboom, and Max Welling. E (n) equivariant graph neural networks. In International conference on machine learning, pages 9323–9332. PMLR, 2021.

  • William L. Hamilton, Rex Ying, and Jure Leskovec. Inductive representation learning on large graphs. In Advances in Neural Information Processing Systems 30, pages 1024–1034. Curran Associates, Inc., 2017.

  • Zeming Lin, Halil Akin, Roshan Rao, et al. Evolutionary-scale prediction of atomic-level protein structure with a language model. In Science, volume 379, issue 6636, pages 1123–1130, 2023.

  • Walid Ahmad, Stuart M. Kearney, Benedikt P. Laner, et al. ChemBERTa-2: Towards a Foundation Model for Chemical ML. arXiv preprint arXiv:2209.01712, 2023.

  • C. Shen, et al. A generalized protein–ligand scoring framework with balanced scoring, docking, ranking and screening powers. In Chemical Science, volume 14, issue 30, pages 8129–8146. Royal Society of Chemistry, 2023.

  • Marta M. Stepniewska-Dziubinska, P. Zielenkiewicz, and P. Siedlecki. Development and evaluation of a deep learning model for protein–ligand binding affinity prediction. In Bioinformatics, volume 34, issue 21, pages 3666–3674. Oxford University Press, 2018.

  • Zhen Wang, Zheng Feng, Yanjun Li, Bowen Li, Yongrui Wang, Chulin Sha, Min He, and Xiaolin Li. Batmannet: bi-branch masked graph transformer autoencoder for molecular representation. Briefings in Bioinformatics, 25(1):bbad400, 11 2023. ISSN 1477–4054. doi:10.1093/bib/bbad400. URL https://doi.org/10.1093/bib/bbad400.

微信群

内容中包含的图片若涉及版权问题,请及时与我们联系删除