Example #1
0
    def forward(self, x, batch: Optional[torch.Tensor] = None):
        x = self.datanorm * x
        x = self.inputnet(x)
        
        edge_index = to_undirected(knn_graph(x, self.k, batch, loop=False, flow=self.edgeconv1.flow))
        x = self.edgeconv1(x, edge_index)        
        weight = normalized_cut_2d(edge_index, x)
        cluster = graclus(edge_index, weight, x.size(0))
        edge_attr = None
        x, edge_index, batch, edge_attr = max_pool(cluster, x, edge_index, batch)

        # Additional layer by Shamik
        edge_index = to_undirected(knn_graph(x, self.k, batch, loop=False, flow=self.edgeconv3.flow))
        x = self.edgeconv1(x, edge_index)        
        weight = normalized_cut_2d(edge_index, x)
        cluster = graclus(edge_index, weight, x.size(0))
        edge_attr = None
        x, edge_index, batch, edge_attr = max_pool(cluster, x, edge_index, batch)
        
        edge_index = to_undirected(knn_graph(x, self.k, batch, loop=False, flow=self.edgeconv2.flow))
        x = self.edgeconv2(x, edge_index)
        
        weight = normalized_cut_2d(edge_index, x)
        cluster = graclus(edge_index, weight, x.size(0))
        x, batch = max_pool_x(cluster, x, batch)

        if not batch is None:
            x = global_max_pool(x, batch)
        
        return self.output(x).squeeze(-1)
Example #2
0
    def forward(self, x, pos, batch=None):

        # add dummy features in case there is none
        if x is None:
            x = torch.ones((pos.shape[0], 1), device=pos.get_device())

        # first block
        x = self.mlp_input(x)
        edge_index = knn_graph(pos, k=self.k, batch=batch)
        x = self.transformer_input(x, pos, edge_index)

        # backbone
        for i in range(len(self.transformers_down)):
            x, pos, batch = self.transition_down[i](x, pos, batch=batch)

            edge_index = knn_graph(pos, k=self.k, batch=batch)
            x = self.transformers_down[i](x, pos, edge_index)

        # GlobalAveragePooling
        x = global_mean_pool(x, batch)

        # Class score
        out = self.mlp_output(x)

        return F.log_softmax(out, dim=-1)
Example #3
0
    def forward(self, x, ret_activations=False, relu_activations=False):
        batch_size = x.size(0)
        x = x.reshape(batch_size * self.num_hits, self.node_feat_size)
        zeros = torch.zeros(batch_size * self.num_hits,
                            dtype=int).to(self.device)
        zeros[torch.arange(batch_size) * self.num_hits] = 1
        batch = torch.cumsum(zeros, 0) - 1

        for i in range(self.num_edge_convs):
            edge_index = knn_graph(
                x[:, :2], self.k, batch) if i == 0 else knn_graph(
                    x, self.k, batch
                )  # using only angular coords for knn in first edgeconv block
            x = torch.cat(
                (self.edge_convs[i](x, edge_index), x), dim=1
            )  # concatenating with original features i.e. skip connection

        x = global_mean_pool(x, batch)
        x = self.fc1(x)

        if ret_activations:
            if relu_activations: return F.relu(x)
            else: return x  # for Frechet ParticleNet Distance
        else: x = self.dropout_layer(F.relu(x))

        return self.fc2(
            x
        )  # no softmax because pytorch cross entropy loss includes softmax
    def forward(self, data):
        data.x = self.datanorm * data.x
        data.x = self.inputnet(data.x)

        data.edge_index = to_undirected(
            knn_graph(data.x,
                      self.k,
                      data.batch,
                      loop=False,
                      flow=self.edgeconv1.flow))
        data.x = self.edgeconv1(data.x, data.edge_index)

        weight = normalized_cut_2d(data.edge_index, data.x)
        cluster = graclus(data.edge_index, weight, data.x.size(0))
        data.edge_attr = None
        data = max_pool(cluster, data)

        data.edge_index = to_undirected(
            knn_graph(data.x,
                      self.k,
                      data.batch,
                      loop=False,
                      flow=self.edgeconv2.flow))
        data.x = self.edgeconv2(data.x, data.edge_index)

        weight = normalized_cut_2d(data.edge_index, data.x)
        cluster = graclus(data.edge_index, weight, data.x.size(0))
        x, batch = max_pool_x(cluster, data.x, data.batch)

        x = global_max_pool(x, batch)

        return self.output(x).squeeze(-1)
    def forward(self, x, pos, batch=None):

        # add dummy features in case there is none
        if x is None:
            x = torch.ones((pos.shape[0], 1)).to(pos.get_device())

        out_x = []
        out_pos = []
        out_batch = []

        # first block
        x = self.mlp_input(x)
        edge_index = knn_graph(pos, k=self.k, batch=batch)
        x = self.transformer_input(x, pos, edge_index)

        # save outputs for skipping connections
        out_x.append(x)
        out_pos.append(pos)
        out_batch.append(batch)

        # backbone down : #reduce cardinality and augment dimensionnality
        for i in range(len(self.transformers_down)):
            x, pos, batch = self.transition_down[i](x, pos, batch=batch)
            edge_index = knn_graph(pos, k=self.k, batch=batch)
            x = self.transformers_down[i](x, pos, edge_index)

            out_x.append(x)
            out_pos.append(pos)
            out_batch.append(batch)

        # summit
        x = self.mlp_summit(x)
        edge_index = knn_graph(pos, k=self.k, batch=batch)
        x = self.transformer_summit(x, pos, edge_index)

        # backbone up : augment cardinality and reduce dimensionnality
        n = len(self.transformers_down)
        for i in range(n):
            x = self.transition_up[-i - 1](x=out_x[-i - 2],
                                           x_sub=x,
                                           pos=out_pos[-i - 2],
                                           pos_sub=out_pos[-i - 1],
                                           batch_sub=out_batch[-i - 1],
                                           batch=out_batch[-i - 2])

            edge_index = knn_graph(out_pos[-i - 2],
                                   k=self.k,
                                   batch=out_batch[-i - 2])
            x = self.transformers_up[-i - 1](x, out_pos[-i - 2], edge_index)

        # Class score
        out = self.mlp_output(x)

        return F.log_softmax(out, dim=-1)
