使用 PyTorch Geometric 和 GCTConv实现异构图、二部图上的节点分类或者链路预测

news/2024/5/24 12:52:12/

解决问题描述

使用 PyTorch Geometric 和 Heterogeneous Graph Transformer 实现异构图上的节点分类
在二部图上应用GTN算法(使用torch_geometric的库HGTConv);

步骤解释

  1. 导入所需的 PyTorch 和 PyTorch Geometric 库。

  2. 定义 x1 和 x2 两种不同类型节点的特征,分别有 1000 个和 500 个节点,每个节点有两维特征。
    随机生成两种边 e1 和 e2 的索引(edge index)和权重(edge weight),其中 e1 从 n1 到 n2,e2 从 n2 到 n1。

  3. 定义异构图的元数据字典 meta_dict,其中 ‘n1’ 和 ‘n2’ 分别表示两种节点类型,而 (‘n1’, ‘e1’, ‘n2’) 表示从类型 ‘n1’ 的节点到类型 ‘n2’ 的节点有一条边,这条边的索引和权重分别为 edge_index_e1 和 edge_weight_e1。

  4. 利用元数据字典 meta_dict 创建异构图数据对象 data,并将节点特征和边索引添加到该对象中。

  5. 定义异构元数据列表 meta_list,其中包含所有节点类型和边类型的名称信息。

  6. 定义 HGTConv 层,并指定输入通道数、输出通道数、异构元数据列表以及头数等超参数。

  7. 将节点特征和边索引转换为字典形式,并利用 HGTConv

  8. 应用 HGTConv 到输入数据,得到输出结果 output_dict,其中包含了处理后的节点特征。最后打印输出 n1 和 n2 节点的输出形状。

详细代码

以下代码可以直接运行

import torch
from torch_geometric.data import Data, HeteroData
from torch_geometric.utils import add_self_loops
from torch_geometric.nn import HGTConv# 定义节点特征
x1 = torch.randn(1000, 2)
x2 = torch.randn(500, 2)# 定义边索引(edge index)以及边权重(edge weight)
edge_index_e1 = torch.cat((torch.randint(0, 1000, size=(1, 4000)),torch.randint(0, 500, size=(1, 4000))),dim=0)
edge_weight_e1 = torch.rand(4000)
edge_index_e2=torch.flip(edge_index_e1, (0,))# 定义元数据字典,描述异构图的结构
meta_dict = {'n1': {'num_nodes': x1.shape[0], 'num_features': x1.shape[1]},'n2': {'num_nodes': x2.shape[0], 'num_features': x2.shape[1]},('n1', 'e1', 'n2'): {'edge_index': edge_index_e1, 'edge_weight': edge_weight_e1},
}# 创建异构图数据对象
data = HeteroData(meta_dict)# 将节点特征和边索引添加到异构图对象中
data['n1'].x = x1
data['n2'].x = x2
data[('n1', 'e1', 'n2')].edge_index = edge_index_e1
data[('n2', 'e1', 'n1')].edge_index = edge_index_e2# 定义异构元数据列表
meta_list= (['n1', 'n2'], [('n1', 'e1', 'n2'), ('n2', 'e1', 'n1')])# 定义 HGTConv 层
in_channels = {'n1': x1.shape[1],'n2': x2.shape[1],
}
out_channels = 16
heads = 4
conv = HGTConv(in_channels=in_channels, out_channels=out_channels, metadata=meta_list,heads=heads)# 将输入数据转换为字典形式
x_dict = {ntype: data[ntype].x for ntype in data.node_types}
edge_index_dict = {}
for etype in data.edge_types:edge_index_dict[etype] = data[etype].edge_index# 应用 HGTConv 到输入数据
output_dict = conv(x_dict, edge_index_dict)
print(output_dict['n1'].shape)
print(output_dict['n2'].shape)

之后如果是节点分类则:

output_dict的n1,n2特征编码分别接全连接层对应y1,y2

之后如果是链路预测则:

