Example #1
0
def test_global_pool():
    N_1, N_2 = 4, 6
    x = torch.randn(N_1 + N_2, 4)
    batch = torch.tensor([0 for _ in range(N_1)] + [1 for _ in range(N_2)])

    out = global_add_pool(x, batch)
    assert out.size() == (2, 4)
    assert out[0].tolist() == x[:4].sum(dim=0).tolist()
    assert out[1].tolist() == x[4:].sum(dim=0).tolist()

    out = global_add_pool(x, None)
    assert out.size() == (1, 4)
    assert out.tolist() == x.sum(dim=0, keepdim=True).tolist()

    out = global_mean_pool(x, batch)
    assert out.size() == (2, 4)
    assert out[0].tolist() == x[:4].mean(dim=0).tolist()
    assert out[1].tolist() == x[4:].mean(dim=0).tolist()

    out = global_mean_pool(x, None)
    assert out.size() == (1, 4)
    assert out.tolist() == x.mean(dim=0, keepdim=True).tolist()

    out = global_max_pool(x, batch)
    assert out.size() == (2, 4)
    assert out[0].tolist() == x[:4].max(dim=0)[0].tolist()
    assert out[1].tolist() == x[4:].max(dim=0)[0].tolist()

    out = global_max_pool(x, None)
    assert out.size() == (1, 4)
    assert out.tolist() == x.max(dim=0, keepdim=True)[0].tolist()