Example #6
0
def test_knn_graph(dtype, device):
    x = tensor([
        [-1, -1],
        [-1, +1],
        [+1, +1],
        [+1, -1],
    ], dtype, device)

    edge_index = knn_graph(x, k=2, flow='target_to_source')
    assert to_set(edge_index) == set([(0, 1), (0, 3), (1, 0), (1, 2), (2, 1),
                                      (2, 3), (3, 0), (3, 2)])

    edge_index = knn_graph(x, k=2, flow='source_to_target')
    assert to_set(edge_index) == set([(1, 0), (3, 0), (0, 1), (2, 1), (1, 2),
                                      (3, 2), (0, 3), (2, 3)])
Example #7
0
    def forward(self, data):
        k = self.k        
        device = self.device
        mode   = self.mode
        pos_idx = self.pos_idx
        #changing xtype to float, change back after saving graphs properly
        x, edge_index, batch = data.x, data.edge_index, data.batch
        
        edge_index = knn_graph(x=x[:,pos_idx],k=k,batch=batch).to(device)

        a = self.conv_add(x,edge_index)
        
        edge_index = knn_graph(x=a[:,pos_idx],k=k,batch=batch).to(device)
        "check if this recalculation of edge indices is correct, maybe you can do it over all of x"
        b = self.conv_add2(a,edge_index)

        edge_index = knn_graph(x=b[:,pos_idx],k=k,batch=batch).to(device)
        
        c = self.conv_add3(b,edge_index)

        edge_index = knn_graph(x=c[:,pos_idx],k=k,batch=batch).to(device)
        
        d = self.conv_add4(c,edge_index)

        x = torch.cat((x,a,b,c,d),dim = 1) 
        del a,b,c,d
        x = self.nn1(x)
        x = self.relu(x)
        x = self.nn2(x)
        
        a,_ = scatter_max(x, batch, dim = 0)
        b,_ = scatter_min(x, batch, dim = 0)
        c = scatter_sum(x,batch,dim = 0)
        d = scatter_mean(x,batch,dim= 0)
        x = torch.cat((a,b,c,d),dim = 1)
        
        x = self.relu(x)
        x = self.nn3(x)
        
        x = self.relu(x)
        x = self.nn4(x)
        
        if mode == 'angle':
            x[:,0] = self.tanh(x[:,0])
            x[:,1] = self.tanh(x[:,1])
        

        return x
