图神经网络之预训练大模型结合:ERNIESage在链接预测任务应用

IT技术2年前 (2022)发布 IT大王
0

1.ERNIESage运行实例介绍(1.8x版本)

本项目原链接:https://aistudio.baidu.com/aistudio/projectdetail/5097085?contributionType=1

本项目主要是为了直接提供一个可以运行ERNIESage模型的环境,

https://github.com/PaddlePaddle/PGL/blob/develop/examples/erniesage/README.md

在很多工业应用中,往往出现如下图所示的一种特殊的图:Text Graph。顾名思义,图的节点属性由文本构成,而边的构建提供了结构信息。如搜索场景下的Text Graph,节点可由搜索词、网页标题、网页正文来表达,用户反馈和超链信息则可构成边关系。

图神经网络之预训练大模型结合:ERNIESage在链接预测任务应用

ERNIESage 由PGL团队提出,是ERNIE SAmple aggreGatE的简称,该模型可以同时建模文本语义与图结构信息,有效提升 Text Graph 的应用效果。其中 ERNIE 是百度推出的基于知识增强的持续学习语义理解框架。

ERNIESage 是 ERNIE 与 GraphSAGE 碰撞的结果,是 ERNIE SAmple aggreGatE 的简称,它的结构如下图所示,主要思想是通过 ERNIE 作为聚合函数(Aggregators),建模自身节点和邻居节点的语义与结构关系。ERNIESage 对于文本的建模是构建在邻居聚合的阶段,中心节点文本会与所有邻居节点文本进行拼接;然后通过预训练的 ERNIE 模型进行消息汇聚,捕捉中心节点以及邻居节点之间的相互关系;最后使用 ERNIESage 搭配独特的邻居互相看不见的 Attention Mask 和独立的 Position Embedding 体系,就可以轻松构建 TextGraph 中句子之间以及词之间的关系。

图神经网络之预训练大模型结合:ERNIESage在链接预测任务应用

使用ID特征的GraphSAGE只能够建模图的结构信息,而单独的ERNIE只能处理文本信息。通过PGL搭建的图与文本的桥梁,ERNIESage能够很简单的把GraphSAGE以及ERNIE的优点结合一起。以下面TextGraph的场景,ERNIESage的效果能够比单独的ERNIE以及GraphSAGE模型都要好。

图神经网络之预训练大模型结合:ERNIESage在链接预测任务应用

ERNIESage可以很轻松地在PGL中的消息传递范式中进行实现,目前PGL在github上提供了3个版本的ERNIESage模型:

  • ERNIESage v1: ERNIE 作用于text graph节点上;
  • ERNIESage v2: ERNIE 作用在text graph的边上;
  • ERNIESage v3: ERNIE 作用于一阶邻居及起边上;

主要会针对ERNIESageV1和ERNIESageV2版本进行一个介绍。

1.1算法实现

可能有同学对于整个项目代码文件都不太了解,因此这里会做一个比较简单的讲解。

核心部分包含:

  • 数据集部分
  1. data.txt – 简单的输入文件,格式为每行query \t answer,可作简单的运行实例使用。
  • 模型文件和配置部分
  1. ernie_config.json – ERNIE模型的配置文件。
  2. vocab.txt – ERNIE模型所使用的词表。
  3. ernie_base_ckpt/ – ERNIE模型参数。
  4. config/ – ERNIESage模型的配置文件,包含了三个版本的配置文件。
  • 代码部分
  1. local_run.sh – 入口文件,通过该入口可完成预处理、训练、infer三个步骤。
  2. preprocessing文件夹 – 包含dump_graph.py, tokenization.py。在预处理部分,我们首先需要进行建图,将输入的文件构建成一张图。由于我们所研究的是Text Graph,因此节点都是文本,我们将文本表示为该节点对应的node feature(节点特征),处理文本的时候需要进行切字,再映射为对应的token id。
  3. dataset/ – 该文件夹包含了数据ready的代码,以便于我们在训练的时候将训练数据以batch的方式读入。
  4. models/ – 包含了ERNIESage模型核心代码。
  5. train.py – 模型训练入口文件。
  6. learner.py – 分布式训练代码,通过train.py调用。
  7. infer.py – infer代码,用于infer出节点对应的embedding。
  • 评价部分
  1. build_dev.py – 用于将我们的验证集修改为需要的格式。
  2. mrr.py – 计算MRR值。

