Exemplo n.º 1
0
    def forward(self, center_coords, coords, features):
        knn_indexes, _ = k_nearest_neighbors(center_coords, coords,
                                             self.k * self.dilation)
        knn_indexes = knn_indexes[:, :, ::self.dilation]
        knn_coords = index2points(coords, knn_indexes)
        knn_local_coords = localize(center_coords, knn_coords)  # (B,C,N,k)

        feature_d = self.mlp_d(knn_local_coords)

        if self.point_feature_size == 0:
            feature_a = feature_d
        else:
            knn_features = index2points(features, knn_indexes)
            feature_a = torch.cat([feature_d, knn_features],
                                  dim=1)  # [B, C+add_C, N, k]

        if self.use_x_transformation:
            trans = self.x_trans(feature_a)
            fx = self.transform(feature_a, trans)
        else:
            fx = feature_a

        fx = self.conv1(fx)
        fx = torch.squeeze(fx, dim=3).contiguous()

        if self.with_global:
            fts_global = self.linear1(center_coords)
            res = torch.can([fts_global, fx], dim=1)
        else:
            res = fx

        return res
Exemplo n.º 2
0
    def forward(self, x, coords):
        # Get p and x of fps indices
        fps_indices = furthest_point_sampling(coords, self.num_samples)
        fps_coords = index2points(coords, fps_indices)  # p

        # Get knn indices
        knn_indices, _ = k_nearest_neighbors(fps_coords, coords, self.k)
        knn_x = index2points(x, knn_indices)

        # MLP
        knn_mlp_x = self.mlp(knn_x)

        # Use local max pooling.
        y, _ = torch.max(knn_mlp_x, dim=-1)

        return y, fps_coords
Exemplo n.º 3
0
 def subsampling(self, coords, num_samples):
     B, C, N = coords.shape
     sampled_point_indices = random_sampling(N, num_samples)
     sampled_point_indices = torch.tensor(sampled_point_indices,
                                          device=coords.device).view(
                                              1, num_samples).contiguous()
     sampled_point_indices = sampled_point_indices.repeat(B, 1).contiguous()
     center_coords = index2points(coords, sampled_point_indices)
     return center_coords
Exemplo n.º 4
0
def get_graph_feature(x, k=20, memory_saving=False):
    B, C, N = x.shape
    k_idx, _ = k_nearest_neighbors(x, x, k)
    feature = index2points(x, k_idx)
    x = x.view(B, C, N, 1).repeat(1, 1, 1, k)
    # x = torch.unsqueeze(x, dim=-1)

    x = torch.cat((feature - x, x), dim=1)

    return x
Exemplo n.º 5
0
    def forward(self, x, coords):
        # Get p and x of fps indices
        # fps_indices = furthest_point_sampling(coords, self.num_samples)
        # fps_coords = index2points(coords, fps_indices) # p

        # Get knn indices
        knn_indices, _ = k_nearest_neighbors(coords, coords, self.k)

        # Get local coords
        knn_coords = index2points(coords, knn_indices)
        local_coods = localize(coords, knn_coords) * -1

        # get knn_features and use MLP
        knn_x = index2points(x, knn_indices)
        knn_mlp_x = self.mlp(knn_x)

        # Use local max pooling. y = pooled features
        y, _ = torch.max(knn_mlp_x, dim=-1)

        return y, knn_mlp_x, local_coods
Exemplo n.º 6
0
    def forward(self, x, coords):
        # Get knn indices
        knn_indices, _ = k_nearest_neighbors(coords, coords, self.k)
        knn_x = index2points(x, knn_indices)

        # MLP
        knn_mlp_x = self.mlp(knn_x)

        # Use local max pooling.
        y, _ = torch.max(knn_mlp_x, dim=-1)

        return y, coords