Example #8
0
    def forward(self, x, batch=None):
        spatial = self.lin_s(x)
        to_propagate = self.lin_flr(x)

        if self.neighbor_algo == "knn":
            edge_index = knn_graph(spatial,
                                   self.k,
                                   batch,
                                   loop=False,
                                   flow=self.flow,
                                   cosine=False)
        elif self.neighbor_algo == "radius":
            edge_index = radius_graph(spatial,
                                      self.radius,
                                      batch,
                                      loop=False,
                                      flow=self.flow,
                                      max_num_neighbors=self.k)
        else:
            raise Exception("Unknown neighbor algo {}".format(
                self.neighbor_algo))

        reference = spatial.index_select(0, edge_index[1])
        neighbors = spatial.index_select(0, edge_index[0])

        distancessq = torch.sum((reference - neighbors)**2, dim=-1)
        # Factor 10 gives a better initial spread
        distance_weight = torch.exp(-10. * distancessq)

        prop_feat = self.propagate(edge_index,
                                   x=to_propagate,
                                   edge_weight=distance_weight)

        return edge_index, self.lin_fout(torch.cat([prop_feat, x], dim=-1))
Example #9
0
    def forward(self, x, pos, batch=None):
        """"""
        pos = pos.unsqueeze(-1) if pos.dim() == 1 else pos
        (N, D), K = pos.size(), self.kernel_size

        row, col = knn_graph(pos, K * self.dilation, batch, loop=True)

        if self.dilation > 1:
            dil = self.dilation
            index = torch.randint(
                K * dil, (N, K), dtype=torch.long, device=row.device)
            arange = torch.arange(N, dtype=torch.long, device=row.device)
            arange = arange * (K * dil)
            index = (index + arange.view(-1, 1)).view(-1)
            row, col = row[index], col[index]

        pos = pos[col] - pos[row]

        x_star = self.mlp1(pos.view(N * K, D))
        if x is not None:
            x = x.unsqueeze(-1) if x.dim() == 1 else x
            x = x[col].view(N, K, self.in_channels)
            x_star = torch.cat([x_star, x], dim=-1)
        x_star = x_star.transpose(1, 2).contiguous()
        x_star = x_star.view(N, self.in_channels + self.hidden_channels, K, 1)

        transform_matrix = self.mlp2(pos.view(N, K * D))
        transform_matrix = transform_matrix.view(N, 1, K, K)

        x_transformed = torch.matmul(transform_matrix, x_star)
        x_transformed = x_transformed.view(N, -1, K)

        out = self.conv(x_transformed)

        return out
 def test_cluster(self):
     x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]])
     batch = torch.tensor([0, 0, 0, 0])
     edge_index = torch_cluster.knn_graph(x, k=2, batch=batch, loop=False)
     test_edge_index = torch.LongTensor([[2, 1, 3, 0, 3, 0, 1, 2],
                                         [0, 0, 1, 1, 2, 2, 3, 3]])
     self.assertTrue(torch.all(torch.eq(test_edge_index, edge_index)))
Example #11
0
    def forward(self, x: Tensor, pos: Tensor, batch: Optional[Tensor] = None):
        """"""
        pos = pos.unsqueeze(-1) if pos.dim() == 1 else pos
        (N, D), K = pos.size(), self.kernel_size

        edge_index = knn_graph(pos,
                               K * self.dilation,
                               batch,
                               loop=True,
                               flow='target_to_source',
                               num_workers=self.num_workers)

        if self.dilation > 1:
            edge_index = edge_index[:, ::self.dilation]

        row, col = edge_index[0], edge_index[1]

        pos = pos[col] - pos[row]

        x_star = self.mlp1(pos)
        if x is not None:
            x = x.unsqueeze(-1) if x.dim() == 1 else x
            x = x[col].view(N, K, self.in_channels)
            x_star = torch.cat([x_star, x], dim=-1)
        x_star = x_star.transpose(1, 2).contiguous()

        transform_matrix = self.mlp2(pos.view(N, K * D))

        x_transformed = torch.matmul(x_star, transform_matrix)

        out = self.conv(x_transformed)

        return out
Example #12
0
def get_graph_feature(x, k, batch=None):
    batch_size = batch.max() + 1 if batch is not None else 1
    # knn
    edges = knn_graph(x, k, batch=batch)
    x = torch.cat([x[edges[1]] - x[edges[0]], x[edges[0]]], dim=1)
    x = x.view(batch_size, -1, x.size(1))

    return x.permute(0, 2, 1).contiguous()
