def test_avg_pool_x():
    cluster = torch.tensor([0, 1, 0, 1, 2, 2])
    x = torch.Tensor([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]])
    batch = torch.tensor([0, 0, 0, 0, 1, 1])

    out = avg_pool_x(cluster, x, batch)
    assert out[0].tolist() == [[3, 4], [5, 6], [10, 11]]
    assert out[1].tolist() == [0, 0, 1]

    out, _ = avg_pool_x(cluster, x, batch, size=2)
    assert out.tolist() == [[3, 4], [5, 6], [10, 11], [0, 0]]
Exemple #2
0
    def forward(self, data, return_hidden_feature=False):

        #import pdb
        #pdb.set_trace()
        if torch.cuda.is_available():
            data.x = data.x.cuda()
            data.edge_attr = data.edge_attr.cuda()
            data.edge_index = data.edge_index.cuda()
            data.batch = data.batch.cuda()

        # make sure that we have undirected graph
        if not is_undirected(data.edge_index):
            data.edge_index = to_undirected(data.edge_index)

        # make sure that nodes can propagate messages to themselves
        if not contains_self_loops(data.edge_index):
            data.edge_index, data.edge_attr = add_self_loops(
                data.edge_index, data.edge_attr.view(-1))

        # covalent_propagation
        # add self loops to enable self propagation
        covalent_edge_index, covalent_edge_attr = self.covalent_neighbor_threshold(
            data.edge_index, data.edge_attr)
        (
            non_covalent_edge_index,
            non_covalent_edge_attr,
        ) = self.non_covalent_neighbor_threshold(data.edge_index,
                                                 data.edge_attr)

        # covalent_propagation and non_covalent_propagation
        covalent_x = self.covalent_propagation(data.x, covalent_edge_index,
                                               covalent_edge_attr)
        non_covalent_x = self.non_covalent_propagation(
            covalent_x, non_covalent_edge_index, non_covalent_edge_attr)

        # zero out the protein features then do ligand only gather...hacky sure but it gets the job done
        non_covalent_ligand_only_x = non_covalent_x
        non_covalent_ligand_only_x[data.x[:, 14] == -1] = 0
        pool_x = self.global_add_pool(non_covalent_ligand_only_x, data.batch)

        # fully connected and output layers
        if return_hidden_feature or self.always_return_hidden_feature:
            # return prediction and atomistic features (covalent result, non-covalent result, pool result)

            avg_covalent_x, _ = avg_pool_x(data.batch, covalent_x, data.batch)
            avg_non_covalent_x, _ = avg_pool_x(data.batch, non_covalent_x,
                                               data.batch)

            fc0_x, fc1_x, output_x = self.output(pool_x,
                                                 return_hidden_feature=True)

            return avg_covalent_x, avg_non_covalent_x, pool_x, fc0_x, fc1_x, output_x
        else:
            return self.output(pool_x)
Exemple #3
0
    def forward(self, data):
        data.x = self.inputnet(data.x)
        data.edge_index = to_undirected(
            knn_graph(data.x,
                      self.k,
                      data.batch,
                      loop=False,
                      flow=self.edgeconv1.flow))
        data.x = self.edgeconv1(data.x, data.edge_index)

        weight = normalized_cut_2d(data.edge_index, data.x)
        cluster = graclus(data.edge_index, weight, data.x.size(0))
        data.edge_attr = None
        data = avg_pool(cluster, data)

        data.edge_index = to_undirected(
            knn_graph(data.x,
                      self.k,
                      data.batch,
                      loop=False,
                      flow=self.edgeconv2.flow))
        data.x = self.edgeconv2(data.x, data.edge_index)

        weight = normalized_cut_2d(data.edge_index, data.x)
        cluster = graclus(data.edge_index, weight, data.x.size(0))
        x, batch = avg_pool_x(cluster, data.x, data.batch)

        x = global_mean_pool(x, batch)
        logits = self.output(x).squeeze(-1)
        return logits
Exemple #4
0
    def forward(self, sample):
        x, edge_index = sample.x, sample.edge_index

        # Dropout layer
        # edge_index = self.dropout_edges(edge_index, dropout=0.2)

        x = self.dense_input(x, self.empty_edges)
        x = F.gelu(x)

        x = self.input(x, edge_index)
        x = F.gelu(x)
        x = self.conv1(x, edge_index)
        x = F.gelu(x)

        # if self.pooling_layers > 1:
        #     batch = torch.tensor([0 for _ in x], dtype=torch.long, device=self.device)
        #     pooled = self.topkpool1(x, edge_index, batch=batch)
        #     x, edge_index = pooled[0], pooled[1]

        x = self.conv2(x, edge_index)
        x = F.gelu(x)

        if self.pooling_layers > 0:
            batch = torch.tensor([0 for _ in x], dtype=torch.long, device=self.device)
            pooled = self.topkpool2(x, edge_index, batch=batch)
            x, edge_index = pooled[0], pooled[1]

        x = self.conv3(x, edge_index)
        x = F.gelu(x)

        # For large graphs
        # while len(x) > 8:
        #     batch = torch.tensor([0 for _ in x], dtype=torch.long, device=self.device)
        #     pooled = self.topkpool3(x, edge_index, batch=batch)
        #     x, edge_index = pooled[0], pooled[1]
        #     x = self.conv4(x, edge_index)
        #     x = F.gelu(x)

        batch = torch.tensor([0 for _ in x], dtype=torch.long, device=self.device)
        # With sort_pool it works but we have the same problem: the output layer learns the order of the pooled nodes
        # using k = 3, let's see what happens by shuffling the nodes
        if self.final_pooling == "avg_pool_x":
            cluster = torch.as_tensor([i % self.final_nodes for i in range(len(x))], device=self.device)
            (x, cluster) = avg_pool_x(cluster, x, batch)
        elif self.final_pooling == "sort_pooling":
            x = global_sort_pool(x, batch, self.final_nodes)
        elif self.final_pooling == "topk" or self.final_pooling == "asap" or self.final_pooling == "sag":
            pooled = self.last_pooling_layer(x, edge_index)
            x = pooled[0]
        elif self.final_pooling == "max_pool_x":
            cluster = torch.as_tensor([i % self.final_nodes for i in range(len(x))], device=self.device)
            (x, cluster) = max_pool_x(cluster, x, batch)
            # (x2, cluster2) = avg_pool_x(cluster, x, batch)
            # x = torch.cat([x1.view(-1), x2.view(-1)])

        return self.output(x.view(-1))
