1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39
| import torch from torch import nn class GraphConvolution(nn.Module):
def __init__(self, node_num, input_feature_num, output_feature_num, add_bias=True, dtype=torch.float, batch_normal=True): super().__init__() # shapes self.graph_num = node_num self.input_feature_num = input_feature_num self.output_feature_num = output_feature_num self.add_bias = add_bias self.batch_normal = batch_normal
# params self.weight = nn.Parameter(torch.empty(self.output_feature_num, input_feature_num, dtype=dtype)) self.bias = nn.Parameter(torch.empty(self.output_feature_num, self.graph_num, dtype=dtype)) if batch_normal: self.norm = nn.InstanceNorm1d(node_num) def set_trainable(self, train=True): for param in self.parameters(): param.requires_grad = train
def forward(self, inp: torch.Tensor): """ @param inp : adjacent: (batch, graph_num, graph_num) cat node_feature: (batch, graph_num, in_feature_num) -> (batch, graph_num, graph_num + in_feature_num) @return: """ b, c, n = inp.shape adjacent, node_feature = inp[:, 0:n, :], inp[:, n:, :] x = torch.matmul(self.weight, node_feature) x = torch.matmul(x, adjacent) if self.add_bias: x = x + self.bias if self.batch_normal: x = self.norm(x)
return torch.cat((adjacent, x), dim=1)
|