Example #13
0
def knn(x, k=2, loop=False, dtype=None, device=None):
    N, D = x.shape
    batch = torch.zeros(N, dtype=torch.long)
    edge_index = knn_graph(x, k, batch=batch, loop=loop).to(device)
    edge_val = torch.ones(edge_index.shape[-1], dtype=dtype, device=device)
    return SparseTensor(
        row=edge_index[0], col=edge_index[1], value=edge_val, sparse_sizes=(N, N)
    )
Example #14
0
def test_knn_graph(dtype, device):
    x = tensor([
        [-1, -1],
        [-1, +1],
        [+1, +1],
        [+1, -1],
    ], dtype, device)

    row, col = knn_graph(x, k=2, flow='target_to_source')
    col = col.view(-1, 2).sort(dim=-1)[0].view(-1)
    assert row.tolist() == [0, 0, 1, 1, 2, 2, 3, 3]
    assert col.tolist() == [1, 3, 0, 2, 1, 3, 0, 2]

    row, col = knn_graph(x, k=2, flow='source_to_target')
    row = row.view(-1, 2).sort(dim=-1)[0].view(-1)
    assert row.tolist() == [1, 3, 0, 2, 1, 3, 0, 2]
    assert col.tolist() == [0, 0, 1, 1, 2, 2, 3, 3]
    def forward(self, batch):

        edge_index = knn_graph(batch.pos,
                               k=self.k,
                               batch=batch.batch,
                               loop=False)
        batch.edge_index = edge_index

        return batch
 def forward(self, x, batch=None):
     """"""
     edge_index = knn_graph(x,
                            self.k,
                            batch,
                            loop=False,
                            flow=self.flow,
                            cosine=True)
     return super(DynamicEdgeConv, self).forward(x, edge_index)
Example #17
0
def test_knn_graph_large(dtype, device):
    x = torch.randn(1000, 3, dtype=dtype, device=device)

    edge_index = knn_graph(x, k=5, flow='target_to_source', loop=True)

    tree = scipy.spatial.cKDTree(x.cpu().numpy())
    _, col = tree.query(x.cpu(), k=5)
    truth = set([(i, j) for i, ns in enumerate(col) for j in ns])

    assert to_set(edge_index.cpu()) == truth
    def forward(self, pos, batch):
        edge_index = knn_graph(pos, k=16, batch=batch, loop=True)
        h = self.conv1(h=pos, pos=pos, edge_index=edge_index)
        h = h.relu()
        h = self.conv2(h=h, pos=pos, edge_index=edge_index)
        h = h.relu()

        h = global_max_pool(h, batch)
        h = self.classifier(h)
        y = torch.sigmoid(h)
        return y
Example #19
0
 def forward(self, data):
     x = data.x
     if self.k_graph:
         data.edge_index = knn_graph(data.x, k_graph)
     x1 = F.elu(self.conv1(x, data.edge_index))
     # x = F.dropout(x, p=0.6, training=self.training)
     x2 = self.conv2(x1, data.edge_index)
     #x = self.conv3(x, data.edge_index)
     x = self.pool(torch.cat([x1, x2], dim=1), data.batch)
     x = self.mlp(x)
     return x
Example #20
0
    def forward(self, data):
        # device = self.device
        # mode   = self.mode
        k = self.k
        device = self.device
        pos_idx = self.pos_idx
        x, edge_index, batch = data.x, data.edge_index, data.batch
        edge_index = knn_graph(x=x[:, pos_idx], k=k, batch=batch).to(device)
        x = self.GGconv1(x, edge_index)
        x = self.relu(x)

        x = self.nn1(x)
        x = self.relu(x)

        y = self.resblock1(x)
        x = x + y

        z = self.resblock2(x)
        x = x + z

        del y, z

        x = self.nn2(x)
        x = self.relu(x)

        x = self.GGconv2(x, edge_index)
        x = self.relu(x)

        p = self.resblock3(x)
        x = x + p

        o = self.resblock4(x)
        x = x + o
        del p, o

        x = self.nn3(x)
        x = self.relu(x)

        a, _ = scatter_max(x, batch, dim=0)
        b, _ = scatter_min(x, batch, dim=0)
        c = scatter_sum(x, batch, dim=0)
        d = scatter_mean(x, batch, dim=0)
        x = torch.cat((a, b, c, d), dim=1)
        # print ("cat size",x.size())
        del a, b, c, d

        x = self.nncat(x)
        x = self.relu(x)
        # if(torch.sum(torch.isnan(x)) != 0):
        # print('NAN ENCOUNTERED AT NN2')

        # print ("xsize %s batchsize %s a size %s b size %s y size %s end forward" %(x.size(),batch.size(),a.size(),b.size(),data.y[:,0].size()))
        return x
