0%

pytorch GCN 实现以及使用

从最为广泛和简单的理解来看, 在图上利用多个个节点数据进行计算的操作都可以称之为图网络

一张图由数个节点组成, 各个节点之间存在单向或双向的边。图只在概念上存在,实际上一张图由节点和邻接矩阵共同表示。

节点

图中的一个概念上的点, 每个点都携带着一定量的数据。例如以下数据结构, 其中每一行代表一个点, 每一行中的数字代表该点所携带的数据。

1
2
3
4
5
# 4点, 每点包含一个5维的数据
[0.4252, 0.2733, 0.5442, 0.7236, 0.0515]
[0.5121, 0.2056, 0.8560, 0.3010, 0.3110]
[0.0684, 0.5282, 0.8454, 0.0913, 0.9803]
[0.4211, 0.5779, 0.2952, 0.3368, 0.8389]

邻接矩阵

邻接矩阵是一个 n * n 的矩阵, n 表示节点数目, 每一行都表示一个节点和其他节点是否相连。 对于一个无向图, 他的邻接矩阵总是对称的。例如以下数据结构。

1
2
3
4
5
# 4点邻接矩阵
[1., 0., 1., 0.]
[0., 1., 1., 0.]
[1., 1., 1., 0.]
[0., 0., 0., 1.]

我们观察第一行

1
[1., 0., 1., 0.]

表示节点0, 和节点0/2相连。

GCN 定义

如何实现一个GCN模块

注意网络输入尺寸为b*(n+c)*n, 前n个channel构成的b*n*n表示邻接矩阵。后b*c*n表示节点数据。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import torch
from torch import nn
class GraphConvolution(nn.Module):

def __init__(self, node_num, input_feature_num, output_feature_num, add_bias=True, dtype=torch.float,
batch_normal=True):
super().__init__()
# shapes
self.graph_num = node_num
self.input_feature_num = input_feature_num
self.output_feature_num = output_feature_num
self.add_bias = add_bias
self.batch_normal = batch_normal

# params
self.weight = nn.Parameter(torch.empty(self.output_feature_num, input_feature_num, dtype=dtype))
self.bias = nn.Parameter(torch.empty(self.output_feature_num, self.graph_num, dtype=dtype))
if batch_normal:
self.norm = nn.InstanceNorm1d(node_num)

def set_trainable(self, train=True):
for param in self.parameters():
param.requires_grad = train

def forward(self, inp: torch.Tensor):
"""
@param inp : adjacent: (batch, graph_num, graph_num) cat node_feature: (batch, graph_num, in_feature_num) -> (batch, graph_num, graph_num + in_feature_num)
@return:
"""
b, c, n = inp.shape
adjacent, node_feature = inp[:, 0:n, :], inp[:, n:, :]
x = torch.matmul(self.weight, node_feature)
x = torch.matmul(x, adjacent)
if self.add_bias:
x = x + self.bias
if self.batch_normal:
x = self.norm(x)

return torch.cat((adjacent, x), dim=1)

残差GCN

GCN存在的问题

训练中需要注意的

  1. 在训练GCN时必须谨慎使用Norm方法, 否则很可能造成网络不收敛
  2. GCN梯度退化较为严重, 尽量使用残差结构