GNN图神经网络的快速构建及训练(TensorFlow)


图神经网络(简称 GNN)已经成为一种强大的技术,可以利用图的连接性(如旧算法 DeepWalk 和 Node2Vec)以及各个节点和边上的输入特征。
GNN
可以对整个图(该分子是否以某种方式反应?)、单个节点(本文档的主题是什么,给出其引文?)或潜在边(该产品是否可能一起购买?)。除了对图进行预测之外,GNN
也是一种强大的工具,用于弥合与更典型的神经网络用例之间的鸿沟。它们以连续的方式对图的离散关系信息进行编码,以便它可以自然地包含在另一个深度学习系统中。
近期,谷歌发布
TensorFlow GNN 1.0 (TF-GNN),这是一个经过生产测试的库,用于大规模构建 GNN。它支持 TensorFlow
中的建模和训练,以及从庞大的数据存储中提取输入图。TF-GNN
是从头开始构建的异构图,其中类型和关系由不同的节点和边集表示。现实世界的对象及其关系以不同的类型出现,而 TF-GNN
的异构焦点使得表征它们变得很自然。
pip install tensorflow_gnn
在 TensorFlow 内部,此类图由 tfgnn.GraphTensor 类型的对象表示。这是一种复合张量类型(一个 Python 类中张量的集合),在 tf.data.Dataset 、 tf.function 等中被接受为一等公民。它存储图结构及其附加到节点、边和整个图的特征。GraphTensors 的可训练转换可以定义为高级 Keras API 中的 Layers 对象,或者直接使用 tfgnn.GraphTensor 原语。
TF-GNN:对上下文中的对象进行预测
为了便于说明,让我们看一下 TF-GNN 的一个典型应用:预测由大型数据库的交叉引用表定义的图中某种类型节点的属性。例如,计算机科学 (CS) arXiv 论文的引文数据库具有一对多引用和多对一引用关系,我们希望在其中预测每篇论文的主题领域。
与大多数神经网络一样,GNN
在包含许多标记示例(约数百万个)的数据集上进行训练,但每个训练步骤仅包含一小批训练示例(例如数百个)。为了扩展到数百万级,GNN
在底层图中的相当小的子图流上进行训练。每个子图都包含足够的原始数据来计算其中心标记节点的 GNN
结果并训练模型。这个过程(通常称为子图采样)对于 GNN 训练极其重要。大多数现有工具以批量方式完成采样,生成用于训练的静态子图。TF-GNN
提供了通过动态和交互采样来改进这一点的工具。

如图所示,子图采样的过程,其中从较大的图中采样小的、易于处理的子图,以创建用于 GNN 训练的输入示例。

TF-GNN
1.0 首次推出了灵活的 Python API,用于在所有相关规模上配置动态或批量子图采样:在 Colab
笔记本(如本例)中交互式地对存储在单个训练主机主内存中的小数据集进行高效采样,或由 Apache Beam
分发,用于存储在网络文件系统上的巨大数据集(多达数亿个节点和数十亿条边)。有关详细信息,请分别参阅我们的内存采样和基于波束采样的用户指南。
在这些相同的采样子图上,GNN
的任务是计算根节点处的隐藏(或潜在)状态;隐藏状态聚合并编码根节点邻域的相关信息。一种经典方法是消息传递神经网络。在每轮消息传递中,节点沿着传入边缘接收来自邻居的消息,并从中更新自己的隐藏状态。经过
n 轮后,根节点的隐藏状态反映了 n 个边内所有节点的聚合信息(如下图,n =
2)。消息和新的隐藏状态由神经网络的隐藏层计算。在异构图中,对于不同类型的节点和边使用单独训练的隐藏层通常是有意义的。
如图所示,这是一个简单的消息传递神经网络,在每个步骤中,节点状态从外部节点传播到内部节点,并在内部节点中进行池化以计算新的节点状态。一旦到达根节点,就可以做出最终的预测。

