def forward(self, x: Tuple[UFloatTensor,  # (N, x, dims)
                            UFloatTensor]  # (N, x, C_in)
             ) -> Tuple[UFloatTensor,        # (N, P, dims)
                        UFloatTensor]:       # (N, P, C_out)
     """
     Given a point cloud, and its corresponding features, return a new set
     of fartest-points-sampled representative points with features projected from
     the point cloud.
     :param x: (pts, fts) where
      - pts: Regional point cloud such that fts[:,p_idx,:] is the
     feature associated with pts[:,p_idx,:].
      - fts: Regional features such that pts[:,p_idx,:] is the feature
     associated with fts[:,p_idx,:].
     :return: Randomly subsampled points and their features.
     """
     pts, fts = x
     if 0 < self.P < pts.size()[1]:
         # Select random set of indices of subsampled points.
         fps_idx = farthest_point_sample(pts, self.P)  # [N, P]
         rep_pts = index_points(pts, fps_idx)  # [N, P, dim]
     else:
         # All input points are representative points.
         rep_pts = pts
     rep_pts_fts = self.pointcnn((rep_pts, pts, fts))  # [N, P, C_in]
     return rep_pts, rep_pts_fts
    def forward(self, xyz, points):
        """
        Input:
            xyz: input points position data, [B, C, N]
            points: input points data, [B, D, N]
        Return:
            new_xyz: sampled points position data, [B, C, S]
            new_points_concat: sample points feature data, [B, D', S]
        """
        xyz = xyz.permute(0, 2, 1)
        if points is not None:
            points = points.permute(0, 2, 1)

        B, N, C = xyz.shape
        S = self.npoint
        new_xyz = index_points(xyz, farthest_point_sample(xyz, S))
        new_points_list = []
        for i, radius in enumerate(self.radius_list):
            K = self.nsample_list[i]
            group_idx = query_ball_point(radius, K, xyz, new_xyz)
            grouped_xyz = index_points(xyz, group_idx)
            grouped_xyz -= new_xyz.view(B, S, 1, C)
            if points is not None:
                grouped_points = index_points(points, group_idx)
                grouped_points = torch.cat([grouped_points, grouped_xyz],
                                           dim=-1)
            else:
                grouped_points = grouped_xyz

            grouped_points = grouped_points.permute(0, 3, 2, 1)  # [B, D, K, S]
            for j in range(len(self.conv_blocks[i])):
                conv = self.conv_blocks[i][j]
                bn = self.bn_blocks[i][j]
                grouped_points = F.relu(bn(conv(grouped_points)))
            new_points = torch.max(grouped_points, 2)[0]  # [B, D', S]
            new_points_list.append(new_points)

        new_xyz = new_xyz.permute(0, 2, 1)
        new_points_concat = torch.cat(new_points_list, dim=1)
        return new_xyz, new_points_concat
    def forward(self, xyz1, xyz2, points1, points2):
        """
        Input:
            xyz1: input points position data, [B, C, N]
            xyz2: sampled input points position data, [B, C, S]
            points1: input points data, [B, D, N]
            points2: input points data, [B, D, S]
        Return:
            new_points: upsampled points data, [B, D', N]
        """
        xyz1 = xyz1.permute(0, 2, 1)  # (B, N, C)
        xyz2 = xyz2.permute(0, 2, 1)  # (B, S, C)

        points2 = points2.permute(0, 2, 1)  # (B, S, D)
        B, N, C = xyz1.shape
        _, S, _ = xyz2.shape

        if S == 1:
            interpolated_points = points2.repeat(1, N, 1)
        else:
            dists = square_distance(xyz1, xyz2)  # (B, N, S)
            dists, idx = dists.sort(dim=-1)  # (B, N, S), # (B, N, S)
            dists, idx = dists[:, :, :3], idx[:, :, :3]  # [B, N, 3], (B, N, 3)

            dist_recip = 1.0 / (dists + 1e-8)  # (B, N, 3)
            norm = torch.sum(dist_recip, dim=2, keepdim=True)  # (B, N, 1)
            weight = dist_recip / norm  # (B, N, 3)
            interpolated_points = torch.sum(index_points(points2, idx) *
                                            weight.view(B, N, 3, 1),
                                            dim=2)  # (B, N, D)

        if points1 is not None:
            points1 = points1.permute(0, 2, 1)  # (B, N, C)
            new_points = torch.cat([points1, interpolated_points],
                                   dim=-1)  # (B, N, C+D)
        else:
            new_points = interpolated_points  # (B, N, D)

        new_points = new_points.permute(0, 2, 1)  # (B, C+D, N)
        for i, conv in enumerate(self.mlp_convs):
            bn = self.mlp_bns[i]
            new_points = F.relu(bn(conv(new_points)))
        return new_points  # (B, mlp[-1], N)
    def forward(self, xyz1, xyz2, points1, points2):
        """
        Interpolate point with densenet
        Input:
            xyz1: input points position data, [B, N, C]
            xyz2: sampled input points position data, [B, S, C]
            points1: input points data, [B, N, D]
            points2: input points data, [B, S, D']
        Return:
            new_points_res: upsampled points data, [B, N, D'']
        """
        B, N, C = xyz1.shape
        _, S, _ = xyz2.shape

        if S == 1:
            interpolated_points = points2.repeat(1, N, 1)
        else:
            assert (S >= 3)
            dists = square_distance(xyz1, xyz2)  # (B, N, S)
            dists, idx = dists.sort(dim=-1)  # (B, N, S)
            dists, idx = dists[:, :, :3], idx[:, :, :3]  # [B, N, 3], [B, N, 3]
            dists[dists < 1e-10] = 1e-10
            weight = 1.0 / dists  # [B, N, 3]
            weight = weight / torch.sum(weight, dim=-1).view(B, N,
                                                             1)  # [B, N, 3]
            interpolated_points = torch.sum(index_points(points2, idx) *
                                            weight.view(B, N, 3, 1),
                                            dim=2)  # (B, N, D')

        if points1 is not None:
            # points1 = points1.permute(0, 2, 1)
            new_points = torch.cat([points1, interpolated_points],
                                   dim=-1)  # [B, N, D+D']
        else:
            new_points = interpolated_points

        new_points = new_points.permute(0, 2, 1)  # [B, D+D', N]
        new_points_dense = self.pnfpdensesnet(new_points)  # [B, D'', N]
        new_points_dense = new_points_dense.permute(0, 2, 1)  # [B, N, D'']

        return new_points_dense