예제 #1
0
    def forward(self, x):
        point_cloud1 = x[:, 0:3, :]
        point_cloud1 = point_cloud1.transpose(1, 2).contiguous()

        x0 = F.relu(self.conv1(x))  # 24
        x1 = F.relu(self.dense_conv1(x0))  # 24 + 24 * 3 = 96
        x1 = torch.cat((x1, x0), 1)  # 120
        x1d, _, _, point_cloud2 = edge_preserve_sampling(
            x1, point_cloud1, self.hierarchy[0], self.k)  # 240

        x2 = F.relu(self.conv2(x1d))  # 48
        x2 = F.relu(self.dense_conv2(x2))  # 48 + 24 * 3 = 120
        x2 = torch.cat((x2, x1d), 1)  # 120 + 240 = 360
        x2d, _, _, point_cloud3 = edge_preserve_sampling(
            x2, point_cloud2, self.hierarchy[1], self.k)  # 720

        x3 = F.relu(self.conv3(x2d))
        x3 = F.relu(self.dense_conv3(x3))
        x3 = torch.cat((x3, x2d), 1)
        x3d, _, _, point_cloud4 = edge_preserve_sampling(
            x3, point_cloud3, self.hierarchy[2], self.k)

        x4 = F.relu(self.conv4(x3d))
        x4 = F.relu(self.dense_conv4(x4))
        x4 = torch.cat((x4, x3d), 1)

        global_feat = self.gf_conv(x4)
        global_feat, _ = torch.max(global_feat, -1)
        global_feat = F.relu(self.fc1(global_feat))
        global_feat = F.relu(self.fc2(global_feat)).unsqueeze(2).repeat(
            1, 1, self.hierarchy[2])

        x4 = torch.cat((global_feat, x4), 1)
        x4 = F.relu(self.conv5(x4))
        idx, weight = three_nn_upsampling(point_cloud3, point_cloud4)
        x4 = pn2.three_interpolate(x4, idx, weight)

        x3 = torch.cat((x3, x4), 1)
        x3 = F.relu(self.conv6(x3))
        idx, weight = three_nn_upsampling(point_cloud2, point_cloud3)
        x3 = pn2.three_interpolate(x3, idx, weight)

        x2 = torch.cat((x2, x3), 1)
        x2 = F.relu(self.conv7(x2))
        idx, weight = three_nn_upsampling(point_cloud1, point_cloud2)
        x2 = pn2.three_interpolate(x2, idx, weight)

        x1 = torch.cat((x1, x2), 1)
        x1 = self.conv8(x1)
        return x1
예제 #2
0
 def interpolate_func(inputs):
     idx = torch.from_numpy(np.array([[[0, 1, 2], [1, 2, 3]]])).int().cuda()
     weight = torch.from_numpy(np.array([[[1, 1, 1], [2, 2,
                                                      2]]])).float().cuda()
     interpolated_feats = pointnet2_utils.three_interpolate(
         inputs, idx, weight)
     return interpolated_feats
예제 #3
0
    def forward(self, unknown: torch.Tensor, known: torch.Tensor,
                unknow_feats: torch.Tensor,
                known_feats: torch.Tensor) -> torch.Tensor:
        """
        :param unknown: (B, n, 3) tensor of the xyz positions of the unknown features
        :param known: (B, m, 3) tensor of the xyz positions of the known features
        :param unknow_feats: (B, C1, n) tensor of the features to be propigated to
        :param known_feats: (B, C2, m) tensor of features to be propigated
        :return:
            new_features: (B, mlp[-1], n) tensor of the features of the unknown features
        """

        if known is not None:
            dist, idx = pointnet2_utils.three_nn(unknown, known)
            dist_recip = 1.0 / (dist + 1e-8)
            norm = torch.sum(dist_recip, dim=2, keepdim=True)
            weight = dist_recip / norm

            interpolated_feats = pointnet2_utils.three_interpolate(
                known_feats, idx, weight)
        else:
            interpolated_feats = known_feats.expand(*known_feats.size()[0:2],
                                                    unknown.size(1))
        if unknow_feats is not None:
            new_features = torch.cat([interpolated_feats, unknow_feats],
                                     dim=1)  # (B, C2 + C1, n)
        else:
            new_features = interpolated_feats

        new_features = new_features.unsqueeze(-1)
        new_features = self.mlp(new_features)

        return new_features.squeeze(-1)