训练设置是通过将输出层放置在标记节点的 GNN 隐藏状态之上、计算损失(以测量预测误差)并通过反向传播更新模型权重来完成的,就像在任何神经网络训练中一样。
除了监督训练(即最小化由标签定义的损失)之外,GNN
还可以以无监督方式进行训练(即没有标签)。这让我们可以计算节点及其特征的离散图结构的连续表示(或嵌入)。这些表示通常用于其他机器学习系统。通过这种方式,由图编码的离散关系信息可以包含在更典型的神经网络用例中。TF-GNN
支持异构图无监督目标的细粒度规范。
构建 GNN 架构
TF-GNN 库支持在不同抽象级别构建和训练 GNN。
在最高级别,用户可以采用与库捆绑在一起的任何以
Keras 层表示的预定义模型。除了研究文献中的一小部分模型之外,TF-GNN
还提供了一个高度可配置的模型模板,该模板提供了精心挑选的建模选择,我们发现这些选择可以为我们的许多内部问题提供强有力的基线。模板实现了 GNN
层;用户只需要初始化 Keras 层。
import tensorflow_gnn as tfgnnfrom tensorflow_gnn.models import mt_albisdef model_fn(graph_tensor_spec: tfgnn.GraphTensorSpec): """Builds a GNN as a Keras model.""" graph = inputs = tf.keras.Input(type_spec=graph_tensor_spec) # Encode input features (callback omitted for brevity). graph = tfgnn.keras.layers.MapFeatures( node_sets_fn=set_initial_node_states)(graph) # For each round of message passing... for _ in range(2): # ... create and apply a Keras layer. graph = mt_albis.MtAlbisGraphUpdate( units=128, message_dim=64, attention_type="none", simple_conv_reduce_type="mean", normalization_type="layer", next_state_type="residual", state_dropout_rate=0.2, l2_regularization=1e-5, )(graph) return tf.keras.Model(inputs, graph)
在最低级别,用户可以根据在图周围传递数据的原语从头开始编写
GNN 模型,例如将数据从节点广播到其所有传出边,或将数据从其所有传入边汇集到节点中(例如,计算传入消息的总和)。TF-GNN
的图数据模型在涉及特征或隐藏状态时平等对待节点、边和整个输入图,从而不仅可以直接表达以节点为中心的模型(如上面讨论的
MPNN),还可以表达更通用的 GraphNet 形式。这可以(但不一定)使用 Keras 作为核心 TensorFlow
之上的建模框架来完成。有关更多详细信息和中级建模,请参阅 TF-GNN 用户指南和模型集合。
 模型训练
虽然高级用户可以自由地进行自定义模型训练,但 TF-GNN Runner 还提供了一种简洁的方法来协调常见情况下 Keras 模型的训练。一个简单的调用可能如下所示:
from tensorflow_gnn import runnerrunner.run( task=runner.RootNodeBinaryClassification("papers", ...), model_fn=model_fn, trainer=runner.KerasTrainer(tf.distribute.MirroredStrategy(), model_dir="/tmp/model"), optimizer_fn=tf.keras.optimizers.Adam, epochs=10, global_batch_size=128, train_ds_provider=runner.TFRecordDatasetProvider("/tmp/train*"), valid_ds_provider=runner.TFRecordDatasetProvider("/tmp/validation*"), gtspec=...,)
Runner 为 ML 难题提供了即用型解决方案,例如分布式训练和云 TPU 上固定形状的 tfgnn.GraphTensor 填充。除了对单个任务进行训练(如上所示)之外,它还支持对多个(两个或更多)任务进行联合训练。例如,无监督任务可以与有监督任务混合,以告知具有应用程序特定归纳偏差的最终连续表示(或嵌入)。调用者只需用任务映射替换任务参数:
from tensorflow_gnn import runnerfrom tensorflow_gnn.models import contrastive_lossesrunner.run( task={ "classification": runner.RootNodeBinaryClassification("papers", ...), "dgi": contrastive_losses.DeepGraphInfomaxTask("papers"), }, ...)
此外,TF-GNN
Runner 还包括用于模型归因的集成梯度的实现。集成梯度输出是一个 GraphTensor,与观察到的 GraphTensor
具有相同的连接性,但其特征被梯度值取代,其中较大的值比较小的值在 GNN 预测中贡献更多。用户可以检查梯度值,以了解其 GNN 使用最多的特征。
到顶部