def __init__(self, in_channels, hidden_channels, out_channels, mode="cat"): super(Block, self).__init__() # self.conv1 = DenseSAGEConv(in_channels, hidden_channels) # self.conv2 = DenseSAGEConv(hidden_channels, out_channels) # self.conv1 = DenseGCNConv(in_channels, hidden_channels) # self.conv2 = DenseGCNConv(hidden_channels, out_channels) nn1 = torch.nn.Sequential( Linear(in_channels, hidden_channels), ReLU(), Linear(hidden_channels, hidden_channels), ) nn2 = torch.nn.Sequential( Linear(hidden_channels, out_channels), ReLU(), Linear(out_channels, out_channels), ) self.conv1 = DenseGINConv(nn1, train_eps=True) self.conv2 = DenseGINConv(nn2, train_eps=True) self.jump = JumpingKnowledge(mode) if mode == "cat": self.lin = Linear(hidden_channels + out_channels, out_channels) else: self.lin = Linear(out_channels, out_channels)
def test_dense_sage_conv(): in_channels, out_channels = (16, 32) nn = Seq(Lin(in_channels, 32), ReLU(), Lin(32, out_channels)) sparse_conv = GINConv(nn) dense_conv = DenseGINConv(nn) dense_conv = DenseGINConv(nn, train_eps=True) assert dense_conv.__repr__() == ( 'DenseGINConv(nn=Sequential(\n' ' (0): Linear(in_features=16, out_features=32, bias=True)\n' ' (1): ReLU()\n' ' (2): Linear(in_features=32, out_features=32, bias=True)\n' '))') x = torch.randn((5, in_channels)) edge_index = torch.tensor([[0, 0, 1, 1, 2, 2, 3, 4], [1, 2, 0, 2, 0, 1, 4, 3]]) sparse_out = sparse_conv(x, edge_index) x = torch.cat([x, x.new_zeros(1, in_channels)], dim=0).view(2, 3, in_channels) adj = torch.Tensor([ [[0, 1, 1], [1, 0, 1], [1, 1, 0]], [[0, 1, 0], [1, 0, 0], [0, 0, 0]], ]) mask = torch.tensor([[1, 1, 1], [1, 1, 0]], dtype=torch.uint8) dense_out = dense_conv(x, adj, mask) assert dense_out.size() == (2, 3, out_channels) dense_out = dense_out.view(6, out_channels)[:-1] assert torch.allclose(sparse_out, dense_out, atol=1e-04)
class Block(torch.nn.Module): def __init__(self, in_channels, hidden_channels, out_channels, mode='cat'): super(Block, self).__init__() # self.conv1 = DenseSAGEConv(in_channels, hidden_channels) # self.conv2 = DenseSAGEConv(hidden_channels, out_channels) # self.conv1 = DenseGCNConv(in_channels, hidden_channels) # self.conv2 = DenseGCNConv(hidden_channels, out_channels) nn1 = torch.nn.Sequential(Linear(in_channels, hidden_channels), ReLU(), Linear(hidden_channels, hidden_channels)) nn2 = torch.nn.Sequential(Linear(hidden_channels, out_channels), ReLU(), Linear(out_channels, out_channels)) self.conv1 = DenseGINConv(nn1, train_eps=True) self.conv2 = DenseGINConv(nn2, train_eps=True) self.jump = JumpingKnowledge(mode) if mode == 'cat': self.lin = Linear(hidden_channels + out_channels, out_channels) else: self.lin = Linear(out_channels, out_channels) def reset_parameters(self): self.conv1.reset_parameters() self.conv2.reset_parameters() self.lin.reset_parameters() def forward(self, x, adj, mask=None, add_loop=True): x1 = F.relu(self.conv1(x, adj, mask, add_loop)) x2 = F.relu(self.conv2(x1, adj, mask, add_loop)) return self.lin(self.jump([x1, x2]))
def __init__(self, in_dim, out_dim, dim=32): super(Encoder, self).__init__() nn1 = Sequential(Linear(in_dim, dim), ReLU(), Linear(dim, dim)) self.conv1 = DenseGINConv(nn1) self.bn1 = torch.nn.BatchNorm1d(in_dim) nn2 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim)) self.conv2 = DenseGINConv(nn2) self.bn2 = torch.nn.BatchNorm1d(in_dim) nn3 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim)) self.conv3 = DenseGINConv(nn3) self.bn3 = torch.nn.BatchNorm1d(in_dim) self.fc1 = Linear(dim, dim) self.fc2 = Linear(dim, out_dim)
def _gcn(self, name, input_dim, hidden_dim, bias, activation='relu'): if name == 'SAGE': return DenseSAGEConv(input_dim, hidden_dim, normalize=True, bias=bias) else: nn1 = nn.Sequential(nn.Linear(input_dim, hidden_dim), self._activation(activation), nn.Linear(hidden_dim, hidden_dim)) return DenseGINConv(nn1)
def __init__(self, in_dim, out_dim, dim=32): super(Decoder, self).__init__() self.in_dim = in_dim self.out_dim = out_dim self.id_mat = torch.eye(in_dim) self.register_buffer('id', self.id_mat) nn1 = Sequential(Linear(in_dim + 2, dim), ReLU(), Linear(dim, dim)) self.conv1 = DenseGINConv(nn1) self.bn1 = torch.nn.BatchNorm1d(in_dim) nn2 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim)) self.conv2 = DenseGINConv(nn2) self.bn2 = torch.nn.BatchNorm1d(in_dim) nn3 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim)) self.conv3 = DenseGINConv(nn3) self.bn3 = torch.nn.BatchNorm1d(in_dim) self.fc1 = Linear(dim, dim) self.fc2 = Linear(dim, out_dim)
def test_dense_gin_conv_with_broadcasting(): batch_size, num_nodes, channels = 8, 3, 16 nn = Seq(Lin(channels, channels), ReLU(), Lin(channels, channels)) conv = DenseGINConv(nn) x = torch.randn(batch_size, num_nodes, channels) adj = torch.Tensor([ [0, 1, 1], [1, 0, 1], [1, 1, 0], ]) assert conv(x, adj).size() == (batch_size, num_nodes, channels) mask = torch.tensor([1, 1, 1], dtype=torch.bool) assert conv(x, adj, mask).size() == (batch_size, num_nodes, channels)