要在这个项目中运行模型其实很简单,只要运行下方的入口命令就ok啦!但是,需要注意的是,由于ERNIESage模型比较大,所以如果AIStudio中的CPU版本运行模型容易出问题。因此,在运行部署环境时,建议选择GPU的环境。

另外,如果提示出现了GPU空间不足等问题,我们可以通过调小对应yaml文件中的batch_size来调整,也可以修改ERNIE模型的配置文件ernie_config.json,将num_hidden_layers设小一些。在这里,我仅提供了ERNIESageV2版本的gpu运行过程,如果同学们想运行其他版本的模型,可以根据需要修改下方的命令。

运行完毕后,会产生较多的文件,这里进行简单的解释。

  1. workdir/ – 这个文件夹主要会存储和图相关的数据信息。
  2. output/ – 主要的输出文件夹,包含了以下内容:(1)模型文件,根据config文件中的save_per_step可调整保存模型的频率,如果设置得比较大则可能训练过程中不会保存模型; (2)last文件夹,保存了停止训练时的模型参数,在infer阶段我们会使用这部分模型参数;(3)part-0文件,infer之后的输入文件中所有节点的Embedding输出。

为了可以比较清楚地知道Embedding的效果,我们直接通过MRR简单判断一下data.txt计算出来的Embedding结果,此处将data.txt同时作为训练集和验证集。

1.2 核心模型代码讲解

首先,我们可以通过查看models/model_factory.py来判断在本项目有多少种ERNIESage模型。

from models.base import BaseGNNModel
from models.ernie import ErnieModel
from models.erniesage_v1 import ErnieSageModelV1
from models.erniesage_v2 import ErnieSageModelV2
from models.erniesage_v3 import ErnieSageModelV3
class Model(object):
    @classmethod
    def factory(cls, config):
        name = config.model_type
        if name == "BaseGNNModel":
            return BaseGNNModel(config)
        if name == "ErnieModel":
            return ErnieModel(config)
        if name == "ErnieSageModelV1":
            return ErnieSageModelV1(config)
        if name == "ErnieSageModelV2":
            return ErnieSageModelV2(config)
        if name == "ErnieSageModelV3":
            return ErnieSageModelV3(config)
        else:
            raise ValueError

可以看到一共有ERNIESage模型一共有3个版本,另外我们也提供了基本的GNN模型和ERNIE模型,感兴趣的同学可以自行查阅。

接下来,我主要会针对ERNIESageV1和ERNIESageV2这两个版本的模型进行关键部分的讲解,主要的不同其实就是消息传递机制(Message Passing)部分的不同。

1.2.1 ERNIESageV1关键代码

# ERNIESageV1的Message Passing代码
# 查找路径:erniesage_v1.py(__call__中的self.gnn_layers) -> base.py(BaseNet类中的gnn_layers方法) -> message_passing.py
# erniesage_v1.py
def __call__(self, graph_wrappers):
    inputs = self.build_inputs()
    feature = self.build_embedding(graph_wrappers, inputs[-1])  # 将节点的文本信息利用ERNIE模型建模,生成对应的Embedding作为feature
    features = self.gnn_layers(graph_wrappers, feature)  # GNN模型的主要不同,消息传递机制入口
    outputs = [self.take_final_feature(features[-1], i, "final_fc") for i in inputs[:-1]]
    src_real_index = L.gather(graph_wrappers[0].node_feat['index'], inputs[0])
    outputs.append(src_real_index)
    return inputs, outputs
# base.py -> BaseNet
def gnn_layers(self, graph_wrappers, feature):
    features = [feature]
    initializer = None
    fc_lr = self.config.lr / 0.001
    for i in range(self.config.num_layers):
        if i == self.config.num_layers - 1:
            act = None
        else:
            act = "leaky_relu"
        feature = get_layer(  
            self.config.layer_type, # 对于ERNIESageV1, 其layer_type="graphsage_sum",可以到config文件夹中查看
            graph_wrappers[i],
            feature,
            self.config.hidden_size,
            act,
            initializer,
            learning_rate=fc_lr,
            name="%s_%s" % (self.config.layer_type, i))
        features.append(feature)
    return features