Exemplo n.º 7
0
def group_layer(coords, center_coords, num_samples, radius, points=None):
    """
    Group layer in PointNet++

    Parameters
    ----------
    coords : torch.tensor [B, 3, N]
        xyz tensor
    center_coords : torch.tensor [B, 3, N']
        xyz tensor of ball query centers
    num_samples : int
       maximum number of samples for ball query
    radius : float
        radius of ball query
    points : torch.tensor [B, C, N]
        Concatenate points to return value.

    Return
    ------
    new_points : torch.tensor [B, 3, N', num_samples] or [B, 3+C, N', num_samples]
        If points is not None, new_points shape is [B, 3+C, N', num_samples].
    """
    # Get sampled coords idx by ball query.
    idx = ball_query(center_coords, coords, radius, num_samples)
    idx = idx.type(torch.long)

    # Convert idx to coords
    grouped_coords = index2points(coords, idx)
    center_coords = torch.unsqueeze(center_coords, 3)
    grouped_coords_norm = grouped_coords - center_coords

    if points is not None:
        grouped_points = index2points(points, idx)
        new_points = torch.cat([grouped_coords_norm, grouped_points], dim=1)
        # note: PointNetSetAbstractionMsg is different order of concatenation.
        # https://github.com/yanx27/Pointnet_Pointnet2_pytorch/blob/2d08fa40635cc5eafd14d19d18e3dc646171910d/models/pointnet_util.py#L253
    else:
        new_points = grouped_coords_norm
    
    return new_points
Exemplo n.º 8
0
    def forward(self, features, coords):
        """
        Parameters
        ----------
        features: torch.tensor (B, C, N)
        coords: torch.tensor (B, 3, N)
        """

        # Get knn indexes.
        knn_indices, _ = k_nearest_neighbors(coords, coords,
                                             self.k)  # get (B, N, k)

        # Get delta.
        outputs_delta = self.pe_delta(coords, knn_indices)  # get (B, C, N, k)

        # Get pointwise feature.
        outputs_phi, outputs_psi, outputs_alpha = torch.chunk(
            self.input_linear(features), chunks=3, dim=1)  # to (B, C, N) x 3

        # Get weights.
        outputs_psi = index2points(outputs_psi, knn_indices)  # to (B, C, N, k)
        inputs_gamma = localize(
            outputs_phi, outputs_psi) * -1 + outputs_delta  # get (B, C, N, k)
        outputs_gamma = self.mlp_gamma(inputs_gamma)
        outputs_rho = self.normalization_rho(outputs_gamma)

        # \alpha(x_j) + \delta
        outputs_alpha = index2points(outputs_alpha,
                                     knn_indices)  # to (B, C, N, k)
        outputs_alpha_delta = outputs_alpha + outputs_delta
        # outputs_alpha_delta = outputs_alpha

        # compute value with hadamard product and aggregation
        outputs_hp = outputs_rho * outputs_alpha_delta
        outputs_aggregation = torch.sum(outputs_hp, dim=-1)  # get (B, C, N)

        return outputs_aggregation
Exemplo n.º 9
0
    def forward(self, coords, knn_indices):
        """
        Parameters
        ----------
        coords: torch.tensor (B, C, N)
        """

        # Get spaces between points.
        knn_coords = index2points(coords, knn_indices)
        coords_space = localize(coords, knn_coords) * -1

        # Use theta.
        outputs = self.mlp_theta(coords_space)

        return outputs
Exemplo n.º 10
0
    def forward(self, f_sem, f_ins):
        adapted_f_sem = self.adaptation(f_sem)

        # for E_INS
        f_sins = f_ins + adapted_f_sem
        e_ins = self.ins_emb_fc(f_sins)

        # for P_SEM
        nn_idx, _ = py_k_nearest_neighbors(e_ins,
                                           e_ins,
                                           self.k,
                                           memory_saving=True)
        k_f_sem = index2points(f_sem, nn_idx)
        f_isem = torch.max(k_f_sem, dim=3, keepdim=True)[0]
        f_isem = torch.squeeze(f_isem, dim=3)
        p_sem = self.sem_pred_fc(f_isem)

        return p_sem, e_ins
