def forward(self, pc_down, pc_up, feat_down, feat_up):
        '''
        pc_down : B x 3 x N_small
        pc_up : B x 3 x N_large

        feat_down : B x C1 x N_small
        feat_up : B x C2 x N_large

        return : B x mlp[-1] x N_large

        '''
        idx, dist = md_utils._knn_indices(feat=pc_down,
                                          k=3,
                                          centroid=pc_up,
                                          dist=True)  # B x N_large x k

        dist_recip = 1.0 / (dist + 1e-8)
        norm = torch.sum(dist_recip, dim=2, keepdim=True)
        weight = dist_recip / norm
        grouped_feat = md_utils._indices_group(feat_down,
                                               idx)  # B x C1 x N_large x k
        weight = weight.unsqueeze(1)  # B x 1 x N_large x k

        interpolated_feats = grouped_feat * weight
        interpolated_feats = torch.max(interpolated_feats,
                                       dim=-1)[0]  # B x C1 x N_large

        interpolated_feats = torch.cat([interpolated_feats, feat_up],
                                       dim=1)  # B x C1+C2 x N_large
        interpolated_feats = interpolated_feats.unsqueeze(-1)

        interpolated_feats = self.mlp(interpolated_feats)

        return interpolated_feats.squeeze(-1)  # B x out_channel x N_large
    def forward(self, pc):
        '''
        pc_withnor : B x N x 6
        or pc_withoutnor : B x N x3
        '''
        assert pc.size(
        )[2] == self.inchannel, 'illegal input pc size:{}'.format(pc.size())
        B, N, _ = pc.size()
        pc = pc.permute(0, 2, 1)
        pc_xyz = pc[:, 0:3, :]
        idx = md_utils._knn_indices(pc_xyz, k=self.K_knn)
        grouped_xyz = md_utils._indices_group(pc_xyz, idx)  #B x 3 x N x k

        grouped_pc = pc_xyz.unsqueeze(-1).expand(B, 3, N, self.K_knn)
        grouped_pc = grouped_xyz - grouped_pc

        feat_1 = self.spiderconv1(pc, idx, grouped_pc)  # B x 64 x N
        feat_2 = self.spiderconv2(feat_1, idx, grouped_pc)
        feat_3 = self.spiderconv3(feat_2, idx, grouped_pc)
        feat_4 = self.spiderconv4(feat_3, idx, grouped_pc)

        cat_feat = torch.cat([feat_1, feat_2, feat_3, feat_4],
                             dim=1)  # B x 480 x N
        cat_feat = torch.topk(cat_feat, 2, dim=2)[0]  # B x 480 x 2
        cat_feat = cat_feat.view(B, -1)  # B x 960

        return cat_feat
