示例#1
0
def get_uniform_loss(pcd, percentages=[0.004, 0.006, 0.008, 0.010, 0.012], radius=1.0):
    B, N, C = pcd.size()
    npoint = int(N * 0.05)
    loss = 0
    for p in percentages:
        nsample = int(N*p)
        r = math.sqrt(p*radius)
        disk_area = math.pi * (radius ** 2) * p/nsample
        new_xyz = pn2.gather_operation(pcd.transpose(1, 2).contiguous(),
                                       pn2.furthest_point_sample(pcd, npoint)).transpose(1, 2).contiguous()
        idx = pn2.ball_query(r, nsample, pcd, new_xyz)
        expect_len = math.sqrt(disk_area)

        grouped_pcd = pn2.grouping_operation(pcd.transpose(1,2).contiguous(), idx)
        grouped_pcd = grouped_pcd.permute(0, 2, 3, 1).contiguous().view(-1, nsample, 3)

        var, _ = knn_point(2, grouped_pcd, grouped_pcd)
        uniform_dis = -var[:, :, 1:]

        uniform_dis = torch.sqrt(torch.abs(uniform_dis+1e-8))
        uniform_dis = torch.mean(uniform_dis, dim=-1)
        uniform_dis = ((uniform_dis - expect_len)**2 / (expect_len + 1e-8))

        mean = torch.mean(uniform_dis)
        mean = mean*math.pow(p*100,2)
        loss += mean
    return loss/len(percentages)
示例#2
0
def edge_preserve_sampling(feature_input, point_input, num_samples, k=10):
    batch_size = feature_input.size()[0]
    feature_size = feature_input.size()[1]
    num_points = feature_input.size()[2]

    p_idx = pn2.furthest_point_sample(point_input, num_samples)
    point_output = pn2.gather_operation(point_input.transpose(1, 2).contiguous(), p_idx).transpose(1,
                                                                                                   2).contiguous()  # B M 3

    pk = int(min(k, num_points))
    _, pn_idx = knn_point(pk, point_input, point_output)
    pn_idx = pn_idx.detach().int()  # B M pk
    # print(pn_idx.size())

    # neighbor_feature = pn2.grouping_operation(feature_input, pn_idx)
    # neighbor_feature = index_points(feature_input.transpose(1,2).contiguous(), pn_idx).permute(0, 3, 1, 2)
    neighbor_feature = pn2.gather_operation(feature_input, pn_idx.view(batch_size, num_samples * pk)).view(batch_size,
                                                                                                           feature_size,
                                                                                                           num_samples,
                                                                                                           pk)
    neighbor_feature, _ = torch.max(neighbor_feature, 3)

    center_feature = pn2.grouping_operation(feature_input, p_idx.unsqueeze(2)).view(batch_size, -1, num_samples)

    net = torch.cat((center_feature, neighbor_feature), 1)

    return net, p_idx, pn_idx, point_output
示例#3
0
    def forward(self, P1: torch.Tensor, P2: torch.Tensor, X1: torch.Tensor,
                S2: torch.Tensor) -> (torch.Tensor):
        r"""
        Parameters
        ----------
        P1:     (B, N, 3)
        P2:     (B, N, 3)
        X1:     (B, C, N)
        S2:     (B, C, N)

        Returns
        -------
        S1:     (B, C, N)
        """
        # 1. Sample points
        idx = pointnet2_utils.ball_query(self.radius, self.nsamples, P2,
                                         P1)  # (B, npoint, nsample)

        # 2.1 Group P2 points
        P2_flipped = P2.transpose(1, 2).contiguous()  # (B, 3, npoint)
        P2_grouped = pointnet2_utils.grouping_operation(
            P2_flipped, idx)  # (B, 3, npoint, nsample)
        # 2.2 Group P2 states
        S2_grouped = pointnet2_utils.grouping_operation(
            S2, idx)  # (B, C, npoint, nsample)

        # 3. Calcaulate displacements
        P1_flipped = P1.transpose(1, 2).contiguous()  # (B, 3, npoint)
        P1_expanded = torch.unsqueeze(P1_flipped, 3)  # (B, 3, npoint, 1)
        displacement = P2_grouped - P1_expanded  # (B, 3, npoint, nsample)
        # 4. Concatenate X1, S2 and displacement
        if self.in_channels != 0:
            X1_expanded = torch.unsqueeze(X1, 3)  # (B, C, npoint, 1)
            X1_repeated = X1_expanded.repeat(1, 1, 1, self.nsamples)
            correlation = torch.cat(tensors=(S2_grouped, X1_repeated,
                                             displacement),
                                    dim=1)
        else:
            correlation = torch.cat(tensors=(S2_grouped, displacement), dim=1)

        # 5. Fully-connected layer (the only parameters)
        S1 = self.fc(correlation)

        # 6. Pooling
        S1 = torch.max(input=S1, dim=-1, keepdim=False)[0]

        return S1