예제 #4
0
    def forward(
            self, unknown: torch.Tensor, known: torch.Tensor,
            unknow_feats: torch.Tensor, known_feats: torch.Tensor
    ) -> torch.Tensor:
        r"""
        Parameters
        ----------
        unknown : torch.Tensor
            (B, n, 3) tensor of the xyz positions of the unknown features
        known : torch.Tensor
            (B, m, 3) tensor of the xyz positions of the known features
        unknow_feats : torch.Tensor
            (B, C1, n) tensor of the features to be propigated to
        known_feats : torch.Tensor
            (B, C2, m) tensor of features to be propigated

        Returns
        -------
        new_features : torch.Tensor
            (B, mlp[-1], n) tensor of the features of the unknown features
        """

        dist, idx = pointnet2_utils.three_nn(unknown, known)
        dist_recip = 1.0 / (dist + 1e-8)
        norm = torch.sum(dist_recip, dim=2, keepdim=True)
        weight = dist_recip / norm

        interpolated_feats = pointnet2_utils.three_interpolate(
            known_feats, idx, weight
        )
        if unknow_feats is not None:
            new_features = torch.cat([interpolated_feats, unknow_feats],
                                     dim=1)  #(B, C2 + C1, n)
        else:
            new_features = interpolated_feats
        
        new_features = new_features.unsqueeze(-1)
        new_features = self.mlp(new_features)

        return new_features.squeeze(-1)
    def forward(self, query_xyz: torch.Tensor, support_xyz: torch.Tensor,
                query_features: torch.Tensor,
                support_features: torch.Tensor) -> torch.Tensor:
        """
        :param query_xyz: (B, n, 3) tensor of the xyz positions of the unknown features
        :param support_xyz: (B, m, 3) tensor of the xyz positions of the known features
        :param query_features: (B, C1, n) tensor of the features to be propagated to
        :param support_features: (B, C2, m) tensor of features to be propagated

        :return:
            new_features: (B, mlp[-1], n) tensor of the features of the unknown features
        """
        # nearest neighbor interpolation with inverse distance weight (k=3)
        if support_xyz is not None:
            dist, idx = pointnet2_utils.three_nn(
                query_xyz, support_xyz)  # (B,n,3)  (B,n,3)
            dist_recip = 1.0 / (dist + 1e-8)
            norm = torch.sum(dist_recip, dim=2, keepdim=True)
            weight = dist_recip / norm

            interpolated_feats = pointnet2_utils.three_interpolate(
                support_features, idx, weight)  # (B,C2,n)
        else:
            interpolated_feats = support_features.expand(
                *support_features.size()[0:2], query_xyz.size(1))

        if query_features is not None:
            new_features = torch.cat([interpolated_feats, query_features],
                                     dim=1)  # (B, C2 + C1, n)
        else:
            new_features = interpolated_feats

        new_features = new_features.unsqueeze(-1)  # (B, C2 + C1, n, 1)
        new_features = self.mlp(new_features)  # (B, mlp[-1], n,1)

        return new_features.squeeze(-1)
    def forward(self, x_a: ME.SparseTensor, x_b: ME.SparseTensor):
        """
        Input:
            M < N
            xyz_1: input points position data, [B, 3, M]
            xyz_2: input points position data, [B, 3, N]
            points_1: input points data, [B, C, M]
            points_2: input points data, [B, C, N]

            interpolate xyz_2's coordinates feature with knn neighbor's features weighted by inverse distance

            TODO: For POINT_TR_LIKE, add support for no x_b is fed, simply upsample the x_a

        Return:
            new_xyz: sampled points position data, [B, C, S]
            new_points_concat: sample points feature data, [B, D', S]
        """

        if self.POINT_TR_LIKE:

            dim = x_b.F.shape[1]
            assert dim == self.out_dim

            x_ac, mask_a, idx_a = separate_batch(x_a.C)
            B = x_ac.shape[0]
            N_a = x_ac.shape[1]
            x_af = torch.zeros(B * N_a, dim).cuda()
            idx_a = idx_a.reshape(-1, 1).repeat(1, dim)
            x_af.scatter_(dim=0, index=idx_a, src=self.linear_a(x_a.F))
            x_af = x_af.reshape([B, N_a, dim])

            x_bc, mask_b, idx_b = separate_batch(x_b.C)
            B = x_bc.shape[0]
            N_b = x_bc.shape[1]
            x_bf = torch.zeros(B * N_b, dim).cuda()
            idx_b = idx_b.reshape(-1, 1).repeat(1, dim)
            x_bf.scatter_(dim=0, index=idx_b, src=self.linear_b(x_b.F))
            x_bf = x_bf.reshape([B, N_b, dim])

            dists, idx = three_nn(x_bc.float(), x_ac.float())

            mask = (dists.sum(dim=-1) > 0).unsqueeze(-1).repeat(1, 1, 3)

            dist_recip = 1.0 / (dists + 1e-1)
            norm = torch.sum(dist_recip, dim=2, keepdim=True)
            weight = dist_recip / norm
            weight = weight * mask  # mask the zeros part

            interpolated_points = three_interpolate(
                x_af.transpose(1, 2).contiguous(), idx,
                weight).transpose(1, 2)  # [B, N_b, dim]
            out = interpolated_points + x_bf

            out = torch.gather(
                out.reshape(B * N_b, dim), dim=0,
                index=idx_b)  # should be the same size with x_a.F
            x = ME.SparseTensor(features=out,
                                coordinate_map_key=x_b.coordinate_map_key,
                                coordinate_manager=x_b.coordinate_manager)

        else:
            if self.SUM_FEATURE:
                x_a = self.conv_a(x_a)
                x_b = self.conv_b(x_b)
                x = x_a + x_b
            else:
                x_a = self.conv(x_a)
                x_a = self.bn(x_a)
                x_a = self.relu(x_a)
                x = me.cat(x_a, x_b)
                x = self.out_conv(x)
                x = self.out_bn(x)
                x = self.out_relu(x)

        return x