Exemple #3
0
    def forward(self, pts, fts):
        assert pts.size()[1] == 3, 'illegal pointcloud size:{}'.format(pts.size())
        B, _, N = pts.size()

        if self.P == -1: # P==-1 has two situation 1),input layer 2),sample number consistent within two layer
            qrs = pts
            self.P = N
        else:
            sample_indices = self.sample(pts.permute(0, 2, 1).contiguous())
            sample_indices = sample_indices.unsqueeze(1).expand(B, 3, self.P)
            qrs = torch.gather(pts, 2, sample_indices) # B x 3 x P
        indices_dilated = md_utils._knn_indices(pts, k=self.K*self.D, centroid=qrs) #B x P x K*D
        indices = indices_dilated[:, :, ::self.D] # B x P x K
        nn_pts = md_utils._indices_group(pts, indices) # B x 3 x P x K
        nn_pts_center = qrs.unsqueeze(-1).expand_as(nn_pts)
        nn_pts_local = nn_pts - nn_pts_center # B x 3 x P x K

        nn_fts_from_pts = self.mlp_delta(nn_pts_local) # B x C_delta x P x K
        if fts is None: # in the first layer
            nn_fts_input = nn_fts_from_pts
        else:
            nn_fts_input = md_utils._indices_group(fts, indices)
            nn_fts_input = torch.cat([nn_fts_from_pts, nn_fts_input], dim=1) # B x C_delta+C_in x P x K

        X = self.X_transform0(nn_pts_local) # B x K*K x P x 1
        X = X.view(B, self.K, self.P, self.K)
        X = self.X_transform1(X)
        X = X.view(B, self.K, self.P, self.K)
        X = self.X_transform2(X)
        X = X.view(B*self.P, self.K, self.K)
        fts_X = torch.bmm(nn_fts_input.permute(0,2,1,3).contiguous().view(B*self.P, -1, self.K), X)
        fts_X = fts_X.view(B, self.P, -1, self.K).permute(0, 2, 1, 3) # B x C_delta+C_in x P x K

        fts_conv = self.conv(fts_X).squeeze(-1) # B x C_out x P 

        if self.with_global:
            fts_global = self.conv_global(qrs.unsqueeze(-1)).squeeze(-1) # B x C_out//4 x P
            return qrs, torch.cat([fts_global, fts_conv], dim=1)
        else :
            return qrs, fts_conv # B x C_out x P
    def forward(self, pc, feat):
        '''
        input
        ---------------
        pc : B x 3 x N
        feat : B x C x N

        output
        ----------------
        pc_sample : B x 3 x npoint
        feat_sample : B x outchannel x npoint 
        '''
        B, _, N = pc.size()
        idx = self.fps(pc.permute(0, 2, 1).contiguous())  # B x npoint
        idx = idx.unsqueeze(1).expand(B, 3, self.npoint)
        pc_sample = torch.gather(pc, 2, idx)  # B x 3 x npoint
        cat_feat = []

        for i in range(len(self.mlp_layers)):
            indices, _ = self.query_ball_point[i](pc.contiguous(),
                                                  pc_sample.contiguous())
            grouped_pc = md_utils._indices_group(
                pc, indices)  # B x 3 x npoint x nsample
            grouped_pc = grouped_pc - pc_sample.unsqueeze(-1).expand_as(
                grouped_pc)
            out_feat = grouped_pc.contiguous()
            if feat is not None:  # feat will be None in the first layer
                grouped_feat = md_utils._indices_group(
                    feat, indices)  # B x C x npoint x nsample
                out_feat = torch.cat([grouped_pc, grouped_feat],
                                     dim=1)  # B x C+3 x npoint x nsample
            out_feat = self.mlp_layers[i](out_feat)
            out_feat = torch.max(out_feat, -1)[0]  # B x C_out x npoint
            cat_feat.append(out_feat)

        cat_feat = torch.cat(cat_feat, dim=1)  # B x sum(mlp[-1]) x npoint

        return pc_sample, cat_feat
    def forward(self, feat, idx, group_pc):
        '''
        feat : B x in_channel x N
        idx(knn_indices) : B x N x k
        group_pc : B x 3 x N x k

        return:
        feat : B x out_channel x N
        '''
        B, in_channel, N = feat.size()
        _, _, k = idx.size()

        assert k == self.K_knn, 'illegal k'

        group_feat = md_utils._indices_group(feat,
                                             idx)  # B x inchannel x N x k

        X = group_pc[:, 0, :, :].unsqueeze(1)
        Y = group_pc[:, 1, :, :].unsqueeze(1)
        Z = group_pc[:, 2, :, :].unsqueeze(1)

        XX, YY, ZZ = X**2, Y**2, Z**2
        XXX, YYY, ZZZ = XX * X, YY * Y, ZZ * Z
        XY, XZ, YZ = X * Y, X * Z, Y * Z
        XXY, XXZ, YYZ, YYX, ZZX, ZZY, XYZ = X * XY, X * XZ, Y * YZ, Y * XY, Z * XZ, Z * YZ, XY * Z


        group_XYZ = torch.cat([
            X, Y, Z, XX, YY, ZZ, XXX, YYY, ZZZ,\
            XY, XZ, YZ, XXY, XXZ, YYZ, YYX, ZZX, ZZY, XYZ
        ], dim=1) # B x 20 x N x k

        taylor = self.conv1(group_XYZ)  # B x taylor_channel x N x k

        group_feat = group_feat.unsqueeze(2)  #B x inchannel x 1 x N x k
        taylor = taylor.unsqueeze(1)  # B x 1 x taylor_channel x N x k

        group_feat = torch.mul(group_feat, taylor).view(
            B, self.in_channel * self.taylor_channel, N, k)

        group_feat = self.conv2(group_feat)  # B x out_channel x N x 1

        group_feat = group_feat.squeeze(-1)

        return group_feat
        group_feat = torch.mul(group_feat, taylor).view(
            B, self.in_channel * self.taylor_channel, N, k)

        group_feat = self.conv2(group_feat)  # B x out_channel x N x 1

        group_feat = group_feat.squeeze(-1)

        return group_feat


if __name__ == '__main__':
    in_channel = 3
    out_channel = 6
    taylor_channel = 9
    k = 3
    batch_size = 3
    num_points = 10
    model = _BaseSpiderConv(in_channel, out_channel, taylor_channel,
                            batch_size, num_points, k)

    pc = torch.randn(batch_size, 3, num_points)
    feat = torch.randn(batch_size, in_channel, num_points)
    idx = md_utils._knn_indices(pc, k)
    group_pc = md_utils._indices_group(pc, idx)
    pc = pc.unsqueeze(-1).expand(batch_size, 3, num_points, k)
    group_pc = group_pc - pc

    output = model(feat, idx, group_pc)
    print(output.size())