Example #2
0
    def forward(self, data):
        x, edge_index, node_depth, batch = data.x, data.edge_index, data.node_depth, data.batch

        x = self.node_encoder(x, node_depth.view(-1, ))

        edge_weight = None
        x = F.relu(self.conv1(x, edge_index))
        xs = [global_mean_pool(x, batch)]
        for i, conv in enumerate(self.convs):
            x = conv(x=x, edge_index=edge_index, edge_weight=edge_weight)
            x = F.relu(x)
            xs += [global_mean_pool(x, batch)]
            if i % 2 == 0 and i < len(self.convs) - 1:
                pool = self.pools[i // 2]
                x, edge_index, edge_weight, batch, _ = pool(
                    x=x, edge_index=edge_index, edge_weight=edge_weight,
                    batch=batch)
        x = self.jump(xs)
        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        # x = self.lin2(x)
        # return F.log_softmax(x, dim=-1)

        if self.num_class > 0:
            return self.graph_pred_linear(x)

        pred_list = []
        for i in range(self.max_seq_len):
            pred_list.append(self.graph_pred_linear_list[i](x))
        return pred_list
Example #3
0
    def forward(self,
                z,
                edge_index,
                batch,
                x=None,
                edge_weight=None,
                node_id=None):
        z_emb = self.z_embedding(z)
        if z_emb.ndim == 3:  # in case z has multiple integer labels
            z_emb = z_emb.sum(dim=1)
        if self.use_feature and x is not None:
            x = torch.cat([z_emb, x.to(torch.float)], 1)
        else:
            x = z_emb
        if self.node_embedding is not None and node_id is not None:
            n_emb = self.node_embedding(node_id)
            x = torch.cat([x, n_emb], 1)
        x = self.conv1(x, edge_index)
        xs = [x]
        for conv in self.convs:
            x = conv(x, edge_index)
            xs += [x]
        if self.jk:
            x = global_mean_pool(torch.cat(xs, dim=1), batch)
        else:
            x = global_mean_pool(xs[-1], batch)
        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.lin2(x)

        return x
Example #4
0
    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        batch = data.batch if hasattr(data, 'batch') else None

        x = F.relu(self.conv1(x, edge_index))
        x, edge_index, _, batch, _, _ = self.pool1(x, edge_index, None, batch)
        x1 = torch.cat([global_max_pool(x, batch),
                        global_mean_pool(x, batch)],
                       dim=1)

        x = F.relu(self.conv2(x, edge_index))
        x, edge_index, _, batch, _, _ = self.pool2(x, edge_index, None, batch)
        x2 = torch.cat([global_max_pool(x, batch),
                        global_mean_pool(x, batch)],
                       dim=1)

        x = F.relu(self.conv3(x, edge_index))
        x, edge_index, _, batch, _, _ = self.pool3(x, edge_index, None, batch)
        x3 = torch.cat([global_max_pool(x, batch),
                        global_mean_pool(x, batch)],
                       dim=1)

        x = x1 + x2 + x3

        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = F.relu(self.lin2(x))

        return x
    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        if data.num_node_features == 0:
            x = torch.ones(data.num_nodes, 1)

        for i in range(self.num_layers):
            x = self.convs[i](x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
            if not i == self.num_layers - 1:
                x = self.norm[i](x)

        if not self.global_pool:
            x = pyg_nn.global_mean_pool(x, batch)
        elif self.global_pool == 'max':
            x = pyg_nn.global_max_pool(x, batch)
        elif self.global_pool == 'mix':
            x1 = pyg_nn.global_mean_pool(x, batch)
            x2 = pyg_nn.global_max_pool(x, batch)
            x = torch.cat((x1, x2), 1)

        x = self.post_mp(x)
        emb = x
        out = F.log_softmax(x, dim=1)

        return emb, out
Example #6
0
    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = F.relu(self.conv1(x, edge_index))
        x, edge_index, _, batch, _ = self.pool1(x, edge_index, None, batch)
        x1 = torch.cat(
            [gnn.global_max_pool(x, batch),
             gnn.global_mean_pool(x, batch)],
            dim=1)

        x = F.relu(self.conv2(x, edge_index))
        x, edge_index, _, batch, _ = self.pool2(x, edge_index, None, batch)
        x2 = torch.cat(
            [gnn.global_max_pool(x, batch),
             gnn.global_mean_pool(x, batch)],
            dim=1)

        x = F.relu(self.conv3(x, edge_index))
        x, edge_index, _, batch, _ = self.pool2(x, edge_index, None, batch)
        x3 = torch.cat(
            [gnn.global_max_pool(x, batch),
             gnn.global_mean_pool(x, batch)],
            dim=1)

        x = x1 + x2 + x3

        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = F.relu(self.lin2(x))
        x = F.log_softmax(self.lin3(x), dim=-1)

        return x
Example #7
0
    def forward(self, x, edge_index, batch):
        if x.dim() == 1:
            x = x.unsqueeze(-1)

        if self.readout == 'mean':
            output_list = [global_mean_pool(x, batch)]
        else:
            output_list = [global_add_pool(x, batch)]
        hid_x = self.fc(x)

        for conv in self.conv_layers:
            hid_x = conv(hid_x, edge_index)
            if self.readout == 'mean':
                output_list.append(global_mean_pool(hid_x, batch))
            else:
                output_list.append(global_add_pool(hid_x, batch))

        score_over_layer = 0
        for layer, h in enumerate(output_list):
            h = self.bns_fc[layer](h)
            score_over_layer += F.relu(self.linears_prediction[layer](h))

        if self.dropout > 0:
            x = F.dropout(x, p=self.dropout, training=self.training)

        x = self.linear(score_over_layer)
        return F.log_softmax(x, dim=-1)
Example #8
0
    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = self.item_embedding(x).squeeze(1)

        x = F.relu(self.conv1(x, edge_index))
        x, edge_index, _, batch, *_ = self.pool1(x, edge_index, batch=batch)
        x1 = torch.cat([global_max_pool(x, batch),
                        global_mean_pool(x, batch)],
                       dim=1)

        x = F.relu(self.conv2(x, edge_index))
        x, edge_index, _, batch, *_ = self.pool2(x, edge_index, batch=batch)
        x2 = torch.cat([global_max_pool(x, batch),
                        global_mean_pool(x, batch)],
                       dim=1)

        x = F.relu(self.conv3(x, edge_index))
        x, edge_index, _, batch, *_ = self.pool3(x, edge_index, batch=batch)
        x3 = torch.cat([global_max_pool(x, batch),
                        global_mean_pool(x, batch)],
                       dim=1)

        x = x1 + x2 + x3

        x = self.fc1(x)
        x = self.fc2(x)
        x = self.drop(x)

        x = torch.sigmoid(self.linear(x)).squeeze(1)

        return x
Example #9
0
    def forward(self,entry1_data,entry2_data,get_latent_varaible=False):

        entry2_data,entry2_seq_data = entry2_data
        entry1_data,entry1_seq_data = entry1_data

        entry1_x,entry1_edge_index,entry1_edge_attr,entry1_batch = entry1_data.x,entry1_data.edge_index,entry1_data.edge_attr,entry1_data.batch
        entry1_out = self.gconv1(entry1_x,entry1_edge_index,entry1_edge_attr,entry1_batch )
        entry1_mean = global_mean_pool(entry1_out,entry1_batch)
        entry1_seq_mean = self.gconv1_seq(entry1_seq_data)
        entry1_mean = t.cat([entry1_mean,entry1_seq_mean],dim=-1)

        entry2_x,entry2_edge_index,entry2_edge_attr,entry2_batch = entry2_data.x,entry2_data.edge_index,entry2_data.edge_attr,entry2_data.batch
        entry2_out = self.gconv2(entry2_x,entry2_edge_index,entry2_edge_attr,entry2_batch)
        entry2_mean = global_mean_pool(entry2_out,entry2_batch)
        entry2_seq_mean = self.gconv2_seq(entry2_seq_data)
        entry2_mean = t.cat([entry2_mean,entry2_seq_mean],dim=-1)

        cat_features = t.cat([entry1_mean,entry2_mean],dim=-1)
        x = self.global_fc_nn(cat_features)  
        if get_latent_varaible:
            return x
        else:
            x = self.fc2(x)
            if self.out_activation_func == 'softmax':
                return F.softmax(x, dim=-1) # F.log_softmax(x, dim=-1)
            elif self.out_activation_func == 'sigmoid':
                return t.sigmoid(x)
            elif self.out_activation_func is None : 
                return x 
    def forward(self, data):
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
        #x = F.relu(self.conv1(x=x, edge_index = edge_index))
        if (not self.use_weight):
            x = F.relu(self.conv1(x=x, edge_index=edge_index))
        else:
            x = F.relu(
                self.conv1(x=x, edge_index=edge_index, edge_weight=edge_attr))

        xs = [global_mean_pool(x, batch)]
        for i, conv in enumerate(self.convs):
            #x = F.relu(conv(x=x, edge_index = edge_index))
            if (not self.use_weight):
                x = F.relu(conv(x=x, edge_index=edge_index))
            else:
                x = F.relu(
                    conv(x=x, edge_index=edge_index, edge_weight=edge_attr))
            xs += [global_mean_pool(x, batch)]
            if i % 2 == 0 and i < len(self.convs) - 1:
                pool = self.pools[i // 2]
                #x, edge_index, _, batch, _, _ = pool(x, edge_index,batch=batch)
                if (not self.use_weight):
                    x, edge_index, _, batch, _, _ = pool(x,
                                                         edge_index,
                                                         batch=batch)
                else:
                    x, edge_index, edge_attr, batch, _, _ = pool(
                        x, edge_index, edge_attr=edge_attr, batch=batch)

        x = self.jump(xs)
        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin2(x)
        #print(x.shape)
        return F.log_softmax(x, dim=-1)
    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch

        if self.encode_edge:
            x = self.atom_encoder(x)
            x = self.conv1(x, edge_index, data.edge_attr)
        else:
            x = self.conv1(x, edge_index)
        x = F.relu(x)
        xs = [global_mean_pool(x, batch)]
        for i, conv in enumerate(self.convs):
            x = F.relu(conv(x, edge_index))
            xs += [global_mean_pool(x, batch)]
            if self.pooling_type != 'none':
                if self.pooling_type == 'complement':
                    complement = batched_negative_edges(edge_index=edge_index, batch=batch, force_undirected=True)
                    cluster = graclus(complement, num_nodes=x.size(0))
                elif self.pooling_type == 'graclus':
                    cluster = graclus(edge_index, num_nodes=x.size(0))
                data = Batch(x=x, edge_index=edge_index, batch=batch)
                data = max_pool(cluster, data)
                x, edge_index, batch = data.x, data.edge_index, data.batch

        if not self.no_cat:
            x = self.jump(xs)
        else:
            x = global_mean_pool(x, batch)
        x = F.relu(self.lin1(x))
        x = self.lin2(x)
        return x
Example #12
0
    def forward(self, x, edge_index, edge_weight, batch):
        x = self.lin1(x)
        x = x.relu()
        gcn1 = F.relu(self.gcn1(x, edge_index, edge_weight))
        x1, edge_index1, edge_attr1, batch1, _, _ = \
            self.pool1(gcn1, edge_index, edge_weight, batch=batch)
        global_pool1 = geo_nn.global_mean_pool(x1, batch1)
        # global_pool1 = torch.cat(
        #     [geo_nn.global_mean_pool(x1, batch1),
        #      geo_nn.global_max_pool(x1, batch1)],
        #     dim=1)

        gcn2 = F.relu(self.gcn2(x1, edge_index1, edge_attr1))
        x2, edge_index2, edge_attr2, batch2, _, _ = \
            self.pool2(gcn2, edge_index1, edge_attr1, batch=batch1)
        global_pool2 = geo_nn.global_mean_pool(x2, batch2)
        # global_pool2 = torch.cat(
        #     [geo_nn.global_mean_pool(x2, batch2),
        #      geo_nn.global_max_pool(x2, batch2)],
        #     dim=1)

        gcn3 = F.relu(self.gcn3(x2, edge_index2, edge_attr2))
        x3, edge_index3, edge_attr3, batch3, _, _ = \
            self.pool3(gcn3, edge_index2, edge_attr2, batch=batch2)
        global_pool3 = geo_nn.global_mean_pool(x3, batch3)
        # global_pool3 = torch.cat(
        #     [geo_nn.global_mean_pool(x3, batch3),
        #      geo_nn.global_max_pool(x3, batch3)],
        #     dim=1)

        x = global_pool1 + global_pool2 + global_pool3
        x = self.mlp(x)
        return x
Example #13
0
    def forward(self, h, edge_index, edge_attr, batch):

        if self.training:
            for l, layer in enumerate(self.layers):
                t1 = time.perf_counter()
                h = layer(h, edge_index, self.pseudo_proj[l](edge_attr))
                print("conv", l, "forward time: ", time.perf_counter() - t1)
                h.register_hook(hook_gcn)
                h = F.relu(h)
                h.register_hook(hook_relu)
                # h = self.dropout(h)

            t2 = time.perf_counter()
            h = global_mean_pool(h, batch)
            print("pooling forward time: ", time.perf_counter() - t2)
            h.register_hook(hook_pool)

            t3 = time.perf_counter()
            h = self.fc1(h)
            print("fc1 forward time: ", time.perf_counter() - t3)
            h.register_hook(hook)

            h = F.elu(h)
            h.register_hook(hook_relu)

            t4 = time.perf_counter()
            h = self.fc2(h)
            print("fc2 forward time: ", time.perf_counter() - t4)
            h.register_hook(hook)
            # h = self.readout(h)
            h = F.log_softmax(h, dim=0)
            h.register_hook(hook)
        else:
            for l, layer in enumerate(self.layers):
                # t1 = time.perf_counter()
                h = layer(h, edge_index, self.pseudo_proj[l](edge_attr))
                # print("conv", l, "forward time: ", time.perf_counter() - t1)
                h = F.relu(h)
                # h = self.dropout(h)

            # t2 = time.perf_counter()
            h = global_mean_pool(h, batch)
            # print("pooling forward time: ", time.perf_counter() - t2)

            # t3 = time.perf_counter()
            h = self.fc1(h)
            # print("fc1 forward time: ", time.perf_counter() - t3)

            h = F.elu(h)

            # t4 = time.perf_counter()
            h = self.fc2(h)
            # print("fc2 forward time: ", time.perf_counter() - t4)
            # h = self.readout(h)
            h = F.log_softmax(h, dim=0)
        return h
Example #14
0
 def forward(self, x, edge, batch, type='mean_pool'):
     if type == 'mean_pool':
         return global_mean_pool(x, batch)
     elif type == 'max_pool':
         return global_max_pool(x, batch)
     elif type == 'sum_pool':
         return global_add_pool(x, batch)
     elif type == 'sag_pool':
         x1, _, _, batch, _, _ = self.sag_pool(x, edge, batch=batch)
         return global_mean_pool(x1, batch)
Example #15
0
 def forward(self, data):
     x, edge_index, batch = data.x, data.edge_index, data.batch
     x = F.relu(self.conv1(x, edge_index))
     xs = [global_mean_pool(x, batch)]
     for i, (conv, pool) in enumerate(zip(self.convs, self.pools)):
         x = F.relu(conv(x, edge_index))
         xs += [global_mean_pool(x, batch)]
         if i % 2 == 0:
             x, edge_index, _, batch, _ = pool(x, edge_index, batch=batch)
     x = self.jump(xs)
     x = F.relu(self.lin1(x))
     x = F.dropout(x, p=0.5, training=self.training)
     x = self.lin2(x)
     return F.log_softmax(x, dim=-1)
 def forward(self, data):
     x, edge_index, batch = data.x, data.edge_index, data.batch
     x = F.relu(self.conv1(x, edge_index))
     xs = [global_mean_pool(x, batch)]
     for i, conv in enumerate(self.convs):
         x = F.relu(conv(x, edge_index))
         xs += [global_mean_pool(x, batch)]
         if i % 2 == 0 and i < len(self.convs) - 1:
             pool = self.pools[i // 2]
             x, edge_index, batch, _ = pool(x, edge_index, batch=batch)
     x = self.jump(xs)
     x = F.relu(self.lin1(x))
     x = F.dropout(x, p=0.5, training=self.training)
     x = self.lin2(x)
     return F.log_softmax(x, dim=-1)
Example #17
0
    def forward(self, data16, data20, datacol):
        x16, edge_index16, batch16 = data16.features, data16.edge_index, data16.batch
        x20, edge_index20, batch20 = data20.features, data20.edge_index, data20.batch
        xcol, edge_indexcol, batchcol = datacol.features, datacol.edge_index, datacol.batch
        x16 = self.conv1(x16, edge_index16)
        x16 = F.relu(x16)
        x16 = self.conv2(x16, edge_index16)
        x16 = F.relu(x16)
        x16 = self.conv3(x16, edge_index16)
        x16 = torch.transpose(x16, 0, 1)
        x16 = self.i1(x16.unsqueeze(0))
        x16 = self.i2(x16)
        x16 = torch.transpose(x16.squeeze(0), 0, 1)
        x16 = global_mean_pool(x16, batch16)

        x20 = self.conv1(x20, edge_index20)
        x20 = F.relu(x20)
        x20 = self.conv2(x20, edge_index20)
        x20 = F.relu(x20)
        x20 = self.conv3(x20, edge_index20)
        x20 = torch.transpose(x20, 0, 1)
        x20 = self.i1(x20.unsqueeze(0))
        x20 = self.i2(x20)
        x20 = torch.transpose(x20.squeeze(0), 0, 1)

        x20 = global_mean_pool(x20, batch20)

        xcol = self.colconv(xcol, edge_indexcol)
        xcol = F.relu(xcol)
        xcol = self.conv2(xcol, edge_indexcol)
        xcol = F.relu(xcol)
        xcol = self.conv3(xcol, edge_indexcol)
        xcol = torch.transpose(xcol, 0, 1)
        xcol = self.i1(xcol.unsqueeze(0))
        xcol = self.i2(xcol)
        xcol = torch.transpose(xcol.squeeze(0), 0, 1)

        xcol = global_mean_pool(xcol, batchcol)
        xcol = self.lin2(xcol)

        xcol = self.mlp2(xcol)

        x = torch.cat([x16, x20], dim=1)
        x = self.lin1(x)
        x = self.mlp(x)
        out = torch.cat([x, xcol], dim=1)
        out = self.finallin(out)
        return F.log_softmax(out, dim=-1)
Example #18
0
def test_permuted_global_pool():
    N_1, N_2 = 4, 6
    x = torch.randn(N_1 + N_2, 4)
    batch = torch.cat([torch.zeros(N_1), torch.ones(N_2)]).to(torch.long)
    perm = torch.randperm(N_1 + N_2)

    px = x[perm]
    pbatch = batch[perm]
    px1 = px[pbatch == 0]
    px2 = px[pbatch == 1]

    out = global_add_pool(px, pbatch)
    assert out.size() == (2, 4)
    assert torch.allclose(out[0], px1.sum(dim=0))
    assert torch.allclose(out[1], px2.sum(dim=0))

    out = global_mean_pool(px, pbatch)
    assert out.size() == (2, 4)
    assert torch.allclose(out[0], px1.mean(dim=0))
    assert torch.allclose(out[1], px2.mean(dim=0))

    out = global_max_pool(px, pbatch)
    assert out.size() == (2, 4)
    assert torch.allclose(out[0], px1.max(dim=0)[0])
    assert torch.allclose(out[1], px2.max(dim=0)[0])
Example #19
0
    def forward(self, pos, batch):
        x = pos.new_ones((pos.size(0), 1))

        radius = 0.2
        edge_index = radius_graph(pos, r=radius, batch=batch)
        pseudo = (pos[edge_index[1]] - pos[edge_index[0]]) / (2 * radius) + 0.5
        pseudo = pseudo.clamp(min=0, max=1)
        x = F.elu(self.conv1(x, edge_index, pseudo))

        idx = fps(pos, batch, ratio=0.5)
        x, pos, batch = x[idx], pos[idx], batch[idx]

        radius = 0.4
        edge_index = radius_graph(pos, r=radius, batch=batch)
        pseudo = (pos[edge_index[1]] - pos[edge_index[0]]) / (2 * radius) + 0.5
        pseudo = pseudo.clamp(min=0, max=1)
        x = F.elu(self.conv2(x, edge_index, pseudo))

        idx = fps(pos, batch, ratio=0.25)
        x, pos, batch = x[idx], pos[idx], batch[idx]

        radius = 1
        edge_index = radius_graph(pos, r=radius, batch=batch)
        pseudo = (pos[edge_index[1]] - pos[edge_index[0]]) / (2 * radius) + 0.5
        pseudo = pseudo.clamp(min=0, max=1)
        x = F.elu(self.conv3(x, edge_index, pseudo))

        x = global_mean_pool(x, batch)

        x = F.elu(self.lin1(x))
        x = F.elu(self.lin2(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin3(x)
        return F.log_softmax(x, dim=-1)
Example #20
0
    def forward(self, data):
        data.x = F.elu(self.conv1(data.x, data.edge_index, data.edge_attr))
        weight = normalized_cut_2d(data.edge_index, data.pos)
        cluster = graclus(data.edge_index, weight, data.x.size(0))
        data.edge_attr = None
        data = max_pool(cluster, data, transform=T.Cartesian(cat=False))

        row, col = data.edge_index
        data.edge_attr = (data.pos[col] -
                          data.pos[row]) / (2 * self.args.cutoff) + 0.5

        data.x = F.elu(self.conv2(data.x, data.edge_index, data.edge_attr))
        weight = normalized_cut_2d(data.edge_index, data.pos)
        cluster = graclus(data.edge_index, weight, data.x.size(0))
        data = max_pool(cluster, data, transform=T.Cartesian(cat=False))

        row, col = data.edge_index
        data.edge_attr = (data.pos[col] -
                          data.pos[row]) / (2 * self.args.cutoff) + 0.5

        data.x = F.elu(self.conv3(data.x, data.edge_index, data.edge_attr))

        x = global_mean_pool(data.x, data.batch)
        x = F.elu(self.fc1(x))
        x = F.dropout(x, training=self.training, p=self.args.disc_dropout)
        y = self.fc2(x)

        if (self.args.wgan):
            return y

        return torch.sigmoid(y)
Example #21
0
    def forward(self, x, edge_index, batch, x_mord):
        # FORWARD CNN
        x = self.graph_conv1(x, edge_index)
        x = x.relu()
        x = self.graph_conv2(x, edge_index)
        x = x.relu()
        x = self.graph_conv3(x, edge_index)

        x = global_mean_pool(x, batch)

        x_mord = F.relu(self.dense_fc1(x_mord))
        x_mord = self.dense_batch_norm1(x_mord)
        x_mord = self.dense_dropout(x_mord)

        x_mord = F.relu(self.dense_fc2(x_mord))
        x_mord = self.dense_batch_norm2(x_mord)
        x_mord = self.dense_dropout(x_mord)

        x_mord = F.relu(self.dense_fc3(x_mord))
        x_mord = self.dense_batch_norm3(x_mord)
        x_mord = self.dense_dropout(x_mord)

        x = torch.cat([x, x_mord], dim=1)

        return torch.sigmoid(self.linear(x))
Example #22
0
    def forward(self, data):
        x, edge_index, y, batch = data.x, data.edge_index, data.y, data.batch
        x = F.normalize(x, p=1., dim=-1)
        self.original_x = x
        self.label = y.long()

        e = self.conv1(x, edge_index)
        e = self.conv2(e, edge_index)
        e = self.conv3(e, edge_index)

        # 2. Readout layer
        c = global_mean_pool(e, batch)
        c = F.dropout(c, p=0.1, training=self.training)
        c = self.lin1(c)
        b = self.bn1(c)
        b = torch.tanh(b)
        b = F.dropout(b, p=0.1, training=self.training)
        c = self.lin2(b)
        b = self.bn2(c)
        b = torch.tanh(b)
        b = F.dropout(b, p=0.1, training=self.training)
        c = self.lin3(b)
        b = self.bn3(c)
        b = torch.tanh(b)
        c = self.lin4(b)

        return c
    def forward(self, data):
        x = data.x

        x_1 = F.relu(self.conv1_1(x, data.edge_index_1))
        x_2 = F.relu(self.conv1_2(x, data.edge_index_2))
        x_1_r = self.mlp_1(torch.cat([x_1, x_2], dim=-1))
        x_1_r = self.bn1(x_1_r)

        x_1 = F.relu(self.conv2_1(x_1_r, data.edge_index_1))
        x_2 = F.relu(self.conv2_2(x_1_r, data.edge_index_2))
        x_2_r = self.mlp_2(torch.cat([x_1, x_2], dim=-1))
        x_2_r = self.bn2(x_2_r)

        x_1 = F.relu(self.conv3_1(x_2_r, data.edge_index_1))
        x_2 = F.relu(self.conv3_2(x_2_r, data.edge_index_2))
        x_3_r = self.mlp_3(torch.cat([x_1, x_2], dim=-1))
        x_3_r = self.bn3(x_3_r)

        x_1 = F.relu(self.conv4_1(x_3_r, data.edge_index_1))
        x_2 = F.relu(self.conv4_2(x_3_r, data.edge_index_2))
        x_4_r = self.mlp_4(torch.cat([x_1, x_2], dim=-1))
        x_4_r = self.bn4(x_4_r)

        x = torch.cat([x_1_r, x_2_r, x_3_r, x_4_r], dim=-1)
        x = global_mean_pool(x, data.batch)

        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = self.fc4(x)
        return x.view(-1)
Example #24
0
    def project(self, data, reg_hook=False):
        "Projects data up to last hidden layer for visualization."

        x, edge_index = data.x, data.edge_index

        for conv_layer in self.conv_encoder:
            x = conv_layer(x, edge_index)
            #x = F.relu(x)
            x = torch.tanh(x)

        if reg_hook:
            h = x.register_hook(self.activations_hook)

        if self.pooling == 'mean':
            x = global_mean_pool(x, data.batch)
        elif self.pooling == 'add':
            x = global_add_pool(x, data.batch)
        elif self.pooling == 'max':
            x = global_max_pool(x, data.batch)

        for dense_layer in self.linear_layers[:-1]:
            x = dense_layer(x)
            #x = torch.tanh(x)
            x = F.relu(x)

        x = self.linear_layers[-1](x)

        return x
Example #25
0
    def forward(self, data):
        """Run a forward pass.

        We do not use any activations just ot make sure an untrained
        netwrok will be still able to propagate gradients.

        Parameters
        ----------
        data : torch_gometric.data.Batch
            Batch graph.

        Return
        ------
        y : torch.tensor
            Per graph predictions of shape `(n_samples, dim)`.
        """
        x = data.x  # (n_nodes_batch, n_node_features)
        edge_index = data.edge_index  # (2, n_edges_batch)
        edge_features = data.edge_features  # (n_edges_batch, n_edge_features=1)
        batch = data.batch  # (n_nodes_batch,)

        x = self.conv(x, edge_index,
                      edge_features)  # (n_nodes_batch, n_channels)

        x = global_mean_pool(x, batch)  # (n_samples, n_channels)

        x = self.fc1(x)  # (n_samples, hidden_size)
        y = self.fc2(x)  # (n_samples, n_targets)

        return y
Example #26
0
    def forward(self, x, edge_index, batch):
        # 1. Obtain node embeddings 
        x = self.batchn1(x)
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.batchn2(x)
        x = self.conv2(x, edge_index)
        x = x.relu()
        x, edge_index, _, batch, _, _ = self.pool1(x, edge_index, None, batch)
        x = self.batchn3(x)
        x = self.conv3(x, edge_index)
        x = x.relu()
        x = self.batchn4(x)
        x = self.conv4(x, edge_index)
        x = x.relu()
        x, edge_index, _, batch, _, _ = self.pool2(x, edge_index, None, batch)
        x = self.batchn5(x)
        x = self.conv5(x, edge_index)
        x = x.relu()
        x = self.batchn6(x)
        x = self.conv6(x, edge_index)
        x = x.relu()
        # x, edge_index, _, batch, _, _ = self.pool3(x, edge_index, None, batch)
        x = self.batchn7(x)
        x = self.conv7(x, edge_index)

        # 2. Readout layer
        x = global_mean_pool(x, batch)  # [batch_size, hidden_channels]

        # 3. Apply a final classifier
        # x = F.dropout(x, p=0.1, training=self.training)
        x = self.lin(x)
        
        return x
Example #27
0
    def forward(self, data):
        x, pos, batch = data.x, data.pos[:, :3], data.batch
        x = F.hardtanh(self.conv1(None, pos, batch))

        idx = fps(pos, batch, ratio=0.375)
        x, pos, batch = x[idx], pos[idx], batch[idx]

        x = F.hardtanh(self.conv2(x, pos, batch))

        idx = fps(pos, batch, ratio=0.334)
        x, pos, batch = x[idx], pos[idx], batch[idx]

        x = F.hardtanh(self.conv3(x, pos, batch))
        x = F.hardtanh(self.conv4(x, pos, batch))
        if self.pool == 'max':
            x = global_max_pool(x, batch)
        elif self.pool == 'mean':
            x = global_mean_pool(x, batch)

        x = F.hardtanh(self.lin1(x))
        x = F.hardtanh(self.lin2(x))
        x = self.lin3(x)
        return {
            'out': F.log_softmax(x, dim=-1)
        }
Example #28
0
    def forward(self, data):
        row, col = data.edge_index
        data.edge_attr = (data.pos[col] - data.pos[row]) / (2 * 28 *
                                                            cutoff) + 0.5

        # print(data.edge_index.shape)
        # print(data.edge_index[:, -20:])

        data.x = F.elu(self.conv1(data.x, data.edge_index, data.edge_attr))
        weight = normalized_cut_2d(data.edge_index, data.pos)
        cluster = graclus(data.edge_index, weight, data.x.size(0))
        data.edge_attr = None
        data = max_pool(cluster, data, transform=T.Cartesian(cat=False))

        row, col = data.edge_index
        data.edge_attr = (data.pos[col] - data.pos[row]) / (2 * 28 *
                                                            cutoff) + 0.5

        data.x = F.elu(self.conv2(data.x, data.edge_index, data.edge_attr))
        weight = normalized_cut_2d(data.edge_index, data.pos)
        cluster = graclus(data.edge_index, weight, data.x.size(0))
        data = max_pool(cluster, data, transform=T.Cartesian(cat=False))

        row, col = data.edge_index
        data.edge_attr = (data.pos[col] - data.pos[row]) / (2 * 28 *
                                                            cutoff) + 0.5

        data.x = F.elu(self.conv3(data.x, data.edge_index, data.edge_attr))

        x = global_mean_pool(data.x, data.batch)
        return self.fc1(x)

        x = F.elu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        return F.log_softmax(self.fc2(x), dim=1)
Example #29
0
    def forward(self, data):
        x = data.x
        x = x.long()

        x_new = torch.zeros(x.size(0), 492).to(device)
        x_new[range(x_new.shape[0]), x.view(1, x.size(0))] = 1

        x_1 = F.relu(self.conv1_1(x_new, data.edge_index_1))
        x_2 = F.relu(self.conv1_2(x_new, data.edge_index_2))
        x_1_r = self.mlp_1(torch.cat([x_1, x_2], dim=-1))
        x_1_r = self.bn1(x_1_r)

        x_1 = F.relu(self.conv2_1(x_1_r, data.edge_index_1))
        x_2 = F.relu(self.conv2_2(x_1_r, data.edge_index_2))
        x_2_r = self.mlp_2(torch.cat([x_1, x_2], dim=-1))
        x_2_r = self.bn2(x_2_r)

        x_1 = F.relu(self.conv3_1(x_2_r, data.edge_index_1))
        x_2 = F.relu(self.conv3_2(x_2_r, data.edge_index_2))
        x_3_r = self.mlp_3(torch.cat([x_1, x_2], dim=-1))
        x_3_r = self.bn3(x_3_r)

        x_1 = F.relu(self.conv4_1(x_3_r, data.edge_index_1))
        x_2 = F.relu(self.conv4_2(x_3_r, data.edge_index_2))
        x_4_r = self.mlp_4(torch.cat([x_1, x_2], dim=-1))
        x_4_r = self.bn4(x_4_r)

        x = torch.cat([x_1_r, x_2_r, x_3_r, x_4_r], dim=-1)
        x = global_mean_pool(x, data.batch)

        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = self.fc4(x)
        return x.view(-1)
def pool_func(x, batch, mode="sum"):
    if mode == "sum":
        return global_add_pool(x, batch)
    elif mode == "mean":
        return global_mean_pool(x, batch)
    elif mode == "max":
        return global_max_pool(x, batch)