Exemple #5
0
    def forward(self, data):

        pos, edge_index, batch = data.pos, data.edge_index, data.batch
        real_batch_size = pos.size(0) / self.nr_points
        real_batch_size = int(real_batch_size)

        # Build first edges
        edge_index = knn_graph(pos, self.k, batch, loop=False)

        #extract features in 3d
        _, _, features_dd, _ = self.ds1(pos, edge_index, None)

        #graclus
        cluster = graclus(edge_index)

        pos_gra, batch_gra = avg_pool_x(cluster, pos, batch)
        features_gra, _ = max_pool_x(cluster, features_dd, batch)

        #knn(f)
        with torch.no_grad():
            edge_index_gra = knn_graph(features_gra.norm(dim=2),
                                       self.k,
                                       batch_gra,
                                       loop=False)

        # DD2
        _, _, features_dd2, _ = self.dd2(pos_gra, edge_index_gra, features_gra)

        y1 = self.nn1(features_dd2)

        y1_pool, _ = max_pool_x(batch_gra, y1, batch_gra)

        y1_pool = torch.nn.functional.relu(y1_pool)
        y1_pool = self.bn1(y1_pool)

        y2 = self.nn2(y1_pool)
        y2 = torch.nn.functional.relu(y2)
        y2 = self.bn2(y2)

        y3 = self.nn3(y2)
        y3 = torch.nn.functional.relu(y3)
        y3 = self.bn3(y3)

        y4 = self.nn4(y3)
        out = self.sm(y4)

        return out
Exemple #6
0
    def forward(self, data, return_hidden_feature=False):

        data.x = data.x.cuda()
        data.edge_attr = data.edge_attr.cuda()
        data.edge_index = data.edge_index.cuda()
        data.batch = data.batch.cuda()

        # make sure that we have undirected graph
        if not is_undirected(data.edge_index):
            data.edge_index = to_undirected(data.edge_index)

        # make sure that nodes can propagate messages to themselves
        if not contains_self_loops(data.edge_index):
            data.edge_index, data.edge_attr = add_self_loops(
                data.edge_index, data.edge_attr.view(-1))
        """
        # now select the top 5 closest neighbors to each node


        dense_adj = sparse_to_dense(edge_index=data.edge_index, edge_attr=data.edge_attr)

        #top_k_vals, top_k_idxs = torch.topk(dense_adj, dim=0, k=5, largest=False)

        #dense_adj = torch.zeros_like(dense_adj).scatter(1, top_k_idxs, top_k_vals)
        
        dense_adj[dense_adj == 0] = 10000   # insert artificially large values for 0 valued entries that will throw off NN calculation
        top_k_vals, top_k_idxs = torch.topk(dense_adj, dim=1, k=15, largest=False)
        dense_adj = torch.zeros_like(dense_adj).scatter(1, top_k_idxs, top_k_vals)
        
        data.edge_index, data.edge_attr = dense_to_sparse(dense_adj)
        """

        # covalent_propagation
        # add self loops to enable self propagation
        covalent_edge_index, covalent_edge_attr = self.covalent_neighbor_threshold(
            data.edge_index, data.edge_attr)
        (
            non_covalent_edge_index,
            non_covalent_edge_attr,
        ) = self.non_covalent_neighbor_threshold(data.edge_index,
                                                 data.edge_attr)

        # covalent_propagation and non_covalent_propagation
        covalent_x = self.covalent_propagation(data.x, covalent_edge_index,
                                               covalent_edge_attr)
        non_covalent_x = self.non_covalent_propagation(
            covalent_x, non_covalent_edge_index, non_covalent_edge_attr)

        # zero out the protein features then do ligand only gather...hacky sure but it gets the job done
        non_covalent_ligand_only_x = non_covalent_x
        non_covalent_ligand_only_x[data.x[:, 14] == -1] = 0
        pool_x = self.global_add_pool(non_covalent_ligand_only_x, data.batch)

        # fully connected and output layers
        if return_hidden_feature:
            # return prediction and atomistic features (covalent result, non-covalent result, pool result)

            avg_covalent_x, _ = avg_pool_x(data.batch, covalent_x, data.batch)
            avg_non_covalent_x, _ = avg_pool_x(data.batch, non_covalent_x,
                                               data.batch)

            fc0_x, fc1_x, output_x = self.output(pool_x,
                                                 return_hidden_feature=True)

            return avg_covalent_x, avg_non_covalent_x, pool_x, fc0_x, fc1_x, output_x
        else:
            return self.output(pool_x)