output_dict的n1,n2特征编码按照链路进行合并,进而预测

一些细节

data = HeteroData(meta_dict) 创建异构图对象
edge_index_e2=torch.flip(edge_index_e1, (0,)) 创建逆向的边,由于是二部图无向图所以需要

目录

    • 解决问题描述
    • 步骤解释
    • 详细代码
    • 一些细节


http://www.ppmy.cn/news/47651.html

相关文章

如何在 TensorFlow 中使用 GPU 加速深度学习计算?

一、前言 TensorFlow 是由 Google 开源的深度学习框架,它具有易用、高效、灵活等特点,被广泛应用于学术界和工业界中。而 GPU 是一种高性能的计算设备,可以加速深度学习的计算过程。本文将介绍如何在 TensorFlow 中使用 GPU 加速深度学习计算。 二、安装 TensorFlow 安装…

Python语言中的注释方法应用

Python语言中的注释方法 在Python编程中,与其他编程语言一样,有良好的注释部分,会让你的程序在后续的改进或优化中,变得便利。同时,给自己培养了良好的编程习惯。 在Python语言中,有两种注释方法。 1.单行…

DAY 43 Apache的配置与应用

虚拟Web主机 概述 虚拟web主机指的是在同一台服务器中运行多个web站点,其中每一个站点实际上并不独立占用整个服务器,因此被称为"虚拟"web主机。通过虚拟web主机服务可以充分利用服务器的硬件资源,从而大大降低网站构建及运行成本…

API 接口主流协议有哪些? 如何创建不同协议?

API 接口协议繁多,不同的协议有着不同的使用场景。70% 互联网应用开发者日常仅会接触到最通用的 HTTP 协议,相信大家希望了解更多其他协议的信息。我们今天会给大家介绍各种 API 接口主流协议和他们之间的关系。 1、API 接口主流协议有哪些? 接口协议分…

理解websocket连接的原理

背景 Websocket是一个持久化的协议,相对于HTTP这种非持久的无状态协议来说 一、问题 http long poll,或者ajax轮询都可以实现实时信息传递,为什么还需要websocket? 二、理解 ajax轮询:浏览器隔个几秒就发送一次请求&am…

json for modern c++

目录 json for modern c概述编译问题问题描述问题解决 读取JSON文件demo json for modern c GitHub - nlohmann/json: JSON for Modern C 概述 json for modern c是一个德国大牛nlohmann写的,该版本的json有以下特点: 1.直观的语法。 2.整个代码由一个…

Spring项目创建与 Spring Bean 的存储与读取