# message_passing.py
def graphsage_sum(gw, feature, hidden_size, act, initializer, learning_rate, name):
    """doc"""
    msg = gw.send(copy_send, nfeat_list=[("h", feature)]) # Send
    neigh_feature = gw.recv(msg, sum_recv)                # Recv
    self_feature = feature
    self_feature = fluid.layers.fc(self_feature,
                                   hidden_size,
                                   act=act,
                                   param_attr=fluid.ParamAttr(name=name + "_l.w_0", initializer=initializer,
                                   learning_rate=learning_rate),
                                    bias_attr=name+"_l.b_0"
                                   )
    neigh_feature = fluid.layers.fc(neigh_feature,
                                    hidden_size,
                                    act=act,
                                    param_attr=fluid.ParamAttr(name=name + "_r.w_0", initializer=initializer,
                                   learning_rate=learning_rate),
                                    bias_attr=name+"_r.b_0"
                                    )
    output = fluid.layers.concat([self_feature, neigh_feature], axis=1)
    output = fluid.layers.l2_normalize(output, axis=1)
    return output

通过上述代码片段可以看到,关键的消息传递机制代码就是graphsage_sum函数,其中send、recv部分如下。

def copy_send(src_feat, dst_feat, edge_feat):
    """doc"""
    return src_feat["h"]
msg = gw.send(copy_send, nfeat_list=[("h", feature)]) # Send
neigh_feature = gw.recv(msg, sum_recv)                # Recv

通过代码可以看到,ERNIESageV1版本,其主要是针对节点邻居,直接将当前节点的邻居节点特征求和。再看到graphsage_sum函数中,将邻居节点特征进行求和后,得到了neigh_feature。随后,我们将节点本身的特征self_feature和邻居聚合特征neigh_feature通过fc层后,直接concat起来,从而得到了当前gnn layer层的feature输出。

1.2.2ERNIESageV2关键代码

ERNIESageV2的消息传递机制代码主要在erniesage_v2.py和message_passing.py,相对ERNIESageV1来说,代码会相对长了一些。

为了使得大家对下面有关ERNIE模型的部分能够有所了解,这里先贴出ERNIE的主模型框架图。
图神经网络之预训练大模型结合:ERNIESage在链接预测任务应用

具体的代码解释可以直接看注释。

# ERNIESageV2的Message Passing代码
# 下面的函数都在erniesage_v2.py的ERNIESageV2类中
# ERNIESageV2的调用函数
def __call__(self, graph_wrappers):
    inputs = self.build_inputs()
    feature = inputs[-1]
    features = self.gnn_layers(graph_wrappers, feature) 
    outputs = [self.take_final_feature(features[-1], i, "final_fc") for i in inputs[:-1]]
    src_real_index = L.gather(graph_wrappers[0].node_feat['index'], inputs[0])
    outputs.append(src_real_index)
    return inputs, outputs
# 进入self.gnn_layers函数
def gnn_layers(self, graph_wrappers, feature):
    features = [feature]
    initializer = None
    fc_lr = self.config.lr / 0.001
    for i in range(self.config.num_layers):
        if i == self.config.num_layers - 1:
            act = None
        else:
            act = "leaky_relu"
        feature = self.gnn_layer(
            graph_wrappers[i],
            feature,
            self.config.hidden_size,
            act,
            initializer,
            learning_rate=fc_lr,
            name="%s_%s" % ("erniesage_v2", i))
        features.append(feature)
    return features