Example #21
0
def test_knn_graph(dtype, device):
    x = tensor([
        [-1, -1],
        [-1, +1],
        [+1, +1],
        [+1, -1],
    ], dtype, device)

    row, col = knn_graph(x, k=2)
    col = col.view(-1, 2).sort(dim=-1)[0].view(-1)

    assert row.tolist() == [0, 0, 1, 1, 2, 2, 3, 3]
    assert col.tolist() == [1, 3, 0, 2, 1, 3, 0, 2]
 def forward(self, data):
     pos, batch, eidx = data.pos, data.batch, data.edge_index
     x1 = self.conv1(pos, eidx)
     x2 = self.conv2(x1, batch)
     x = self.lin1(torch.cat([x1, x2], dim=1))
     x = global_max_pool(x, batch)
     x = self.mlp(x)
     out_knn = knn_graph(x, self.k_global+1, batch=None, loop=True)[0]
     # assuming k_global < min streamline length
     out_knn = x[out_knn.view(-1, self.k_global+1)].mean(1)
     # pseudo_class = F.log_softmax(out_knn)
     out = self.lin2(out_knn)
     return out
Example #23
0
    def forward(self, data):
        # Use the coords for the first knn step
        print('data.x:', data.x.size())
        print('data.batch:', data.batch.size())
        clustering1 = to_undirected(
            knn_graph(data.x,
                      self.k,
                      data.batch,
                      loop=False,
                      flow=self.edgeconv1.flow))
        print('clustering1:', clustering1.size())
        out1 = self.edgeconv1(data.features, clustering1)
        print('out1:', out1.size())

        raise Exception('stop')

        # Now use the outputted features of the previous layer for the knn
        clustering2 = to_undirected(
            knn_graph(out1,
                      self.k,
                      data.batch,
                      loop=False,
                      flow=self.edgeconv2.flow))
        out2 = self.edgeconv2(out1, clustering2)

        clustering3 = to_undirected(
            knn_graph(out2,
                      self.k,
                      data.batch,
                      loop=False,
                      flow=self.edgeconv3.flow))
        out3 = self.edgeconv3(out2, clustering3)

        # Cat all outputs together
        edgeconv_out = torch.cat([data.features, out1, out2, out3])

        # Run the output layer
        return self.output(edgeconv_out).squeeze(-1)
    def load_edges(self, idx, nhits):
        if self.distance_weighted:
            distEdgeTensor = self.dist_pos_matrix(idx, nhits)
            self.edge_index = distEdgeTensor[0]
            self.edge_attr = distEdgeTensor[1]

        else:
            if self.fully_connected:
                edge_index = torch.ones([nhits, nhits], dtype=torch.int64)
                self.edge_index = edge_index.to_sparse()._indices()
            else:
                pos = torch.as_tensor(self.event_data[idx, :nhits, 2:5],
                                      dtype=torch.float)
                self.edge_index = knn_graph(pos, k=self.k_neighbours)
Example #25
0
    def forward(self, x, pos, batch=None, edge_index=None):
        if edge_index is None:
            edge_index = knn_graph(x, self.k, batch, loop=False)
            edge_index = edge_index.to(device)

        if self.pool:
            new_adj, new_feat, new_pos, new_batch, index, values, origsize, newsize = mgpool(x, pos, edge_index, batch)
            return self.layers(new_feat, new_pos, new_adj, new_batch, self.k), new_pos, new_batch, (
                index, values, origsize, newsize)
        else:
            new_pos = pos
            new_batch = batch
            new_feat = x
            new_adj = edge_index

        return self.layers(new_feat, new_pos, new_adj, new_batch, self.k), new_pos, new_batch
 def test_cluster(self):
     x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]]).to(device)
     batch = torch.tensor([0, 0, 0, 0]).to(device)
     edge_index = torch_cluster.knn_graph(x, k=2, batch=batch,
                                          loop=False).to(device)
     test_edge_index = torch.LongTensor([[2, 1, 3, 0, 3, 0, 1, 2],
                                         [0, 0, 1, 1, 2, 2, 3,
                                          3]]).to(device)
     edge_list = edge_index.tolist()
     test_edge_list = test_edge_index.tolist()
     del edge_index, test_edge_index
     # need to transpose the edges to (ei, ej) format
     edge_list = [(edge_list[0][i], edge_list[1][i])
                  for i in range(len(edge_list[0]))]
     test_edge_list = [(test_edge_list[0][i], test_edge_list[1][i])
                       for i in range(len(test_edge_list[0]))]
     self.assertCountEqual(edge_list, test_edge_list)