目录 一、创建Spring项目 1.1 创建Maven项目 1.2 添加 Spring 框架依赖 1.3 添加启动类 二、Bean对象的创建与存储 2.1 创建Bean 2.2 将Bean注册到容器 2.3 获取并使用Bean对象 2.3.1 创建Spring上下文 2.3.2 从Spring容器中获取Bean对象​编辑 延申(多种…

政企数智办公巡展回顾 | 通信赋能传统行业数智化转型的应用实践

在宏观政策引导、技术革新与企业内部数字化改革需求的共同驱使下,数智办公已经成为各行各业转型升级的必由之路。关注【融云 RongCloud】,了解协同办公平台更多干货。 近期,“连接无界 智赋未来” 融云 2023 政企数智办公巡展在北京、杭州相…

X进制转十进制黄金万能算法

单纯、混合进制通吃,真正的黄金万能的进制转换方法。 【学习的细节是欢悦的历程】 Python 官网:https://www.python.org/ Free:大咖免费“圣经”教程《 python 完全自学教程》,不仅仅是基础那么简单…… 地址:https:/…

Qt音视频开发27-ffmpeg视频旋转显示

一、前言 用手机或者平板拍摄的视频文件,很可能是旋转的,比如分辨率是1280x720,确是垂直的,相当于分辨率变成了720x1280,如果不做旋转处理的话,那脑袋必须歪着看才行,这样看起来太难受,所以一定要想办法解析到视频的旋转角度,然后根据这个角度重新绘制。在窗体那边也…

匿名管道与命名管道

匿名管道与命名管道 一,进程间通信什么是进程间通信进程间通信的目的管道的概念 二,匿名管道匿名管道的创建匿名管道使用匿名管道的特性以及四种场景匿名管道的原理通过匿名管道实现简易进程池。 三,命名管道命名管道的创建命名管道的使用命名…

应急响应 - Windows启动项分析,Windows计划任务分析,Windows服务分析

「作者简介」:CSDN top100、阿里云博客专家、华为云享专家、网络安全领域优质创作者 「推荐专栏」:对网络安全感兴趣的小伙伴可以关注专栏《网络安全入门到精通》 Windows应急响应 一、启动项分析1、msconfig2、gpedit.msc3、注册表4、msinfo325、启动菜…

深入拆解 Java 虚拟机-打卡|01 | Java代码是怎么运行的?

文章目录 Java代码是怎么运行的?几个为什么为什么在虚拟机中运行?Java 虚拟机具体又是怎样运行 Java 代码的呢?Java虚拟机的运行效率怎么样? 总结 Java代码是怎么运行的? 来来来,运行个"Hello word !“告诉我是…

C#内建接口:IComparable

目录 一、介绍 二、示例 注意:Array.Sort(people);调用了CompareTo方法 注意:WriteLine输出会调用ToString 三、笔试题实战 一、介绍 IComparable是一个接口,它定义了一个用于比较对象的方法CompareTo。在C#中,IComparable接…

第31天-贪心-第八章 ● 122.买卖股票的最佳时机II ● 55. 跳跃游戏 ● 45.跳跃游戏II

文章目录 1. 买卖股票的最佳时机2. 跳跃游戏3. 跳跃游戏 || 1. 买卖股票的最佳时机 - LeetCode链接 给你一个整数数组 prices ,其中 prices[i] 表示某支股票第 i 天的价格。 在每一天,你可以决定是否购买和/或出售股票。你在任何时候 最多 只能持有 一股…

PHP快速入门09-正则相关,附一定要学会的20个高频使用案例

文章目录 前言一、正则表达式介绍二、正则高频案例20个2.1 检查字符串是否以字母开头2.2 检查字符串是否以数字开头2.3 检查字符串是否包含特定字符2.4 检查字符串是否以特定字符结尾2.5 检查字符串是否为纯数字2.6 检查字符串是否为纯字母2.7 检查字符串是否为有效的电子邮件地…

C/C++每日一练(20230417)

目录 1. 字母异位词分组 🌟🌟 2. 计算右侧小于当前元素的个数 🌟🌟🌟 3. 加一 🌟 🌟 每日一练刷题专栏 🌟 Golang每日一练 专栏 Python每日一练 专栏 C/C每日一练 专栏 J…

前端开发中性能优化之较少http请求(缓存策略)

1.实现减少http请求逻辑如下 定义了一个fetchData函数,用于发起HTTP请求并返回响应结果。函数的实现逻辑如下: 将请求参数对象params转换为字符串,作为缓存对象的键cacheKey。 如果缓存对象中已经有该请求参数对应的结果,直接返回…

实在智能获评十大数字经济风云企业,2022余杭数字经济“群英榜”发布

4月17日,经专家评审、公开投票,由中共杭州市余杭区委组织部(区委两新工委)、中共杭州市余杭区经济和信息化局委员会主办评选的2022年度余杭区数字经济“群英榜”正式公示。其中,实在智能成功获评十大数字经济风云企业之…

Linux 操作系统中应该掌握的知识

下面是我从业整理的一部分需要掌握的内容: 1. 基本命令行操作 基本命令行操作:包括文件管理、进程管理、用户权限等方面的基本命令行操作。 下面是文件管理、进程管理和用户权限相关的一些命令和内容: 1.1 文件管理 ls:显示当…