예제 #1
0
def get_uniform_loss(pcd, percentages=[0.004, 0.006, 0.008, 0.010, 0.012], radius=1.0):
    B, N, C = pcd.size()
    npoint = int(N * 0.05)
    loss = 0
    for p in percentages:
        nsample = int(N*p)
        r = math.sqrt(p*radius)
        disk_area = math.pi * (radius ** 2) * p/nsample
        new_xyz = pn2.gather_operation(pcd.transpose(1, 2).contiguous(),
                                       pn2.furthest_point_sample(pcd, npoint)).transpose(1, 2).contiguous()
        idx = pn2.ball_query(r, nsample, pcd, new_xyz)
        expect_len = math.sqrt(disk_area)

        grouped_pcd = pn2.grouping_operation(pcd.transpose(1,2).contiguous(), idx)
        grouped_pcd = grouped_pcd.permute(0, 2, 3, 1).contiguous().view(-1, nsample, 3)

        var, _ = knn_point(2, grouped_pcd, grouped_pcd)
        uniform_dis = -var[:, :, 1:]

        uniform_dis = torch.sqrt(torch.abs(uniform_dis+1e-8))
        uniform_dis = torch.mean(uniform_dis, dim=-1)
        uniform_dis = ((uniform_dis - expect_len)**2 / (expect_len + 1e-8))

        mean = torch.mean(uniform_dis)
        mean = mean*math.pow(p*100,2)
        loss += mean
    return loss/len(percentages)
예제 #2
0
    def forward(self, P1: torch.Tensor, P2: torch.Tensor, X1: torch.Tensor,
                S2: torch.Tensor) -> (torch.Tensor):
        r"""
        Parameters
        ----------
        P1:     (B, N, 3)
        P2:     (B, N, 3)
        X1:     (B, C, N)
        S2:     (B, C, N)

        Returns
        -------
        S1:     (B, C, N)
        """
        # 1. Sample points
        idx = pointnet2_utils.ball_query(self.radius, self.nsamples, P2,
                                         P1)  # (B, npoint, nsample)

        # 2.1 Group P2 points
        P2_flipped = P2.transpose(1, 2).contiguous()  # (B, 3, npoint)
        P2_grouped = pointnet2_utils.grouping_operation(
            P2_flipped, idx)  # (B, 3, npoint, nsample)
        # 2.2 Group P2 states
        S2_grouped = pointnet2_utils.grouping_operation(
            S2, idx)  # (B, C, npoint, nsample)

        # 3. Calcaulate displacements
        P1_flipped = P1.transpose(1, 2).contiguous()  # (B, 3, npoint)
        P1_expanded = torch.unsqueeze(P1_flipped, 3)  # (B, 3, npoint, 1)
        displacement = P2_grouped - P1_expanded  # (B, 3, npoint, nsample)
        # 4. Concatenate X1, S2 and displacement
        if self.in_channels != 0:
            X1_expanded = torch.unsqueeze(X1, 3)  # (B, C, npoint, 1)
            X1_repeated = X1_expanded.repeat(1, 1, 1, self.nsamples)
            correlation = torch.cat(tensors=(S2_grouped, X1_repeated,
                                             displacement),
                                    dim=1)
        else:
            correlation = torch.cat(tensors=(S2_grouped, displacement), dim=1)

        # 5. Fully-connected layer (the only parameters)
        S1 = self.fc(correlation)

        # 6. Pooling
        S1 = torch.max(input=S1, dim=-1, keepdim=False)[0]

        return S1