Exemplo n.º 11
0
    def forward(self, xyz1, xyz2, points1, points2):
        """
        Parameters
        ----------
        xyz1:
            xyz of center poitns
        xyz2:
            xyz of all points
        points1:
            features of center points
        points2:
            features of all points

        Note
        ----
        xyz1 > xyz2
        """
        B, C, N = xyz1.shape
        _, _, S = xyz2.shape

        if S == 1:
            interpolated_points = points2.repeat(1, 1, N)
        else:
            # xyz1 = xyz1.permute(0,2,1).contiguous()
            # xyz2 = xyz2.permute(0,2,1).contiguous()
            # dists, idxs = three_nn(xyz1, xyz2)
            idxs, dists = k_nearest_neighbors(xyz1, xyz2, 3)

            dist_recip = 1.0 / (dists + 1e-8)
            norm = torch.sum(dist_recip, dim=2, keepdim=True)
            weight = dist_recip / norm
            a = index2points(points2, idxs)
            w = weight.view(B, 1, N, 3)
            aw = a * w
            interpolated_points = torch.sum(aw, dim=3)

        if points1 is not None:
            new_points = torch.cat([points1, interpolated_points], dim=1)
        else:
            new_points = interpolated_points

        new_points = self.mlp(new_points)

        return new_points
Exemplo n.º 12
0
def sampling_layer(coords, num_samples):
    """
    Sampling layer in PointNet++

    Parameters
    ----------
    coords : torch.tensor [B, 3, N]
        xyz tensor
    num_samples : int
       number of samples for furthest point sample

    Return
    ------
    sampled_coords : torch.tensor [B, 3, num_samples]
        sampled xyz using furthest point sample
    """
    fps_idx = furthest_point_sampling(coords, num_samples) # fps_idx = batch_fps(coords, num_samples)
    fps_idx = fps_idx.type(torch.long)
    sampled_coords = index2points(coords, fps_idx)
    return sampled_coords
Exemplo n.º 13
0
    def forward(self, xyz1, xyz2, points1, points2):
        """
        Parameters
        ----------
        xyz1:
            xyz of all poitns
        xyz2:
            xyz of center points
        points1:
            features of all points
        points2:
            features of center points
        """
        B, C, N = xyz1.shape
        _, _, S = xyz2.shape

        if S == 1:
            interpolated_points = points2.repeat(1, 1, N)
        else:
            idxs, dists = py_k_nearest_neighbors(xyz1, xyz2, 3, memory_saving=True)

            dist_recip = 1.0 / (dists + 1e-8)
            norm = torch.sum(dist_recip, dim=2, keepdim=True)
            weight = dist_recip / norm
            a = index2points(points2, idxs)
            w = weight.view(B, 1, N, 3)
            aw = a * w
            interpolated_points = torch.sum(aw, dim=3)

        if points1 is not None:
            new_points = torch.cat([points1, interpolated_points], dim=1)
        else:
            new_points = interpolated_points

        new_points = self.mlp(new_points)

        return new_points
Exemplo n.º 14
0
#     gt_outs = other.index2points(point_clouds, center_idxs)
#     t = timecheck(t, "index2points:")
#     acc = outs == gt_outs
#     print(False in (acc))
#     # print(acc)
#     # print(torch.sum(acc))
#     # print(outs.shape)
#     # print(outs)
#     exit()

k = 1
for data in loader:
    point_clouds, sem_label, ins_label = data
    point_clouds = point_clouds[:, :, :3].transpose(1, 2).to(device)
    center_idxs = furthest_point_sampling(point_clouds, 1024)
    center_points = other.index2points(point_clouds, center_idxs)
    print(point_clouds.shape, center_points.shape)
    knn_idxs, _ = k_nearest_neighbors(center_points, point_clouds, k)
    t = timecheck()
    outs = other.gather(point_clouds, knn_idxs)
    t = timecheck(t, "gather:")
    gt_outs = other.index2points(point_clouds, knn_idxs)
    t = timecheck(t, "index2points:")
    acc = outs == gt_outs
    print(False in (acc))
    # print(acc)
    # print(torch.sum(acc))
    # print(outs.shape)
    # print(outs)
    exit()