Example #27
0
    def _featurize_as_graph(self, protein):
        name = protein['name']
        with torch.no_grad():
            coords = torch.as_tensor(protein['coords'],
                                     device=self.device,
                                     dtype=torch.float32)
            seq = torch.as_tensor(
                [self.letter_to_num[a] for a in protein['seq']],
                device=self.device,
                dtype=torch.long)

            mask = torch.isfinite(coords.sum(dim=(1, 2)))
            coords[~mask] = np.inf

            X_ca = coords[:, 1]
            edge_index = torch_cluster.knn_graph(X_ca, k=self.top_k)

            pos_embeddings = self._positional_embeddings(edge_index)
            E_vectors = X_ca[edge_index[0]] - X_ca[edge_index[1]]
            rbf = _rbf(E_vectors.norm(dim=-1),
                       D_count=self.num_rbf,
                       device=self.device)

            dihedrals = self._dihedrals(coords)
            orientations = self._orientations(X_ca)
            sidechains = self._sidechains(coords)

            node_s = dihedrals
            node_v = torch.cat(
                [orientations, sidechains.unsqueeze(-2)], dim=-2)
            edge_s = torch.cat([rbf, pos_embeddings], dim=-1)
            edge_v = _normalize(E_vectors).unsqueeze(-2)

            node_s, node_v, edge_s, edge_v = map(
                torch.nan_to_num, (node_s, node_v, edge_s, edge_v))

        data = torch_geometric.data.Data(x=X_ca,
                                         seq=seq,
                                         name=name,
                                         node_s=node_s,
                                         node_v=node_v,
                                         edge_s=edge_s,
                                         edge_v=edge_v,
                                         edge_index=edge_index,
                                         mask=mask)
        return data
Example #28
0
    def forward(self, x, labels=None, epoch=None):
        x = F.leaky_relu(self.dense(x), negative_slope=self.args.leaky_relu_alpha)

        batch_size = x.size(0)
        x = x.reshape(batch_size * self.args.num_hits, self.args.graphcnng_layers[0])
        zeros = torch.zeros(batch_size * self.args.num_hits, dtype=int).to(self.args.device)
        zeros[torch.arange(batch_size) * self.args.num_hits] = 1
        batch = torch.cumsum(zeros, 0) - 1

        for i in range(len(self.layers)):
            edge_index = knn_graph(x, self.args.num_knn, batch)
            edge_attr = x[edge_index[0]] - x[edge_index[1]]
            x = self.bn_layers[i](self.layers[i](x, edge_index, edge_attr))
            if i < (len(self.layers) - 1): x = F.leaky_relu(x, negative_slope=self.args.leaky_relu_alpha)

        if self.args.graphcnng_tanh: x = F.tanh(x)

        return x.reshape(batch_size, self.args.num_hits, self.args.node_feat_size)
Example #29
0
    def forward(self, pos, batch):
        # Compute the kNN graph:
        # Here, we need to pass the batch vector to the function call in order
        # to prevent creating edges between points of different examples.
        # We also add `loop=True` which will add self-loops to the graph in
        # order to preserve central point information.
        edge_index = knn_graph(pos, k=16, batch=batch, loop=True)

        # 3. Start bipartite message passing.
        h = self.conv1(h=pos, pos=pos, edge_index=edge_index)
        h = h.relu()
        h = self.conv2(h=h, pos=pos, edge_index=edge_index)
        h = h.relu()
        h = self.conv3(h=h, pos=pos, edge_index=edge_index)
        h = h.relu()
        out = self.classifier(h)

        return out
Example #30
0
    def forward(self, x, batch=None):
        """"""
        x = x.unsqueeze(-1) if x.dim() == 1 else x

        row, col = knn_graph(x, self.k, batch, loop=False)
        x_row, x_col = x.index_select(0, row), x.index_select(0, col)
        out = torch.cat([x_row, x_col - x_row], dim=1)
        out = self.nn(out)
        out = out.view(-1, self.k, out.size(-1))

        if self.aggr == 'add':
            out = out.sum(dim=1)
        elif self.aggr == 'mean':
            out = out.mean(dim=1)
        else:
            out = out.max(dim=1)[0]

        return out