| 12
 3
 4
 5
 6
 7
 8
 9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 
 | import torchfrom torch import nn
 class GraphConvolution(nn.Module):
 
 def __init__(self, node_num, input_feature_num, output_feature_num, add_bias=True, dtype=torch.float,
 batch_normal=True):
 super().__init__()
 # shapes
 self.graph_num = node_num
 self.input_feature_num = input_feature_num
 self.output_feature_num = output_feature_num
 self.add_bias = add_bias
 self.batch_normal = batch_normal
 
 # params
 self.weight = nn.Parameter(torch.empty(self.output_feature_num, input_feature_num, dtype=dtype))
 self.bias = nn.Parameter(torch.empty(self.output_feature_num, self.graph_num, dtype=dtype))
 if batch_normal:
 self.norm = nn.InstanceNorm1d(node_num)
 
 def set_trainable(self, train=True):
 for param in self.parameters():
 param.requires_grad = train
 
 def forward(self, inp: torch.Tensor):
 """
 @param inp : adjacent: (batch, graph_num, graph_num) cat node_feature: (batch, graph_num, in_feature_num) -> (batch, graph_num, graph_num + in_feature_num)
 @return:
 """
 b, c, n = inp.shape
 adjacent, node_feature = inp[:, 0:n, :], inp[:, n:, :]
 x = torch.matmul(self.weight, node_feature)
 x = torch.matmul(x, adjacent)
 if self.add_bias:
 x = x + self.bias
 if self.batch_normal:
 x = self.norm(x)
 
 return torch.cat((adjacent, x), dim=1)
 
 |