接下来会进入ERNIESageV2主要的代码部分。
可以看到,在ernie_send函数用于将我们的邻居信息发送到当前节点。在ERNIESageV1中,我们在Send阶段对邻居节点通过ERNIE模型得到Embedding后,再直接求和,实际上当前节点和邻居节点之间的文本信息在消息传递过程中是没有直接交互的,直到最后才**concat**起来;而ERNIESageV2中,在Send阶段,源节点和目标节点的信息会直接concat起来,通过ERNIE模型得到一个统一的Embedding,这样就得到了源节点和目标节点的一个信息交互过程,这个部分可以查看下面的ernie_send函数。
gnn_layer函数中包含了三个函数:
1. ernie_send: 将src和dst节点对应文本concat后,过Ernie后得到需要的msg,更加具体的解释可以看下方代码注释。
2. build_position_ids: 主要是为了创建位置ID,提供给Ernie,从而可以产生position embeddings。
3. erniesage_v2_aggregator: gnn_layer的入口函数,包含了消息传递机制,以及聚合后的消息feature处理过程。
# 进入self.gnn_layer函数
def gnn_layer(self, gw, feature, hidden_size, act, initializer, learning_rate, name):
    def build_position_ids(src_ids, dst_ids): # 此函数用于创建位置ID,可以对应到ERNIE框架图中的Position Embeddings
        # ...
        pass
    def ernie_send(src_feat, dst_feat, edge_feat): 
        """doc"""
        # input_ids,可以对应到ERNIE框架图中的Token Embeddings
        cls = L.fill_constant_batch_size_like(src_feat["term_ids"], [-1, 1, 1], "int64", 1)
        src_ids = L.concat([cls, src_feat["term_ids"]], 1)
        dst_ids = dst_feat["term_ids"]
        term_ids = L.concat([src_ids, dst_ids], 1)
        # sent_ids,可以对应到ERNIE框架图中的Segment Embeddings
        sent_ids = L.concat([L.zeros_like(src_ids), L.ones_like(dst_ids)], 1)
        # position_ids,可以对应到ERNIE框架图中的Position Embeddings
        position_ids = build_position_ids(src_ids, dst_ids)
        term_ids.stop_gradient = True
        sent_ids.stop_gradient = True
        ernie = ErnieModel( # ERNIE模型
            term_ids, sent_ids, position_ids,
            config=self.config.ernie_config)
        feature = ernie.get_pooled_output() # 得到发送过来的msg,该msg是由src节点和dst节点的文本特征一起过ERNIE后得到的embedding
        return feature
    def erniesage_v2_aggregator(gw, feature, hidden_size, act, initializer, learning_rate, name):
        feature = L.unsqueeze(feature, [-1])
        msg = gw.send(ernie_send, nfeat_list=[("term_ids", feature)]) # Send
        neigh_feature = gw.recv(msg, lambda feat: F.layers.sequence_pool(feat, pool_type="sum")) # Recv,直接将发送来的msg根据dst节点来相加。
        # 接下来的部分和ERNIESageV1类似,将self_feature和neigh_feature通过concat、normalize后得到需要的输出。
        term_ids = feature
        cls = L.fill_constant_batch_size_like(term_ids, [-1, 1, 1], "int64", 1)
        term_ids = L.concat([cls, term_ids], 1)
        term_ids.stop_gradient = True
        ernie = ErnieModel(
            term_ids, L.zeros_like(term_ids),
            config=self.config.ernie_config)
        self_feature = ernie.get_pooled_output()
        self_feature = L.fc(self_feature,
                                        hidden_size,
                                        act=act,
                                        param_attr=F.ParamAttr(name=name + "_l.w_0",
                                        learning_rate=learning_rate),
                                        bias_attr=name+"_l.b_0"
                                        )
        neigh_feature = L.fc(neigh_feature,
                                        hidden_size,
                                        act=act,
                                        param_attr=F.ParamAttr(name=name + "_r.w_0",
                                        learning_rate=learning_rate),
                                        bias_attr=name+"_r.b_0"
                                        )
        output = L.concat([self_feature, neigh_feature], axis=1)
        output = L.l2_normalize(output, axis=1)
        return output
    return erniesage_v2_aggregator(gw, feature, hidden_size, act, initializer, learning_rate, name)

2.总结

通过以上两个版本的模型代码简单的讲解,我们可以知道他们的不同点,其实主要就是在消息传递机制的部分有所不同。ERNIESageV1版本只作用在text graph的节点上,在传递消息(Send阶段)时只考虑了邻居本身的文本信息;而ERNIESageV2版本则作用在了边上,在Send阶段同时考虑了当前节点和其邻居节点的文本信息,达到更好的交互效果。

© 版权声明
好牛新坐标 广告
版权声明:
1、IT大王遵守相关法律法规,由于本站资源全部来源于网络程序/投稿,故资源量太大无法一一准确核实资源侵权的真实性;
2、出于传递信息之目的,故IT大王可能会误刊发损害或影响您的合法权益,请您积极与我们联系处理(所有内容不代表本站观点与立场);
3、因时间、精力有限,我们无法一一核实每一条消息的真实性,但我们会在发布之前尽最大努力来核实这些信息;
4、无论出于何种目的要求本站删除内容,您均需要提供根据国家版权局发布的示范格式
《要求删除或断开链接侵权网络内容的通知》:https://itdw.cn/ziliao/sfgs.pdf,
国家知识产权局《要求删除或断开链接侵权网络内容的通知》填写说明: http://www.ncac.gov.cn/chinacopyright/contents/12227/342400.shtml
未按照国家知识产权局格式通知一律不予处理;请按照此通知格式填写发至本站的邮箱 wl6@163.com

相关文章