예제 #7
0
 def _edge_unpooling(self, features, src_pts, tgt_pts):
     features = features.squeeze(2)
     idx, weight = three_nn_upsampling(tgt_pts, src_pts)
     features = pn2.three_interpolate(features, idx, weight)
     features = features.unsqueeze(2)
     return features
예제 #8
0
    def forward(self, xyzs: torch.Tensor, original_xyzs: torch.Tensor, features: torch.Tensor, original_features: torch.Tensor = None) -> torch.Tensor:
        r"""
        Parameters
        ----------
        xyzs : torch.Tensor
            (B, L', N', 3) tensor of the xyz positions of the convolved features
        original_xyzs : torch.Tensor
            (B, L,  N,  3) tensor of the xyz positions of the original points

        features : torch.Tensor
            (B, L', C', N') tensor of the features to be propigated to
        original_features : torch.Tensor
            (B, L,  C,  N) tensor of original point features for skip connection

        Returns
        -------
        new_features : torch.Tensor
            (B, L,  C", N) tensor of the features of the unknown features
        """

        L1 = original_xyzs.size(1)
        N1 = original_xyzs.size(2)

        L2 = xyzs.size(1)
        N2 = xyzs.size(2)

        if self.temporal_kernel_size > 1 and self.temporal_stride > 1:
            assert ((L2 - 1) * self.temporal_stride + sum(self.temporal_padding) + self.temporal_kernel_size == L1), "PSTConvTranspose: Temporal parameter error!"

        xyzs = torch.split(tensor=xyzs, split_size_or_sections=1, dim=1)
        xyzs = [torch.squeeze(input=xyz, dim=1).contiguous() for xyz in xyzs]

        features = torch.split(tensor=features, split_size_or_sections=1, dim=1)
        features = [torch.squeeze(input=feature, dim=1).contiguous() for feature in features]

        new_xyzs = original_xyzs

        original_xyzs = torch.split(tensor=original_xyzs, split_size_or_sections=1, dim=1)
        original_xyzs = [torch.squeeze(input=original_xyz, dim=1).contiguous() for original_xyz in original_xyzs]

        if original_features is not None:
            original_features = torch.split(tensor=original_features, split_size_or_sections=1, dim=1)
            original_features = [torch.squeeze(input=feature, dim=1).contiguous() for feature in original_features]

        # temporal transposed convolution
        temporal_trans_features = []
        for feature in features:
            feature = self.temporal_conv(feature)
            feature = torch.split(tensor=feature, split_size_or_sections=self.mid_planes, dim=1)
            temporal_trans_features.append(feature)

        # temporal interpolation
        temporal_interpolated_xyzs = []
        temporal_interpolated_features = []

        middles = []
        deltas = []
        for t2 in range(1, L2+1):
            middle = t2 + (t2-1)*(self.temporal_stride-1) + self.temporal_radius + self.temporal_padding[0]
            middles.append(middle)
            delta = range(middle - self.temporal_radius, middle + self.temporal_radius + self.temporal_padding[1] + 1)
            deltas.append(delta)

        for t1 in range(1, L1+1):
            seed_xyzs = []
            seed_features = []
            for t2 in range(L2):
                delta = deltas[t2]
                if t1 in delta:
                    seed_xyzs.append(xyzs[t2])
                    seed_feature = temporal_trans_features[t2][t1-middles[t2]+self.temporal_radius]
                    if self.batch_norm:
                        seed_feature = self.batch_norm(seed_feature)
                    if self.activation:
                        seed_feature = self.activation(seed_feature)
                    seed_features.append(seed_feature)
            seed_xyzs = torch.cat(seed_xyzs, dim=1)
            seed_features = torch.cat(seed_features, dim=2)
            temporal_interpolated_xyzs.append(seed_xyzs)
            temporal_interpolated_features.append(seed_features)

        # spatial interpolation
        new_features = []
        for t1 in range(L1):
            neighbor_xyz = temporal_interpolated_xyzs[t1]                                                               # [B, N', 3]
            anchor_xyz = original_xyzs[t1]                                                                              # [B, N,  3]

            dist, idx = pointnet2_utils.three_nn(anchor_xyz, neighbor_xyz)

            dist_recip = 1.0 / (dist + 1e-8)
            norm = torch.sum(dist_recip, dim=2, keepdim=True)
            weight = dist_recip / norm

            interpolated_feats = pointnet2_utils.three_interpolate(temporal_interpolated_features[t1], idx, weight)

            if original_features is not None:
                new_feature = torch.cat([interpolated_feats, original_features[t1]], dim=1)
            else:
                new_feature = interpolated_feats

            new_feature = self.spatial_conv(new_feature)

            new_features.append(new_feature)

        new_features = torch.stack(tensors=new_features, dim=1)

        return new_xyzs, new_features