Esempio n. 1
0
    def forward(self, center_coords, coords, features):
        knn_indexes, _ = k_nearest_neighbors(center_coords, coords,
                                             self.k * self.dilation)
        knn_indexes = knn_indexes[:, :, ::self.dilation]
        knn_coords = index2points(coords, knn_indexes)
        knn_local_coords = localize(center_coords, knn_coords)  # (B,C,N,k)

        feature_d = self.mlp_d(knn_local_coords)

        if self.point_feature_size == 0:
            feature_a = feature_d
        else:
            knn_features = index2points(features, knn_indexes)
            feature_a = torch.cat([feature_d, knn_features],
                                  dim=1)  # [B, C+add_C, N, k]

        if self.use_x_transformation:
            trans = self.x_trans(feature_a)
            fx = self.transform(feature_a, trans)
        else:
            fx = feature_a

        fx = self.conv1(fx)
        fx = torch.squeeze(fx, dim=3).contiguous()

        if self.with_global:
            fts_global = self.linear1(center_coords)
            res = torch.can([fts_global, fx], dim=1)
        else:
            res = fx

        return res
Esempio n. 2
0
def get_graph_feature(x, k=20, memory_saving=False):
    B, C, N = x.shape
    k_idx, _ = k_nearest_neighbors(x, x, k)
    feature = index2points(x, k_idx)
    x = x.view(B, C, N, 1).repeat(1, 1, 1, k)
    # x = torch.unsqueeze(x, dim=-1)

    x = torch.cat((feature - x, x), dim=1)

    return x
Esempio n. 3
0
    def forward(self, x, coords):
        # Get knn indices
        knn_indices, _ = k_nearest_neighbors(coords, coords, self.k)
        knn_x = index2points(x, knn_indices)

        # MLP
        knn_mlp_x = self.mlp(knn_x)

        # Use local max pooling.
        y, _ = torch.max(knn_mlp_x, dim=-1)

        return y, coords
Esempio n. 4
0
def label_and_features_to_label_boundary(features,
                                         k,
                                         label,
                                         radius=0.05,
                                         method="knn"):
    """
    create label boundary information.

    Parameters
    ----------
        features:torch.tensor (shape:(batch, num_points, num_dims))
            coordinates 
        k:int
            k of kNN
        label:torch.tensor (shape:(batch, num_points))
            labels
        radius:float
            ball query radius (method="ball_query")
        method:str
            method to create is_transition (knn or ball_query)
    Returns:
    --------
        source:torch.tensor (shape: (batch*num_points*k))
        target:torch.tensor (shape: (batch*num_points*k))
        is_transition:torch.tensor (shape: (batch*num_points*k))
    """

    features = torch.transpose(features, 1, 2).contiguous()

    if method == "knn":
        idx_w_centerpoint = k_nearest_neighbors(features, features, k + 1)
    elif method == "ball_query":
        if features.shape[2] != 3:
            raise NotImplementedError(
                "The ball_query is only for features num_dims == 3")
        idx_w_centerpoint = ball_query(features, features, radius, k + 1)
    else:
        NotImplementedError("knn or ball_query")

    source = idx_w_centerpoint[:, :, 0:1].repeat(1, 1,
                                                 k)  # get center points of kNN
    target = idx_w_centerpoint[:, :, 1:]  # remove center points of kNN

    source = index2row(source)
    target = index2row(target)

    label = label.view(-1)
    source_label = label[source]
    target_label = label[target]

    is_transition = source_label != target_label

    return source, target, is_transition
