def forward(self, data):
        pos, batch = data.pos, data.batch

        idx = fps(pos, batch, ratio=0.5)  # 512 points
        row, col = radius(pos,
                          pos[idx],
                          0.2,
                          batch,
                          batch[idx],
                          max_num_neighbors=64)
        edge_index = torch.stack([col, row], dim=0)  # Transpose.
        x = F.relu(self.local_sa1(None, (pos, pos[idx]), edge_index))
        pos, batch = pos[idx], batch[idx]

        idx = fps(pos, batch, ratio=0.25)  # 128 points
        row, col = radius(pos,
                          pos[idx],
                          0.4,
                          batch,
                          batch[idx],
                          max_num_neighbors=64)
        edge_index = torch.stack([col, row], dim=0)  # Transpose.
        x = F.relu(self.local_sa2(x, (pos, pos[idx]), edge_index))
        pos, batch = pos[idx], batch[idx]

        x = self.global_sa(torch.cat([x, pos], dim=1))
        x = x.view(-1, 128, self.lin1.in_features).max(dim=1)[0]

        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = F.relu(self.lin2(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin3(x)
        return F.log_softmax(x, dim=-1)
Esempio n. 2
0
    def forward(self, data):
        pos, batch = data.pos, data.batch

        idx = fps(pos, batch, ratio=0.5)  # 512 points
        row, col = radius(pos[idx],
                          pos,
                          0.1,
                          batch[idx],
                          batch,
                          max_num_neighbors=64)
        edge_index = torch.stack([row, idx[col]], dim=0)
        x = F.relu(self.local_sa1(None, pos, edge_index))
        x, pos, batch = x[idx], pos[idx], batch[idx]

        idx = fps(pos, batch, ratio=0.25)  # 128 points
        row, col = radius(pos[idx],
                          pos,
                          0.2,
                          batch[idx],
                          batch,
                          max_num_neighbors=64)
        edge_index = torch.stack([row, idx[col]], dim=0)
        x = F.relu(self.local_sa2(x, pos, edge_index))
        x, pos, batch = x[idx], pos[idx], batch[idx]

        x = self.global_sa(torch.cat([x, pos], dim=1))
        x = x.view(-1, 128, 1024).max(dim=1)[0]

        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = F.relu(self.lin2(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin3(x)
        return F.log_softmax(x, dim=-1)
Esempio n. 3
0
    def forward(self, data):
        # input = torch.cat([data.norm, data.pos], dim=1)
        # i = torch.cat([data.norm, data.pos, data.x], dim=1)
        input = torch.cat([data.norm, data.pos], dim=1)
        x, batch = input, data.batch

        edge_index, edge_weight = data.edge_index, data.edge_attr
        edge_weight = torch.ones((edge_index.size(1),), dtype=x.dtype, device=edge_index.device)

        # first conv with full points
        x = F.dropout(x, training=self.training, p=0.2)
        x = F.relu(self.conv1(x, edge_index, edge_weight))

        # Second conv with full points
        x = torch.cat([x, input], dim=1)
        x = F.dropout(x, training=self.training, p=0.2)
        x = F.relu(self.conv2(x, edge_index, edge_weight))

        # first down sampling index generation
        idx = fps(data.pos, batch, ratio=0.5)
        row, col = radius(data.pos, data.pos[idx], 0.4, batch, batch[idx], max_num_neighbors=64)
        edge_index_int = torch.stack([col, row], dim=0)
        x = self.con_int(x, (data.pos, data.pos[idx]), edge_index_int)

        batch = batch[idx]

        edge_index, edge_weight = self.filter_adj(edge_index, edge_weight, idx, data.pos.size(0))

        x = torch.cat([x, input[idx]], dim=1)
        x = F.relu(self.conv3(x, edge_index, edge_weight))

        out, critical_points = global_max_pool(x, batch)
        out = self.lin1(out)
        out = F.log_softmax(out, dim=1)
        return out, critical_points
Esempio n. 4
0
    def forward(self, x, pos, batch, norm=None):
        # pool points based on FPS algorithm, returning Npt*ratio centroids
        idx = fps(pos, batch, ratio=self.ratio)

        # finds points within radius `self.r` of the centroids, up to `self.K` pts per centroid
        row, col = radius(pos,
                          pos[idx],
                          self.r,
                          batch,
                          batch[idx],
                          max_num_neighbors=self.K)

        # edges joining centroids to their neighbors within ball of radius `self.r`
        edge_index = torch.stack([col, row], dim=0)

        # perform convolution
        if self.conv_name == 'PointConv':
            x = self.conv(x, (pos, pos[idx]), edge_index)
        elif self.conv_name == 'GraphConv':
            x = self.conv(x, edge_index)[idx]
        elif self.conv_name == 'PPFConv':
            x = self.conv(x, pos, norm, edge_index)[idx]

        pos, batch = pos[idx], batch[idx]
        return (x, pos, batch), idx
Esempio n. 5
0
 def forward(self, x, pos, batch):
     idx = fps(pos, batch, ratio=self.ratio)
     row, col = radius(pos, pos[idx], self.r, batch, batch[idx], max_num_neighbors=64)
     edge_index = torch.stack([col, row], dim=0)
     x = self.conv(x, (pos, pos[idx]), edge_index)
     pos, batch = pos[idx], batch[idx]
     return x, pos, batch
Esempio n. 6
0
    def forward(self, points, features, batch):
        ratio = 1 / self.nb_neighbors
        fps_indices = gnn.fps(x=points, batch=batch, ratio=ratio)
        fps_points = points[fps_indices]
        fps_batch = batch[fps_indices]

        radius_cluster, radius_indices = gnn.radius(x=points,
                                                    y=fps_points,
                                                    batch_x=batch,
                                                    batch_y=fps_batch,
                                                    r=self.radius)

        anchor_points = fps_points[radius_cluster]
        radius_points = points[radius_indices]
        radius_features = features[radius_indices]

        relative_points = (radius_points - anchor_points) / self.radius
        rel_encoded = self.neighborhood_enc(relative_points, radius_cluster)
        rel_enc_mapped = rel_encoded[radius_cluster]

        fc_input = torch.cat(
            [relative_points, rel_enc_mapped, radius_features], dim=1)

        fc1_features = F.relu(self.fc1(fc_input))

        max_features = gnn.global_max_pool(x=fc1_features,
                                           batch=radius_cluster)

        fc1_global_features = F.relu(self.fc1_global(max_features))

        output_features = torch.cat([rel_encoded, fc1_global_features], dim=1)

        return fps_points, output_features, fps_batch
Esempio n. 7
0
    def forward(self, points, batch):
        ratio = 1/self.nb_neighbors
        fps_indices = gnn.fps(
            x=points,
            batch=batch,
            ratio=ratio
        )
        fps_points = points[fps_indices]
        fps_batch = batch[fps_indices]

        radius_cluster, radius_indices = gnn.radius(
            x=points,
            y=fps_points,
            batch_x=batch,
            batch_y=fps_batch,
            r=self.radius
        )
        anchor_points = fps_points[radius_cluster]
        radius_points = points[radius_indices]

        relative_points = (radius_points - anchor_points) / self.radius

        features = self.neighborhood_encoder(relative_points, radius_cluster)

        return fps_points, features, fps_batch
Esempio n. 8
0
    def forward(self, points, batch):
        ratio = 1 / self.nb_neighbors
        fps_indices = gnn.fps(x=points, batch=batch, ratio=ratio)
        fps_points = points[fps_indices]
        fps_batch = batch[fps_indices]

        radius_cluster, radius_indices = gnn.radius(x=points,
                                                    y=fps_points,
                                                    batch_x=batch,
                                                    batch_y=fps_batch,
                                                    r=self.radius)

        anchor_points = fps_points[radius_cluster]
        radius_points = points[radius_indices]

        relative_points = (radius_points - anchor_points) / self.radius

        fc1_features = F.relu(self.fc1(relative_points))
        fc2_features = F.relu(self.fc2(fc1_features))
        fc3_features = F.relu(self.fc3(fc2_features))

        max_features = gnn.global_max_pool(x=fc3_features,
                                           batch=radius_cluster)

        fc1_global_features = F.relu(self.fc1_global(max_features))
        fc2_global_features = F.relu(self.fc2_global(fc1_global_features))
        fc3_global_features = F.relu(self.fc3_global(fc2_global_features))

        return fps_points, fc3_global_features, fps_batch
 def forward(self, x, pos, batch):
   edge_index = radius(pos, pos, self.r,
                       batch, batch, max_num_neighbors=64)
   msg1 = self.cdf1(x, pos, edge_index)
   msg2 = self.cdf2(x, pos, edge_index)
   msg = torch.cat([msg1, msg2], dim=-1)
   conf = self.lin1(msg)
   return conf, msg, edge_index
Esempio n. 10
0
    def forward(self, data):
        x = data.x  # 원자 인덱싱
        edge_index = data.edge_index
        pos = data.pos
        batch = data.batch
        # Initialize node embeddings
        h = torch.index_select(self.embeddings, 0, x.long())    # input, dim, index (int or long)

        # Get the edges and pairwise distances in the local layer
        edge_index_l, _ = remove_self_loops(edge_index)
        j_l, i_l = edge_index_l
        dist_l = (pos[i_l] - pos[j_l]).pow(2).sum(dim=-1).sqrt()
        
        # Get the edges pairwise distances in the global layer
        row, col = radius(pos, pos, self.cutoff, batch, batch, max_num_neighbors=500)   # radius
        edge_index_g = torch.stack([row, col], dim=0)
        edge_index_g, _ = remove_self_loops(edge_index_g)
        j_g, i_g = edge_index_g
        dist_g = (pos[i_g] - pos[j_g]).pow(2).sum(dim=-1).sqrt()
        
        # Compute the node indices for defining the angles
        idx_i_1, idx_j, idx_k, idx_kj, idx_ji,   idx_i_2, idx_j1, idx_j2, idx_jj, idx_ji_2 = self.indices(edge_index_l, num_nodes=h.size(0))

        # Compute the two-hop angles
        pos_ji_1, pos_kj = pos[idx_j] - pos[idx_i_1], pos[idx_k] - pos[idx_j]       # ij, jk (kj, ji)
        a = (pos_ji_1 * pos_kj).sum(dim=-1)                 # 내적 u v cos
        b = torch.cross(pos_ji_1, pos_kj).norm(dim=-1)      # 외적 u v sin
        angle_1 = torch.atan2(b, a)                         # 각도 = arctan( sin / cos )

        # Compute the one-hop angles
        pos_ji_2, pos_jj = pos[idx_j1] - pos[idx_i_2], pos[idx_j2] - pos[idx_j1]    # ij, jj (jj, ji)
        a = (pos_ji_2 * pos_jj).sum(dim=-1)
        b = torch.cross(pos_ji_2, pos_jj).norm(dim=-1)
        angle_2 = torch.atan2(b, a)

        # Get the RBF and SBF embeddings
        rbf_g = self.rbf_g(dist_g)
        rbf_l = self.rbf_l(dist_l)
        sbf_1 = self.sbf(dist_l, angle_1, idx_kj)   # 2-hop
        sbf_2 = self.sbf(dist_l, angle_2, idx_jj)   # 1-hop
        
        rbf_g = self.rbf_g_mlp(rbf_g)
        rbf_l = self.rbf_l_mlp(rbf_l)
        sbf_1 = self.sbf_1_mlp(sbf_1)
        sbf_2 = self.sbf_2_mlp(sbf_2)
        
        # Perform the message passing schemes
        node_sum = 0

        for layer in range(self.n_layer):
            h = self.global_layers[layer](h, rbf_g, edge_index_g)
            h, t = self.local_layers[layer](h, rbf_l, sbf_1, sbf_2, idx_kj, idx_ji, idx_jj, idx_ji_2, edge_index_l)
            node_sum += t
        
        # Readout
        output = global_add_pool(node_sum, batch)   #
        return output.view(-1)
 def forward(self, x, pos, batch):
     idx = fps(pos, batch, ratio=self.ratio)
     row, col = radius(
         pos, pos[idx], self.r, batch, batch[idx], max_num_neighbors=64
     )  # TODO: FIGURE OUT THIS WITH RESPECT TO NUMBER OF POINTS
     edge_index = torch.stack([col, row], dim=0)
     x = self.conv(x, (pos, pos[idx]), edge_index)
     pos, batch = pos[idx], batch[idx]
     return x, pos, batch
 def find_neighbours(self, x, y, batch_x=None, batch_y=None, scale_idx=0):
     if scale_idx >= self.num_scales:
         raise ValueError("Scale %i is out of bounds %i" %
                          (scale_idx, self.num_scales))
     return radius(x,
                   y,
                   self._radius[scale_idx],
                   batch_x,
                   batch_y,
                   max_num_neighbors=self._max_num_neighbors[scale_idx])
Esempio n. 13
0
 def forward(self, data):
     x, pos, batch = data
     idx = fps(pos, batch, ratio=self.ratio)
     row, col = radius(pos, pos[idx], self.radius, batch, batch[idx],
                       max_num_neighbors=self.max_num_neighbors)
     edge_index = torch.stack([col, row], dim=0)
     x = self.conv(x, (pos, pos[idx]), edge_index)
     pos, batch = pos[idx], batch[idx]
     data = (x, pos, batch)
     return data
 def forward(self, x, pos, batch):
     """Compute Point-wise Rigid Transformation.
 """
     pos6d = torch.cat([pos, x[:, :3]], dim=-1)
     #idx = fps(pos6d, batch, ratio=0.2)
     #row6d, col6d = radius(pos6d, pos6d[idx], 0.05, batch, batch[idx],
     #                      max_num_neighbors=64)
     #edge_index6d = torch.stack([col6d, row6d], dim=0)
     row, col = radius(pos, pos, 0.1, batch, batch, max_num_neighbors=64)
     edge_index = torch.stack([col, row], dim=0)
     new_pred = self.conv(x, pos, edge_index)
     return new_pred
Esempio n. 15
0
    def forward(self, x, pos, batch):
        row, col = radius(pos,
                          pos,
                          self.r,
                          batch,
                          batch,
                          max_num_neighbors=self.sample_size)
        edge_index = torch.stack([col, row], dim=0)

        x = self.conv(x, (pos, pos), edge_index)
        pos, batch = pos, batch
        return x, pos, batch
Esempio n. 16
0
 def radius_graph_3d(self, pos, batch, r, direction='source2target'):
     target_idx, source_idx = radius(pos,
                                     pos,
                                     r,
                                     batch,
                                     batch,
                                     max_num_neighbors=32)
     if direction == 'source2target':
         edge_index = torch.stack([source_idx, target_idx], dim=0)
     else:
         edge_index = torch.stack([target_idx, source_idx], dim=0)
     return edge_index
 def forward(self, x, pos, norm, batch):
     idx = fps(pos, batch, ratio=self.ratio)
     #可以用radius或nerest构建半径内或者最近邻图
     row, col = radius(pos,
                       pos[idx],
                       self.r,
                       batch,
                       batch[idx],
                       max_num_neighbors=32)
     edge_index = torch.stack([col, row], dim=0)
     x = self.conv(x, (pos, pos[idx]), (norm, norm[idx]), edge_index)
     pos, norm, batch = pos[idx], norm[idx], batch[idx]
     return x, pos, norm, batch
    def forward(self, x, pos, batch):
        pos_ext = torch.cat([pos, x], dim=1)
        msg = []
        x_centers = []
        poss = []
        idxs = []
        msg_list = []
        for r3d, ratio, r6d in zip(self.r3d, self.ratios, self.r6d):
            row6d, col6d = radius(pos_ext,
                                  pos_ext,
                                  r6d,
                                  batch,
                                  batch,
                                  max_num_neighbors=64)
            #idx = fps(pos_ext, batch, ratio=ratio)
            #idxs.append(idx)
            #row, col = radius(pos_ext, pos_ext[idx], r6d, batch, batch[idx],
            #                  max_num_neighbors=64)
            x_centers = scatter_mean(x.index_select(index=col6d, dim=0),
                                     row6d.unsqueeze(1),
                                     dim=0)
            edge_index6d = torch.stack([col6d, row6d], dim=0)
            #edge_indices.append(edge_index)
            row3d, col3d = radius(pos,
                                  pos,
                                  r3d,
                                  batch,
                                  batch,
                                  max_num_neighbors=64)

            edge_index3d = torch.stack([col3d, row3d], dim=0)

            msg = self.dst(x_centers, edge_index3d)
            msg_list.append(msg.unsqueeze(-1))
        msg = torch.cat(msg_list, dim=-1)

        #conf = F.relu(self.lin1(msg))
        conf = F.sigmoid(self.lin1(msg))
        return conf[:, 0]
    def forward(self, data):
        pos, batch = data.pos, data.batch

        idx = fps(pos, batch, ratio=0.5)  # 512 points
        edge_index = radius(pos[idx], pos, 0.1, batch[idx], batch, 48)
        x = F.relu(self.local_sa1(None, pos, edge_index))
        pos, batch = pos[idx], batch[idx]

        idx = fps(pos, batch, ratio=0.25)  # 128 points
        edge_index = radius(pos[idx], pos, 0.2, batch[idx], batch, 48)
        x = F.relu(self.local_sa2(x, pos, edge_index))
        pos, batch = pos[idx], batch[idx]

        x = self.global_sa(torch.cat([x, pos], dim=1))
        x = x.view(-1, 128, 1024).max(dim=1)[0]

        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = F.relu(self.lin2(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin3(x)
        return F.log_softmax(x, dim=-1)
    def forward(self, data):
        x, pos, batch = data

        # Sample
        idx = fps(pos, batch, ratio=self.sample_ratio)

        # Group(Build graph)
        row, col = radius(pos, pos[idx], self.radius, batch, batch[idx], max_num_neighbors=self.max_num_neighbors)
        edge_index = torch.stack([col, row], dim=0)

        # Apply pointnet
        x1 = self.point_conv(x, (pos, pos[idx]), edge_index)
        pos1, batch1 = pos[idx], batch[idx]

        return x1, pos1, batch1
Esempio n. 21
0
    def forward(self, x, pos, batch):
        sample_indices = gnn.fps(pos, batch, self.ratio)
        sparse_indices, dense_indices = gnn.radius(pos,
                                                   pos[sample_indices],
                                                   self.radius,
                                                   batch,
                                                   batch[sample_indices],
                                                   max_num_neighbors=64)
        edge_index = torch.stack(
            (dense_indices, sparse_indices), dim=0
        )  #TODO/CARE: Indices are propagated internally? Care edge direction: a->b <=> a is in N(b)

        x = self.point_conv(x, (pos, pos[sample_indices]), edge_index)

        return x, pos[sample_indices], batch[sample_indices]
Esempio n. 22
0
    def forward(self, x, pos, batch):
        idx = fps(pos, batch, self.sample_points / len(pos))
        x_list = []
        for i in range(len(self.r_list)):
            row, col = radius(pos,
                              pos[idx],
                              self.r_list[i],
                              batch,
                              batch[idx],
                              max_num_neighbors=self.group_sample_size[i])
            edge_index = torch.stack([col, row], dim=0)
            group_x = self.conv_list[i](x, (pos, pos[idx]), edge_index)
            x_list.append(group_x)

        new_x = torch.cat(x_list, 1)
        return new_x, pos[idx], batch[idx]
Esempio n. 23
0
        def forward(self, x, pos, batch):
            # Sampling Layer
            idx = fps(pos, batch, ratio=self.ratio)
            # Grouping Layer
            row, col = radius(pos,
                              pos[idx],
                              self.r,
                              batch,
                              batch[idx],
                              max_num_neighbors=64)
            edge_index = torch.stack([col, row], dim=0)

            # PointNet Layer
            x = self.conv(x, (pos, pos[idx]), edge_index)
            pos, batch = pos[idx], batch[idx]
            return x, pos, batch
Esempio n. 24
0
    def forward(self, pos, batch):
        # 1. Sample farthest points.
        mask = fps(pos, batch, ratio=0.25)

        # 2. Dynamically generate message passing connections.
        row, col = radius(pos, pos[mask], 0.3, batch, batch[mask])
        assign_index = torch.stack([col, row], dim=0)  # Transpose.

        # 3. Start bipartite message passing.
        x = self.conv(pos, pos[mask], assign_index)

        # 4. Global Pooling.
        x = global_max_pool(x, batch[mask])

        # 5. Classifier.
        return self.classifier(x)
    def forward(self, x, pos, batch):
        pos6d = torch.cat([x, pos], dim=-1)
        idx = fps(pos6d, batch, ratio=self.ratio)

        row6d, col6d = radius(pos6d,
                              pos6d[idx],
                              self.r,
                              batch,
                              batch[idx],
                              max_num_neighbors=64)
        edge_index6d = torch.stack([col6d, row6d], dim=0)
        x_centers = scatter_mean(x.index_select(index=col6d, dim=0),
                                 row6d.unsqueeze(1),
                                 dim=0)
        #x = self.conv(x, (pos, pos[idx]), edge_index)
        pos, batch = pos[idx], batch[idx]
        return x_centers, pos, batch
Esempio n. 26
0
    def forward(self, x, pos, batch):
        idx = fps(pos, batch, ratio=self.ratio)
        print("idx.shape: ", idx.shape)
        row, col = radius(pos,
                          pos[idx],
                          self.r,
                          batch,
                          batch[idx],
                          max_num_neighbors=64)
        print("row.shape: ", row.shape)
        edge_index = torch.stack([col, row], dim=0)
        print("edge_index.shape: ", edge_index.shape)
        x = self.conv(x, (pos, pos[idx]), edge_index)
        print("x.shape: ", x.shape)

        pos, batch = pos[idx], batch[idx]
        return x, pos, batch
Esempio n. 27
0
 def find_neighbours(self, x, y, batch_x=None, batch_y=None):
     if self._conv_type == ConvolutionFormat.MESSAGE_PASSING.value:
         return radius(x,
                       y,
                       self._radius,
                       batch_x,
                       batch_y,
                       max_num_neighbors=self._max_num_neighbors)
     elif self._conv_type == ConvolutionFormat.DENSE.value or ConvolutionFormat.PARTIAL_DENSE.value:
         return tp.ball_query(self._radius,
                              self._max_num_neighbors,
                              x,
                              y,
                              mode=self._conv_type,
                              batch_x=batch_x,
                              batch_y=batch_y)
     else:
         raise NotImplementedError
Esempio n. 28
0
    def forward(self, x, pos, batch):
        if self.training:
            ratio = self.ratio_train
        else:
            ratio = self.ratio_test
        idx = fps(pos, batch, ratio=ratio)
        # ball query searches neighbors
        y_idx, x_idx = radius(pos,
                              pos[idx],
                              self.r,
                              batch,
                              batch[idx],
                              max_num_neighbors=128)

        edge_index = torch.stack([x_idx, y_idx], dim=0)
        x = self.conv(x, (pos, pos[idx]), edge_index)
        pos, batch = pos[idx], batch[idx]
        return x, pos, batch, x_idx, y_idx
  def forward(self, x, pos, batch):
    pos6d = torch.cat([pos, x], dim=-1)
    idx = fps(pos6d, batch, ratio=self.ratio, random_start=False)
    #batch0_mask = (batch[idx] == 0)
    #print('idx^2={}'.format((idx[batch0_mask]**2).sum()))
    row, col = radius(pos6d, pos6d[idx], self.r6d,
                        batch, batch[idx], max_num_neighbors=128)
    edge_index = torch.stack([col, row], dim=0)
    #knn_index = knn(pos6d[idx], pos6d, 1, batch[idx], batch)
    x = scatter_mean(
          x.index_select(index=edge_index[0], dim=0),
          edge_index[1].unsqueeze(1),
          dim=0)
    pos = scatter_mean(
            pos.index_select(index=edge_index[0], dim=0),
            edge_index[1].unsqueeze(1),
            dim=0)

    return (x, pos, batch[idx]), edge_index
Esempio n. 30
0
    def build_bipartite_graph(self,
                              pos,
                              batch,
                              ratio,
                              method='radius',
                              r=0.1,
                              k=32,
                              dilation=1):
        '''
        Build a bipartite graph given a pos vector
        :param pos: (torch.Tensor) The position of the points
        :param batch: (torch.Tensor) The batch index for each point
        :param ratio: (float) The sample ratio
        :param method: (str) Specify which method to use, ['radius', 'knn']
        :param r: (float) If 'radius' is adopted, the radius of the support domain
        :param k: (int) IF 'knn' is adopted, the number of neighbors
        :param dilation: (int) If 'knn' is adopted, the dilation rate
        :return: (torch.Tensor) The edge index, position and batch of the sampled points
        '''
        assert method in ['radius', 'knn']
        idx = fps(pos, batch, ratio=ratio)
        if method == 'radius':
            row, col = radius(pos,
                              pos[idx],
                              r,
                              batch,
                              batch[idx],
                              max_num_neighbors=k)
        if method == 'knn':
            row, col = knn(pos, pos[idx], k * dilation, batch, batch[idx])
            if dilation > 1:
                n = idx.shape[0]
                index = torch.randint(k * dilation, (n, k),
                                      dtype=torch.long,
                                      device=row.device)
                arange = torch.arange(n, dtype=torch.long, device=row.device)
                arange = arange * (k * dilation)
                index = (index + arange.view(-1, 1)).view(-1)
                row, col = row[index], col[index]

        edge_index = torch.stack([col, row], dim=0)
        return edge_index, pos[idx], batch[idx]