def test_gine_conv_edge_dim(): x = torch.randn(4, 16) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) edge_attr = torch.randn(edge_index.size(1), 8) nn = Seq(Lin(16, 32), ReLU(), Lin(32, 32)) conv = GINEConv(nn, train_eps=True, edge_dim=8) out = conv(x, edge_index, edge_attr) assert out.size() == (4, 32) nn = Lin(16, 32) conv = GINEConv(nn, train_eps=True, edge_dim=8) out = conv(x, edge_index, edge_attr) assert out.size() == (4, 32)
def __init__(self, hidden_channels, out_channels, num_layers=3, dropout=0.5): super().__init__() self.dropout = dropout self.atom_encoder = AtomEncoder(hidden_channels) self.bond_encoder = BondEncoder(hidden_channels) self.convs = torch.nn.ModuleList() for _ in range(num_layers): nn = Sequential( Linear(hidden_channels, 2 * hidden_channels), BatchNorm(2 * hidden_channels), ReLU(), Linear(2 * hidden_channels, hidden_channels), BatchNorm(hidden_channels), ReLU(), ) self.convs.append(GINEConv(nn, train_eps=True)) self.lin = Linear(hidden_channels, out_channels)
def __init__(self, layers=5, vfeat=12, efeat=4, hidden=32, nclass=6): super(GINEWOBN, self).__init__(vfeat, hidden) self.layers = layers self.vfeat = vfeat self.efeat = efeat self.hidden = hidden self.nclass = nclass self.convs = nn.ModuleList() self.bns = nn.ModuleList() self.edge_encoders = nn.ModuleList() for i in range(layers): hfunc = nn.Sequential( nn.Linear(hidden, 2 * hidden), nn.BatchNorm1d(2 * hidden), nn.ReLU(), nn.Linear(2 * hidden, hidden), ) self.convs.append(GINEConv(hfunc, train_eps=True)) self.bns.append(nn.BatchNorm1d(hidden)) self.edge_encoders.append( nn.Sequential( nn.Linear(efeat, hidden), nn.ReLU(), nn.Linear(hidden, hidden), )) self.fc1 = nn.Linear(hidden, hidden) self.fc2 = nn.Linear(hidden, nclass)
def test_static_gine_conv(): x = torch.randn(3, 4, 16) edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) edge_attr = torch.randn(edge_index.size(1), 16) nn = Seq(Lin(16, 32), ReLU(), Lin(32, 32)) conv = GINEConv(nn, train_eps=True) out = conv(x, edge_index, edge_attr) assert out.size() == (3, 4, 32)
def test_gine_conv(): in_channels, out_channels = (16, 32) edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) num_nodes = edge_index.max().item() + 1 x = torch.randn((num_nodes, in_channels)) edge_attr = torch.randn((edge_index.size(1), in_channels)) nn = Seq(Lin(in_channels, 32), ReLU(), Lin(32, out_channels)) conv = GINEConv(nn, train_eps=True) assert conv.__repr__() == ( 'GINEConv(nn=Sequential(\n' ' (0): Linear(in_features=16, out_features=32, bias=True)\n' ' (1): ReLU()\n' ' (2): Linear(in_features=32, out_features=32, bias=True)\n' '))') assert conv(x, edge_index, edge_attr).size() == (num_nodes, out_channels) conv = GINEConv(nn, train_eps=False) assert conv(x, edge_index, edge_attr).size() == (num_nodes, out_channels)
def __init__(self, hidden_channels, out_channels, num_layers, dropout=0.0, inter_message_passing=True): super(Net, self).__init__() self.num_layers = num_layers self.dropout = dropout self.inter_message_passing = inter_message_passing self.atom_encoder = AtomEncoder(hidden_channels) self.clique_encoder = Embedding(4, hidden_channels) self.bond_encoders = ModuleList() self.atom_convs = ModuleList() self.atom_batch_norms = ModuleList() for _ in range(num_layers): self.bond_encoders.append(BondEncoder(hidden_channels)) nn = Sequential( Linear(hidden_channels, 2 * hidden_channels), BatchNorm1d(2 * hidden_channels), ReLU(), Linear(2 * hidden_channels, hidden_channels), ) self.atom_convs.append(GINEConv(nn, train_eps=True)) self.atom_batch_norms.append(BatchNorm1d(hidden_channels)) self.clique_convs = ModuleList() self.clique_batch_norms = ModuleList() for _ in range(num_layers): nn = Sequential( Linear(hidden_channels, 2 * hidden_channels), BatchNorm1d(2 * hidden_channels), ReLU(), Linear(2 * hidden_channels, hidden_channels), ) self.clique_convs.append(GINConv(nn, train_eps=True)) self.clique_batch_norms.append(BatchNorm1d(hidden_channels)) self.atom2clique_lins = ModuleList() self.clique2atom_lins = ModuleList() for _ in range(num_layers): self.atom2clique_lins.append( Linear(hidden_channels, hidden_channels)) self.clique2atom_lins.append( Linear(hidden_channels, hidden_channels)) self.atom_lin = Linear(hidden_channels, hidden_channels) self.clique_lin = Linear(hidden_channels, hidden_channels) self.lin = Linear(hidden_channels, out_channels)
def __init__(self, feat_dim, hid_dim, dropout): super(GINEEncoder, self).__init__() self.x_encode = nn.Sequential(nn.Linear(feat_dim, hid_dim), nn.ReLU()) self.nn1 = nn.Sequential(nn.Linear(hid_dim, hid_dim), nn.ReLU(), nn.Linear(hid_dim, hid_dim), nn.ReLU(), nn.Linear(hid_dim, hid_dim)) self.nn2 = nn.Sequential(nn.Linear(hid_dim, hid_dim), nn.ReLU(), nn.Linear(hid_dim, hid_dim), nn.ReLU(), nn.Linear(hid_dim, hid_dim)) self.conv1 = GINEConv(self.nn1, train_eps=True) self.conv2 = GINEConv(self.nn2, train_eps=True) self.dropout = nn.Dropout(dropout) self.act = F.relu self._reset_parameters() self.attr_encode = nn.Embedding(8, hid_dim)
def test_gine_conv(): x1 = torch.randn(4, 16) x2 = torch.randn(2, 16) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) row, col = edge_index value = torch.randn(row.size(0), 16) adj = SparseTensor(row=row, col=col, value=value, sparse_sizes=(4, 4)) nn = Seq(Lin(16, 32), ReLU(), Lin(32, 32)) conv = GINEConv(nn, train_eps=True) assert conv.__repr__() == ( 'GINEConv(nn=Sequential(\n' ' (0): Linear(in_features=16, out_features=32, bias=True)\n' ' (1): ReLU()\n' ' (2): Linear(in_features=32, out_features=32, bias=True)\n' '))') out = conv(x1, edge_index, value) assert out.size() == (4, 32) assert conv(x1, edge_index, value, size=(4, 4)).tolist() == out.tolist() assert conv(x1, adj.t()).tolist() == out.tolist() if is_full_test(): t = '(Tensor, Tensor, OptTensor, Size) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit(x1, edge_index, value).tolist() == out.tolist() assert jit(x1, edge_index, value, size=(4, 4)).tolist() == out.tolist() t = '(Tensor, SparseTensor, OptTensor, Size) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit(x1, adj.t()).tolist() == out.tolist() adj = adj.sparse_resize((4, 2)) out1 = conv((x1, x2), edge_index, value) out2 = conv((x1, None), edge_index, value, (4, 2)) assert out1.size() == (2, 32) assert out2.size() == (2, 32) assert conv((x1, x2), edge_index, value, (4, 2)).tolist() == out1.tolist() assert conv((x1, x2), adj.t()).tolist() == out1.tolist() assert conv((x1, None), adj.t()).tolist() == out2.tolist() if is_full_test(): t = '(OptPairTensor, Tensor, OptTensor, Size) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit((x1, x2), edge_index, value).tolist() == out1.tolist() assert jit((x1, x2), edge_index, value, size=(4, 2)).tolist() == out1.tolist() assert jit((x1, None), edge_index, value, size=(4, 2)).tolist() == out2.tolist() t = '(OptPairTensor, SparseTensor, OptTensor, Size) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit((x1, x2), adj.t()).tolist() == out1.tolist() assert jit((x1, None), adj.t()).tolist() == out2.tolist()
def __init__(self, D, C, G=0, task='graph'): super(GINENet, self).__init__() self.D = D self.C = C self.G = G self.task = task # Convolution layers self.conv1 = GINEConv(MLP([self.D, self.D])) self.conv2 = GINEConv(MLP([self.D, 64])) # "Fusion" layer taking in conv1 and conv2 outputs self.lin1 = MLP([self.D + 64, 96]) if (self.G > 0): self.Z = 96 + self.G else: self.Z = 96 # Final layers concatenating everything self.mlp1 = MLP([self.Z, self.Z, self.C])
def __init__(self, in_channels, out_channels, ins_dim, dropout=0.0): super(gine_seq, self).__init__() # 5 layers of conv with BN, ReLU, and Dropout in between self.convs = torch.nn.ModuleList([ GINEConv( Seq(Lin(in_channels + ins_dim, out_channels), ReLU(), Lin(out_channels, out_channels))) for _ in range(5) ]) # for the last output, no batch norm self.bns = torch.nn.ModuleList( [torch.nn.BatchNorm1d(out_channels) for _ in range(5 - 1)]) self.dropout = dropout
def _build_base_mp_layers(self, hidden_channels, num_layers): self.atom_encoder = blocks.AtomEncoder(hidden_channels) self.bond_encoders = ModuleList() self.atom_convs = ModuleList() self.atom_batch_norms = ModuleList() for _ in range(num_layers): self.bond_encoders.append(blocks.BondEncoder(hidden_channels)) nn = Sequential( Linear(hidden_channels, 2 * hidden_channels), BatchNorm1d(2 * hidden_channels), ReLU(), Linear(2 * hidden_channels, hidden_channels), ) self.atom_convs.append(GINEConv(nn, train_eps=True)) self.atom_batch_norms.append(BatchNorm1d(hidden_channels)) self.atom_lin = Linear(hidden_channels, hidden_channels)
def __init__(self): super().__init__() self.lin1 = Linear(8, 16) self.conv1 = GINEConv(nn=Linear(16, 32))