Esempio n. 5
0
    def forward(self, x, coords):
        # Get p and x of fps indices
        fps_indices = furthest_point_sampling(coords, self.num_samples)
        fps_coords = index2points(coords, fps_indices)  # p

        # Get knn indices
        knn_indices, _ = k_nearest_neighbors(fps_coords, coords, self.k)
        knn_x = index2points(x, knn_indices)

        # MLP
        knn_mlp_x = self.mlp(knn_x)

        # Use local max pooling.
        y, _ = torch.max(knn_mlp_x, dim=-1)

        return y, fps_coords
    def forward(self, xyz1, xyz2, points1, points2):
        """
        Parameters
        ----------
        xyz1:
            xyz of center poitns
        xyz2:
            xyz of all points
        points1:
            features of center points
        points2:
            features of all points

        Note
        ----
        xyz1 > xyz2
        """
        B, C, N = xyz1.shape
        _, _, S = xyz2.shape

        if S == 1:
            interpolated_points = points2.repeat(1, 1, N)
        else:
            # xyz1 = xyz1.permute(0,2,1).contiguous()
            # xyz2 = xyz2.permute(0,2,1).contiguous()
            # dists, idxs = three_nn(xyz1, xyz2)
            idxs, dists = k_nearest_neighbors(xyz1, xyz2, 3)

            dist_recip = 1.0 / (dists + 1e-8)
            norm = torch.sum(dist_recip, dim=2, keepdim=True)
            weight = dist_recip / norm
            a = index2points(points2, idxs)
            w = weight.view(B, 1, N, 3)
            aw = a * w
            interpolated_points = torch.sum(aw, dim=3)

        if points1 is not None:
            new_points = torch.cat([points1, interpolated_points], dim=1)
        else:
            new_points = interpolated_points

        new_points = self.mlp(new_points)

        return new_points
    def forward(self, x, coords):
        # Get p and x of fps indices
        # fps_indices = furthest_point_sampling(coords, self.num_samples)
        # fps_coords = index2points(coords, fps_indices) # p

        # Get knn indices
        knn_indices, _ = k_nearest_neighbors(coords, coords, self.k)

        # Get local coords
        knn_coords = index2points(coords, knn_indices)
        local_coods = localize(coords, knn_coords) * -1

        # get knn_features and use MLP
        knn_x = index2points(x, knn_indices)
        knn_mlp_x = self.mlp(knn_x)

        # Use local max pooling. y = pooled features
        y, _ = torch.max(knn_mlp_x, dim=-1)

        return y, knn_mlp_x, local_coods
Esempio n. 8
0
    def forward(self, features, coords):
        """
        Parameters
        ----------
        features: torch.tensor (B, C, N)
        coords: torch.tensor (B, 3, N)
        """

        # Get knn indexes.
        knn_indices, _ = k_nearest_neighbors(coords, coords,
                                             self.k)  # get (B, N, k)

        # Get delta.
        outputs_delta = self.pe_delta(coords, knn_indices)  # get (B, C, N, k)

        # Get pointwise feature.
        outputs_phi, outputs_psi, outputs_alpha = torch.chunk(
            self.input_linear(features), chunks=3, dim=1)  # to (B, C, N) x 3

        # Get weights.
        outputs_psi = index2points(outputs_psi, knn_indices)  # to (B, C, N, k)
        inputs_gamma = localize(
            outputs_phi, outputs_psi) * -1 + outputs_delta  # get (B, C, N, k)
        outputs_gamma = self.mlp_gamma(inputs_gamma)
        outputs_rho = self.normalization_rho(outputs_gamma)

        # \alpha(x_j) + \delta
        outputs_alpha = index2points(outputs_alpha,
                                     knn_indices)  # to (B, C, N, k)
        outputs_alpha_delta = outputs_alpha + outputs_delta
        # outputs_alpha_delta = outputs_alpha

        # compute value with hadamard product and aggregation
        outputs_hp = outputs_rho * outputs_alpha_delta
        outputs_aggregation = torch.sum(outputs_hp, dim=-1)  # get (B, C, N)

        return outputs_aggregation
Esempio n. 9
0
#     gt_outs = other.index2points(point_clouds, center_idxs)
#     t = timecheck(t, "index2points:")
#     acc = outs == gt_outs
#     print(False in (acc))
#     # print(acc)
#     # print(torch.sum(acc))
#     # print(outs.shape)
#     # print(outs)
#     exit()

k = 1
for data in loader:
    point_clouds, sem_label, ins_label = data
    point_clouds = point_clouds[:, :, :3].transpose(1, 2).to(device)
    center_idxs = furthest_point_sampling(point_clouds, 1024)
    center_points = other.index2points(point_clouds, center_idxs)
    print(point_clouds.shape, center_points.shape)
    knn_idxs, _ = k_nearest_neighbors(center_points, point_clouds, k)
    t = timecheck()
    outs = other.gather(point_clouds, knn_idxs)
    t = timecheck(t, "gather:")
    gt_outs = other.index2points(point_clouds, knn_idxs)
    t = timecheck(t, "index2points:")
    acc = outs == gt_outs
    print(False in (acc))
    # print(acc)
    # print(torch.sum(acc))
    # print(outs.shape)
    # print(outs)
    exit()