示例#4
0
    def get_repulsion_loss(self, pred):
        _, idx = knn_point(self.nn_size, pred, pred, transpose_mode=True)
        idx = idx[:, :, 1:].to(torch.int32)  # remove first one
        idx = idx.contiguous()  # B, N, nn

        pred = pred.transpose(1, 2).contiguous()  # B, 3, N
        grouped_points = grouping_operation(
            pred, idx)  # (B, 3, N), (B, N, nn) => (B, 3, N, nn)

        grouped_points = grouped_points - pred.unsqueeze(-1)
        dist2 = torch.sum(grouped_points**2, dim=1)
        dist2 = torch.max(dist2, torch.tensor(self.eps).cuda())
        dist = torch.sqrt(dist2)
        weight = torch.exp(-dist2 / self.h**2)

        uniform_loss = torch.mean((self.radius - dist) * weight)
        # uniform_loss = torch.mean(self.radius - dist * weight) # punet
        return uniform_loss
示例#5
0
def get_repulsion_loss(pred, nsample=20, radius=0.07):
    # pred: (batch_size, npoint,3)
    # idx = pn2.ball_query(radius, nsample, pred, pred)
    idx = knn(pred.transpose(1, 2).contiguous(), nsample).int()
    pred_flipped = pred.transpose(1, 2).contiguous()
    grouped_pred = pn2.grouping_operation(pred_flipped, idx)  # (B, C, npoint, nsample)
    grouped_pred -= pred_flipped.unsqueeze(-1)

    # get the uniform loss
    h = 0.03
    dist_square = torch.sum(grouped_pred ** 2, dim=1)
    dist_square, idx = torch.topk(-dist_square, 5)
    dist_square = -dist_square[:, :, 1:]  # remove the first one
    dist_square = torch.max(torch.FloatTensor([1e-12]).expand_as(dist_square).cuda(), dist_square)
    dist = torch.sqrt(dist_square)
    weight = torch.exp(-dist_square / h ** 2)
    uniform_loss = torch.mean(radius - dist * weight)
    return uniform_loss