예제 #3
0
    def forward(self, xyzs: torch.Tensor, features: torch.Tensor = None) -> (torch.Tensor, torch.Tensor):
        """
        Args:
            xyzs: torch.Tensor
                 (B, L, N, 3) tensor of sequence of the xyz coordinates
            features: torch.Tensor
                 (B, L, C, N) tensor of sequence of the features
        """
        device = xyzs.get_device()

        nframes = xyzs.size(1)  # L
        npoints = xyzs.size(2)  # N

        if self.temporal_kernel_size > 1 and self.temporal_stride > 1:
            assert ((nframes + sum(self.temporal_padding) - self.temporal_kernel_size) % self.temporal_stride == 0), "PSTConv: Temporal parameter error!"

        xyzs = torch.split(tensor=xyzs, split_size_or_sections=1, dim=1)
        xyzs = [torch.squeeze(input=xyz, dim=1).contiguous() for xyz in xyzs]

        if self.in_planes != 0:
            features = torch.split(tensor=features, split_size_or_sections=1, dim=1)
            features = [torch.squeeze(input=feature, dim=1).contiguous() for feature in features]

        if self.padding_mode == "zeros":
            xyz_padding = torch.zeros(xyzs[0].size(), dtype=torch.float32, device=device)
            for i in range(self.temporal_padding[0]):
                xyzs = [xyz_padding] + xyzs
            for i in range(self.temporal_padding[1]):
                xyzs = xyzs + [xyz_padding]

            if self.in_planes != 0:
                feature_padding = torch.zeros(features[0].size(), dtype=torch.float32, device=device)
                for i in range(self.temporal_padding[0]):
                    features = [feature_padding] + features
                for i in range(self.temporal_padding[1]):
                    features = features + [feature_padding]
        else:   # "replicate"
            for i in range(self.temporal_padding[0]):
                xyzs = [xyzs[0]] + xyzs
            for i in range(self.temporal_padding[1]):
                xyzs = xyzs + [xyzs[-1]]

            if self.in_planes != 0:
                for i in range(self.temporal_padding[0]):
                    features = [features[0]] + features
                for i in range(self.temporal_padding[1]):
                    features = features + [features[-1]]

        new_xyzs = []
        new_features = []
        for t in range(self.temporal_radius, len(xyzs)-self.temporal_radius, self.temporal_stride):                                 # temporal anchor frames
            # spatial anchor point subsampling by FPS
            anchor_idx = pointnet2_utils.furthest_point_sample(xyzs[t], npoints//self.spatial_stride)                               # (B, N//self.spatial_stride)
            anchor_xyz_flipped = pointnet2_utils.gather_operation(xyzs[t].transpose(1, 2).contiguous(), anchor_idx)                 # (B, 3, N//self.spatial_stride)
            anchor_xyz_expanded = torch.unsqueeze(anchor_xyz_flipped, 3)                                                            # (B, 3, N//spatial_stride, 1)
            anchor_xyz = anchor_xyz_flipped.transpose(1, 2).contiguous()                                                            # (B, N//spatial_stride, 3)

            # spatial convolution
            spatial_features = []
            for i in range(t-self.temporal_radius, t+self.temporal_radius+1):
                neighbor_xyz = xyzs[i]

                idx = pointnet2_utils.ball_query(self.r, self.k, neighbor_xyz, anchor_xyz)

                neighbor_xyz_flipped = neighbor_xyz.transpose(1, 2).contiguous()                                                    # (B, 3, N)
                neighbor_xyz_grouped = pointnet2_utils.grouping_operation(neighbor_xyz_flipped, idx)                                # (B, 3, N//spatial_stride, k)

                displacement = neighbor_xyz_grouped - anchor_xyz_expanded                                                           # (B, 3, N//spatial_stride, k)
                displacement = self.spatial_conv_d(displacement)                                                                    # (B, mid_planes, N//spatial_stride, k)

                if self.in_planes != 0:
                    neighbor_feature_grouped = pointnet2_utils.grouping_operation(features[i], idx)                                 # (B, in_planes, N//spatial_stride, k)
                    feature = self.spatial_conv_f(neighbor_feature_grouped)                                                         # (B, mid_planes, N//spatial_stride, k)

                    if self.spatial_aggregation == "addition":
                        spatial_feature = feature + displacement
                    else:
                        spatial_feature = feature * displacement

                else:
                    spatial_feature = displacement

                if self.spatial_pooling == 'max':
                    spatial_feature, _ = torch.max(input=spatial_feature, dim=-1, keepdim=False)                                    # (B, mid_planes, N//spatial_stride)
                elif self.spatial_pooling == 'sum':
                    spatial_feature = torch.sum(input=spatial_feature, dim=-1, keepdim=False)                                       # (B, mid_planes, N//spatial_stride)
                else:
                    spatial_feature = torch.mean(input=spatial_feature, dim=-1, keepdim=False)                                      # (B, mid_planes, N//spatial_stride)

                spatial_features.append(spatial_feature)

            spatial_features = torch.cat(tensors=spatial_features, dim=1, out=None)                                                 # (B, temporal_kernel_size*mid_planes, N//spatial_stride)

            # batch norm and relu
            if self.batch_norm:
                spatial_features = self.batch_norm(spatial_features)

            spatial_features = self.relu(spatial_features)

            # temporal convolution
            spatio_temporal_feature = self.temporal(spatial_features)

            new_xyzs.append(anchor_xyz)
            new_features.append(spatio_temporal_feature)

        new_xyzs = torch.stack(tensors=new_xyzs, dim=1)
        new_features = torch.stack(tensors=new_features, dim=1)

        return new_xyzs, new_features