示例#6
0
    def forward(self, xyzs: torch.Tensor, features: torch.Tensor = None) -> (torch.Tensor, torch.Tensor):
        """
        Args:
            xyzs: torch.Tensor
                 (B, L, N, 3) tensor of sequence of the xyz coordinates
            features: torch.Tensor
                 (B, L, C, N) tensor of sequence of the features
        """
        device = xyzs.get_device()

        nframes = xyzs.size(1)  # L
        npoints = xyzs.size(2)  # N

        if self.temporal_kernel_size > 1 and self.temporal_stride > 1:
            assert ((nframes + sum(self.temporal_padding) - self.temporal_kernel_size) % self.temporal_stride == 0), "PSTConv: Temporal parameter error!"

        xyzs = torch.split(tensor=xyzs, split_size_or_sections=1, dim=1)
        xyzs = [torch.squeeze(input=xyz, dim=1).contiguous() for xyz in xyzs]

        if self.in_planes != 0:
            features = torch.split(tensor=features, split_size_or_sections=1, dim=1)
            features = [torch.squeeze(input=feature, dim=1).contiguous() for feature in features]

        if self.padding_mode == "zeros":
            xyz_padding = torch.zeros(xyzs[0].size(), dtype=torch.float32, device=device)
            for i in range(self.temporal_padding[0]):
                xyzs = [xyz_padding] + xyzs
            for i in range(self.temporal_padding[1]):
                xyzs = xyzs + [xyz_padding]

            if self.in_planes != 0:
                feature_padding = torch.zeros(features[0].size(), dtype=torch.float32, device=device)
                for i in range(self.temporal_padding[0]):
                    features = [feature_padding] + features
                for i in range(self.temporal_padding[1]):
                    features = features + [feature_padding]
        else:   # "replicate"
            for i in range(self.temporal_padding[0]):
                xyzs = [xyzs[0]] + xyzs
            for i in range(self.temporal_padding[1]):
                xyzs = xyzs + [xyzs[-1]]

            if self.in_planes != 0:
                for i in range(self.temporal_padding[0]):
                    features = [features[0]] + features
                for i in range(self.temporal_padding[1]):
                    features = features + [features[-1]]

        new_xyzs = []
        new_features = []
        for t in range(self.temporal_radius, len(xyzs)-self.temporal_radius, self.temporal_stride):                                 # temporal anchor frames
            # spatial anchor point subsampling by FPS
            anchor_idx = pointnet2_utils.furthest_point_sample(xyzs[t], npoints//self.spatial_stride)                               # (B, N//self.spatial_stride)
            anchor_xyz_flipped = pointnet2_utils.gather_operation(xyzs[t].transpose(1, 2).contiguous(), anchor_idx)                 # (B, 3, N//self.spatial_stride)
            anchor_xyz_expanded = torch.unsqueeze(anchor_xyz_flipped, 3)                                                            # (B, 3, N//spatial_stride, 1)
            anchor_xyz = anchor_xyz_flipped.transpose(1, 2).contiguous()                                                            # (B, N//spatial_stride, 3)

            # spatial convolution
            spatial_features = []
            for i in range(t-self.temporal_radius, t+self.temporal_radius+1):
                neighbor_xyz = xyzs[i]

                idx = pointnet2_utils.ball_query(self.r, self.k, neighbor_xyz, anchor_xyz)

                neighbor_xyz_flipped = neighbor_xyz.transpose(1, 2).contiguous()                                                    # (B, 3, N)
                neighbor_xyz_grouped = pointnet2_utils.grouping_operation(neighbor_xyz_flipped, idx)                                # (B, 3, N//spatial_stride, k)

                displacement = neighbor_xyz_grouped - anchor_xyz_expanded                                                           # (B, 3, N//spatial_stride, k)
                displacement = self.spatial_conv_d(displacement)                                                                    # (B, mid_planes, N//spatial_stride, k)

                if self.in_planes != 0:
                    neighbor_feature_grouped = pointnet2_utils.grouping_operation(features[i], idx)                                 # (B, in_planes, N//spatial_stride, k)
                    feature = self.spatial_conv_f(neighbor_feature_grouped)                                                         # (B, mid_planes, N//spatial_stride, k)

                    if self.spatial_aggregation == "addition":
                        spatial_feature = feature + displacement
                    else:
                        spatial_feature = feature * displacement

                else:
                    spatial_feature = displacement

                if self.spatial_pooling == 'max':
                    spatial_feature, _ = torch.max(input=spatial_feature, dim=-1, keepdim=False)                                    # (B, mid_planes, N//spatial_stride)
                elif self.spatial_pooling == 'sum':
                    spatial_feature = torch.sum(input=spatial_feature, dim=-1, keepdim=False)                                       # (B, mid_planes, N//spatial_stride)
                else:
                    spatial_feature = torch.mean(input=spatial_feature, dim=-1, keepdim=False)                                      # (B, mid_planes, N//spatial_stride)

                spatial_features.append(spatial_feature)

            spatial_features = torch.cat(tensors=spatial_features, dim=1, out=None)                                                 # (B, temporal_kernel_size*mid_planes, N//spatial_stride)

            # batch norm and relu
            if self.batch_norm:
                spatial_features = self.batch_norm(spatial_features)

            spatial_features = self.relu(spatial_features)

            # temporal convolution
            spatio_temporal_feature = self.temporal(spatial_features)

            new_xyzs.append(anchor_xyz)
            new_features.append(spatio_temporal_feature)

        new_xyzs = torch.stack(tensors=new_xyzs, dim=1)
        new_features = torch.stack(tensors=new_features, dim=1)

        return new_xyzs, new_features