示例#1
0
def edge_preserve_sampling(feature_input, point_input, num_samples, k=10):
    batch_size = feature_input.size()[0]
    feature_size = feature_input.size()[1]
    num_points = feature_input.size()[2]

    p_idx = pn2.furthest_point_sample(point_input, num_samples)
    point_output = pn2.gather_operation(point_input.transpose(1, 2).contiguous(), p_idx).transpose(1,
                                                                                                   2).contiguous()  # B M 3

    pk = int(min(k, num_points))
    _, pn_idx = knn_point(pk, point_input, point_output)
    pn_idx = pn_idx.detach().int()  # B M pk
    # print(pn_idx.size())

    # neighbor_feature = pn2.grouping_operation(feature_input, pn_idx)
    # neighbor_feature = index_points(feature_input.transpose(1,2).contiguous(), pn_idx).permute(0, 3, 1, 2)
    neighbor_feature = pn2.gather_operation(feature_input, pn_idx.view(batch_size, num_samples * pk)).view(batch_size,
                                                                                                           feature_size,
                                                                                                           num_samples,
                                                                                                           pk)
    neighbor_feature, _ = torch.max(neighbor_feature, 3)

    center_feature = pn2.grouping_operation(feature_input, p_idx.unsqueeze(2)).view(batch_size, -1, num_samples)

    net = torch.cat((center_feature, neighbor_feature), 1)

    return net, p_idx, pn_idx, point_output
示例#2
0
    def forward(self, xyz, features, sample_inds):
        """
        Args:
            xyz: (B,K,3)
            features: (B,C,K)
        """
        xyz_flipped = xyz.transpose(1, 2).contiguous()
        new_xyz = pointnet2_utils.gather_operation(xyz_flipped,
                                                   sample_inds).transpose(
                                                       1, 2).contiguous()
        new_features = pointnet2_utils.gather_operation(
            features, sample_inds).contiguous()

        return new_xyz, new_features, sample_inds
示例#3
0
    def forward(self, global_feat, point_input):
        batch_size = global_feat.size()[0]
        coarse = F.relu(self.fc1(global_feat))
        coarse = F.relu(self.fc2(coarse))
        coarse = self.fc3(coarse).view(batch_size, 3, self.num_coarse)

        if self.downsample_im:
            if self.mirror_im:
                org_points_input = symmetric_sample(
                    point_input.transpose(1, 2).contiguous(),
                    int((2048 - self.num_coarse) / 2))
                org_points_input = org_points_input.transpose(1,
                                                              2).contiguous()
            else:
                org_points_input = pn2.gather_operation(
                    point_input,
                    pn2.furthest_point_sample(
                        point_input.transpose(1, 2).contiguous(),
                        int(2048 - self.num_coarse)))
        else:
            org_points_input = point_input

        if self.points_label:
            id0 = torch.zeros(coarse.shape[0], 1,
                              coarse.shape[2]).cuda().contiguous()
            coarse_input = torch.cat((coarse, id0), 1)
            id1 = torch.ones(org_points_input.shape[0], 1,
                             org_points_input.shape[2]).cuda().contiguous()
            org_points_input = torch.cat((org_points_input, id1), 1)
            points = torch.cat((coarse_input, org_points_input), 2)
        else:
            points = torch.cat((coarse, org_points_input), 2)

        dense_feat = self.encoder(points)

        if self.scale >= 2:
            dense_feat = self.expansion(dense_feat)

        point_feat = F.relu(self.conv1(dense_feat))
        fine = self.conv2(point_feat)

        num_out = fine.size()[2]
        if num_out > self.num_fine:
            fine = pn2.gather_operation(
                fine,
                pn2.furthest_point_sample(
                    fine.transpose(1, 2).contiguous(), self.num_fine))

        return coarse, fine
示例#4
0
def get_uniform_loss(pcd, percentages=[0.004, 0.006, 0.008, 0.010, 0.012], radius=1.0):
    B, N, C = pcd.size()
    npoint = int(N * 0.05)
    loss = 0
    for p in percentages:
        nsample = int(N*p)
        r = math.sqrt(p*radius)
        disk_area = math.pi * (radius ** 2) * p/nsample
        new_xyz = pn2.gather_operation(pcd.transpose(1, 2).contiguous(),
                                       pn2.furthest_point_sample(pcd, npoint)).transpose(1, 2).contiguous()
        idx = pn2.ball_query(r, nsample, pcd, new_xyz)
        expect_len = math.sqrt(disk_area)

        grouped_pcd = pn2.grouping_operation(pcd.transpose(1,2).contiguous(), idx)
        grouped_pcd = grouped_pcd.permute(0, 2, 3, 1).contiguous().view(-1, nsample, 3)

        var, _ = knn_point(2, grouped_pcd, grouped_pcd)
        uniform_dis = -var[:, :, 1:]

        uniform_dis = torch.sqrt(torch.abs(uniform_dis+1e-8))
        uniform_dis = torch.mean(uniform_dis, dim=-1)
        uniform_dis = ((uniform_dis - expect_len)**2 / (expect_len + 1e-8))

        mean = torch.mean(uniform_dis)
        mean = mean*math.pow(p*100,2)
        loss += mean
    return loss/len(percentages)
示例#5
0
    def forward(self, xyz, features):
        """
        Args:
            xyz: (B,K,3)
            features: (B,C,K)
        """
        # Farthest point sampling (FPS)
        sample_inds = pointnet2_utils.furthest_point_sample(
            xyz, self.num_proposal)
        xyz_flipped = xyz.transpose(1, 2).contiguous()
        new_xyz = pointnet2_utils.gather_operation(xyz_flipped,
                                                   sample_inds).transpose(
                                                       1, 2).contiguous()
        new_features = pointnet2_utils.gather_operation(
            features, sample_inds).contiguous()

        return new_xyz, new_features, sample_inds
示例#6
0
def symmetric_sample(points, num=512):
    p1_idx = pn2.furthest_point_sample(points, num)
    input_fps = pn2.gather_operation(points.transpose(1, 2).contiguous(), p1_idx).transpose(1, 2).contiguous()
    x = torch.unsqueeze(input_fps[:, :, 0], dim=2)
    y = torch.unsqueeze(input_fps[:, :, 1], dim=2)
    z = torch.unsqueeze(-input_fps[:, :, 2], dim=2)
    input_fps_flip = torch.cat([x, y, z], dim=2)
    input_fps = torch.cat([input_fps, input_fps_flip], dim=1)
    return input_fps
示例#7
0
    def forward(self, xyz, normal, features=None):
        """
        Parameters
        ----------
        xyz : (B, N, 3) tensor of the xyz coordinates of the points
        xyz : (B, N, 3) tensor of the normal vectors of the points
        features : (B, N, C) tensor of the descriptors of the the points

        Returns
        -------
        new_xyz : (B, npoint, 3) tensor of the new points' xyz
        new_normal : (B, npoint, 3) tensor of the new points' normal
        new_features : (B, npoint, \sum_k(mlps[k][-1])) tensor of the new_points descriptors
        """
        new_features_list = []
        xyz_flipped = xyz.transpose(1, 2).contiguous()

        if self.npoint is not None:
            normal_flipped = normal.transpose(1, 2).contiguous()
            fps_idx = pointnet2_utils.furthest_point_sample(xyz, self.npoint)
            new_xyz = pointnet2_utils.gather_operation(xyz_flipped,
                                                       fps_idx).transpose(
                                                           1, 2).contiguous()
            new_normal = pointnet2_utils.gather_operation(
                normal_flipped, fps_idx).transpose(1, 2).contiguous()
            fps_idx = fps_idx.data
        else:
            # for global convolution
            new_xyz = torch.FloatTensor([
                0.0
            ]).cuda().unsqueeze(-1).unsqueeze(-1).expand(xyz.shape[0], 1,
                                                         3).contiguous()
            new_normal = None
            fps_idx = None

        for i in range(len(self.groupers)):
            if self.npoint is not None:
                new_features = self.groupers[i](xyz, new_xyz, normal, features,
                                                fps_idx)
            else:
                new_features = self.groupers[i](xyz, new_xyz, features)
            new_features = self.mlps[i]((new_features, new_normal))
            new_features_list.append(new_features)
        return new_xyz, new_normal, torch.cat(new_features_list, dim=1)
示例#8
0
    def forward(self, xyz, features=None):
        # type: (_PointnetSAModuleBase, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]
        r"""
        Parameters
        ----------
        xyz : torch.Tensor
            (B, N, 3) tensor of the xyz coordinates of the features
        features : torch.Tensor
            (B, N, C) tensor of the descriptors of the the features
        Returns
        -------
        new_xyz : torch.Tensor
            (B, npoint, 3) tensor of the new features' xyz
        new_features : torch.Tensor
            (B,  \sum_k(mlps[k][-1]), npoint) tensor of the new_features descriptors
        """

        new_features_list = []

        xyz_flipped = xyz.transpose(1, 2).contiguous()
        if self.npoint is not None:
            # check
            if self.npoint == xyz.size(1):
                fps_idx = torch.arange(self.npoint,
                                       device=xyz.device,
                                       dtype=torch.int).unsqueeze(0).expand(
                                           xyz.size(0),
                                           self.npoint).contiguous()
                new_xyz = xyz
            else:
                fps_idx = pointnet2_utils.furthest_point_sample(
                    xyz, self.npoint)  # (B, npoint)
                new_xyz = pointnet2_utils.gather_operation(
                    xyz_flipped, fps_idx).transpose(1, 2).contiguous()
                fps_idx = fps_idx.data
        else:
            new_xyz = None
            fps_idx = None

        for i in range(len(self.groupers)):
            new_features = self.groupers[i](xyz, new_xyz, features, fps_idx) if self.npoint is not None else \
            self.groupers[i](xyz, new_xyz, features)  # (B, C, npoint, nsample)

            new_features = self.mlps[i](
                new_features)  # (B, mlp[-1], npoint, nsample)
            new_features = F.max_pool2d(new_features,
                                        kernel_size=[
                                            1, new_features.size(3)
                                        ])  # (B, mlp[-1], npoint, 1)
            new_features = new_features.squeeze(-1)  # (B, mlp[-1], npoint)

            new_features = self.out_mlps[i](new_features)

            new_features_list.append(new_features)

        return new_xyz, torch.cat(new_features_list, dim=1)
示例#9
0
    def forward(self,
                xyz: torch.Tensor,
                features: torch.Tensor = None) -> (torch.Tensor, torch.Tensor):
        r"""
        Parameters
        ----------
        xyz : torch.Tensor
            (B, N, 3) tensor of the xyz coordinates of the points
        features : torch.Tensor
            (B, N, C) tensor of the descriptors of the the points

        Returns
        -------
        new_xyz : torch.Tensor
            (B, npoint, 3) tensor of the new points' xyz
        new_features : torch.Tensor
            (B, npoint, \sum_k(mlps[k][-1])) tensor of the new_points descriptors
        """

        new_features_list = []
        xyz_flipped = xyz.transpose(1, 2).contiguous()
        if self.npoint is not None:
            # check
            if self.npoint == xyz.size(1):
                fps_idx = torch.arange(self.npoint,
                                       device=xyz.device,
                                       dtype=torch.int).unsqueeze(0).expand(
                                           xyz.size(0),
                                           self.npoint).contiguous()
                new_xyz = xyz
            else:
                fps_idx = pointnet2_utils.furthest_point_sample(
                    xyz, self.npoint)  # (B, npoint)
                new_xyz = pointnet2_utils.gather_operation(
                    xyz_flipped, fps_idx).transpose(1, 2).contiguous()
                fps_idx = fps_idx.data
        else:
            new_xyz = None
            fps_idx = None

        for i in range(len(self.groupers)):
            new_features = self.groupers[i](
                xyz, new_xyz, features,
                fps_idx) if self.npoint is not None else self.groupers[i](
                    xyz, new_xyz, features)  # (B, C, npoint, nsample)
            new_features = self.mlps[i](new_features)  # (B, mlp[-1], npoint)

            new_features_list.append(new_features)

        if len(new_features_list) > 1:
            new_features_list = [f.unsqueeze(0) for f in new_features_list]
            final_feature, _ = torch.max(torch.cat(new_features_list, dim=0),
                                         dim=0)
            return new_xyz, final_feature
        else:
            return new_xyz, torch.cat(new_features_list, dim=1)
示例#10
0
def fps_sample(numpy_x, num_point):
    points = torch.from_numpy(numpy_x).float().view(1, -1, 6).cuda()
    xyz = points[:, :, :3].contiguous()
    fps_idx = pointnet2_utils.furthest_point_sample(xyz,
                                                    num_point)  # (B, npoint)
    points = pointnet2_utils.gather_operation(
        points.transpose(1, 2).contiguous(),
        fps_idx).transpose(1, 2).contiguous()  # (B, N, 3)
    points = points.data.cpu().numpy()
    return points[0]
示例#11
0
    def forward(self, xyz: torch.Tensor,
                features: torch.Tensor = None,
                inds: torch.Tensor = None) -> (torch.Tensor, torch.Tensor):
        r"""
        Parameters
        ----------
        xyz : torch.Tensor
            (B, N, 3) tensor of the xyz coordinates of the features
        features : torch.Tensor
            (B, C, N) tensor of the descriptors of the the features
        inds : torch.Tensor
            (B, npoint) tensor that stores index to the xyz points (values in 0-N-1)

        Returns
        -------
        new_xyz : torch.Tensor
            (B, npoint, 3) tensor of the new features' xyz
        new_features : torch.Tensor
            (B, \sum_k(mlps[k][-1]), npoint) tensor of the new_features descriptors
        inds: torch.Tensor
            (B, npoint) tensor of the inds
        """

        if not self.same_idx:
            xyz_flipped = xyz.transpose(1, 2).contiguous()
            if inds is None:
                inds = pointnet2_utils.furthest_point_sample(xyz, self.npoint)
            else:
                assert(inds.shape[1] == self.npoint)
            new_xyz = pointnet2_utils.gather_operation(
                xyz_flipped, inds
            ).transpose(1, 2).contiguous() if self.npoint is not None else None
        else:
            new_xyz = xyz

        if not self.ret_unique_cnt:
            #grouped_features, grouped_xyz = self.grouper(
            #    xyz, new_xyz, features
            #)  # (B, C, npoint, nsample)
            grouped_features, grouped_xyz, idx = self.grouper(
                xyz, new_xyz, features
            )  # (B, C, npoint, nsample)
        else:
            grouped_features, grouped_xyz, unique_cnt = self.grouper(
                xyz, new_xyz, features
            )  # (B, C, npoint, nsample), (B,3,npoint,nsample), (B,npoint)

        new_features = []
        for i in range(self.split):
            new_features.append(self.mlp_module[i](
                grouped_features
            ))  # (B, mlp[-1], npoint, nsample)

        new_features = torch.stack(new_features, 1)
        return new_xyz, new_features, idx, grouped_features
示例#12
0
    def forward(self,
                xyz: torch.Tensor,
                features: torch.Tensor = None) -> (torch.Tensor, torch.Tensor):
        r"""
        Parameters
        ----------
        xyz : torch.Tensor
            (B, N, 3) tensor of the xyz coordinates of the points
        features : torch.Tensor
            (B, N, C) tensor of the descriptors of the the points

        Returns
        -------
        new_xyz : torch.Tensor
            (B, npoint, 3) tensor of the new points' xyz
        new_features : torch.Tensor
            (B, npoint, \sum_k(mlps[k][-1])) tensor of the new_points descriptors
        """

        all_features = 0
        xyz_flipped = xyz.transpose(1, 2).contiguous()

        if self.npoint is not None:
            '''
            fps_idx = pointnet2_utils.furthest_point_sample(xyz, self.npoint) \
                      if self.pool else torch.from_numpy(np.arange(xyz.size(1))).int().cuda().repeat(xyz.size(0), 1)
            '''

            # random sampling
            if self.pool:
                fps_idx = np.random.randint(0,
                                            xyz.shape[1] - 1,
                                            size=[xyz.shape[0], self.npoint])
                fps_idx = torch.from_numpy(fps_idx).type(
                    torch.IntTensor).cuda()
            else:
                fps_idx = torch.from_numpy(np.arange(
                    xyz.size(1))).int().cuda().repeat(xyz.size(0), 1)

            new_xyz = pointnet2_utils.gather_operation(xyz_flipped,
                                                       fps_idx).transpose(
                                                           1, 2).contiguous()
        else:
            new_xyz = None

        for i in range(len(self.groupers)):
            new_features = self.groupers[i](
                xyz, new_xyz, features)  # (B, C, npoint, nsample)
            if not self.pool and self.npoint is not None:
                new_features = [new_features, features]
            new_features = self.mlps[i](new_features)  # (B, mlp[-1], npoint)
            all_features += new_features

        return new_xyz, all_features
示例#13
0
def train(train_dataloader, test_dataloader, model, criterion, optimizer,
          lr_scheduler, bnm_scheduler, args, num_batch):
    PointcloudScaleAndTranslate = d_utils.PointcloudScaleAndTranslate(
    )  # initialize augmentation
    global g_acc
    g_acc = 0.91  # only save the model whose acc > 0.91
    batch_count = 0
    model.train()
    for epoch in range(args.epochs):
        for i, data in enumerate(train_dataloader, 0):
            if lr_scheduler is not None:
                lr_scheduler.step(epoch)
            if bnm_scheduler is not None:
                bnm_scheduler.step(epoch - 1)
            points, target = data
            points, target = points.cuda(), target.cuda()
            points, target = Variable(points), Variable(target)

            # farthest point sampling
            # fps_idx = pointnet2_utils.furthest_point_sample(points, 1200)  # (B, npoint)

            # random sampling
            fps_idx = np.random.randint(0,
                                        points.shape[1] - 1,
                                        size=[points.shape[0], 1200])
            fps_idx = torch.from_numpy(fps_idx).type(torch.IntTensor).cuda()

            fps_idx = fps_idx[:,
                              np.random.choice(1200, args.num_points, False)]
            points = pointnet2_utils.gather_operation(
                points.transpose(1, 2).contiguous(),
                fps_idx).transpose(1, 2).contiguous()  # (B, N, 3)

            # augmentation
            points.data = PointcloudScaleAndTranslate(points.data)

            optimizer.zero_grad()

            pred = model(points)
            target = target.view(-1)
            loss = criterion(pred, target)
            loss.backward()
            optimizer.step()
            if i % args.print_freq_iter == 0:
                print(
                    '[epoch %3d: %3d/%3d] \t train loss: %0.6f \t lr: %0.5f' %
                    (epoch + 1, i, num_batch, loss.data.clone(),
                     lr_scheduler.get_lr()[0]))
            batch_count += 1

            # validation in between an epoch
            if args.evaluate and batch_count % int(
                    args.val_freq_epoch * num_batch) == 0:
                validate(test_dataloader, model, criterion, args, batch_count)
示例#14
0
    def forward(self, xyz: torch.Tensor, normal: torch.Tensor,
                features: torch.Tensor = None) -> (torch.Tensor, torch.Tensor):
        """
        Parameters
        ----------
        xyz : (B, N, 3) tensor of the xyz coordinates of the points
        xyz : (B, N, 3) tensor of the normal vectors of the points
        features : (B, N, C) tensor of the descriptors of the the points

        Returns
        -------
        new_xyz : (B, npoint, 3) tensor of the new points' xyz
        new_normal : (B, npoint, 3) tensor of the new points' normal
        new_features : (B, npoint, \sum_k(mlps[k][-1])) tensor of the new_points descriptors
        """
        new_features_list = []
        xyz_flipped = xyz.transpose(1, 2).contiguous()

        if self.npoint is not None:
            normal_flipped = normal.transpose(1, 2).contiguous()
            fps_idx = pointnet2_utils.furthest_point_sample(xyz, self.npoint)  # (B, npoint)
            new_xyz = pointnet2_utils.gather_operation(xyz_flipped, fps_idx).transpose(1, 2).contiguous()
            new_normal = pointnet2_utils.gather_operation(normal_flipped, fps_idx).transpose(1, 2).contiguous()
            fps_idx = fps_idx.data
        else:
            new_xyz = None
            new_normal = None
            fps_idx = None
        
        for i in range(len(self.groupers)):
            new_features = self.groupers[i](xyz, new_xyz, normal, features, fps_idx) if self.npoint is not None else self.groupers[i](xyz, new_xyz, features)  # (B, C, npoint, nsample)
            
            new_features = self.mlps[i](
                (new_features, new_normal)
            )  # (B, mlp[-1], npoint)

            new_features_list.append(new_features)
        
        return new_xyz, new_normal, torch.cat(new_features_list, dim=1)
    def forward(self,
                xyz: torch.Tensor,
                features: torch.Tensor = None,
                new_xyz=None) -> (torch.Tensor, torch.Tensor):
        """
        :param xyz: (B, N, 3) tensor of the xyz coordinates
        :param features: (B, C, N) tensor of point features
        :param new_xyz: 
        :return:
            new_xyz: (B, npoint, 3) tensor of the new xyz coordinates
            new_features: (B, \sum_k(mlps[k][-1]), npoint) tensor of the new point features, note: the channels = a sum of the last number of the mlps, e.g., mlps[[16,32],[32,64]] the result is 32+64
        """
        new_features_list = []

        # sample npoint points from N points for each batch using FPS algorithm asssuming (npoint<N)
        xyz_flipped = xyz.transpose(1, 2).contiguous()  # (B,3,N)
        if new_xyz is None:
            new_xyz = pointnet2_utils.gather_operation(
                features=xyz_flipped,
                idx=pointnet2_utils.furthest_point_sample(
                    xyz, self.npoint)  # (B,npoint) indexes
            ).transpose(1, 2).contiguous(
            ) if self.npoint is not None else None  # (B,npoint,3)

        for i in range(len(self.groupers)):
            new_features = self.groupers[i](
                xyz, new_xyz, features)  # (B, C, npoint, nsample)

            new_features = self.mlps[i](
                new_features)  # (B, mlp[i][-1], npoint, nsample)

            if self.pool_method == 'max_pool':
                new_features = F.max_pool2d(new_features,
                                            kernel_size=[
                                                1, new_features.size(3)
                                            ])  # (B, mlp[i][-1], npoint, 1)
            elif self.pool_method == 'avg_pool':
                new_features = F.avg_pool2d(new_features,
                                            kernel_size=[
                                                1, new_features.size(3)
                                            ])  # (B, mlp[i][-1], npoint, 1)
            else:
                raise NotImplementedError

            new_features = new_features.squeeze(-1)  # (B, mlp[i][-1], npoint)
            new_features_list.append(new_features)

        return new_xyz, torch.cat(
            new_features_list,
            dim=1)  # (B,npoint,3)  (B,\sum_k(mlp[k][-1]), npoint)
示例#16
0
    def forward(
        self,
        xyz: torch.Tensor,
        features: torch.Tensor = None,
        inds: torch.Tensor = None,
    ) -> (torch.Tensor, torch.Tensor):
        r"""
        Parameters
        ----------
        xyz : torch.Tensor
            (B, N, 3) tensor of the xyz coordinates of the features
        features : torch.Tensor
            (B, C, C) tensor of the descriptors of the the features
        inds : torch.Tensor
            (B, npoint) tensor that stores index to the xyz points (values in 0-N-1)

        Returns
        -------
        new_xyz : torch.Tensor
            (B, npoint, 3) tensor of the new features' xyz
        new_features : torch.Tensor
            (B, \sum_k(mlps[k][-1]), npoint) tensor of the new_features descriptors
        inds: torch.Tensor
            (B, npoint) tensor of the inds
        """
        new_features_list = []

        xyz_flipped = xyz.transpose(1, 2).contiguous()
        if inds is None:
            inds = pointnet2_utils.furthest_point_sample(xyz, self.npoint)
        new_xyz = (pointnet2_utils.gather_operation(
            xyz_flipped, inds).transpose(1, 2).contiguous()
                   if self.npoint is not None else None)

        for i in range(len(self.groupers)):
            new_features = self.groupers[i](
                xyz, new_xyz, features)  # (B, C, npoint, nsample)
            new_features = self.mlps[i](
                new_features)  # (B, mlp[-1], npoint, nsample)
            new_features = F.max_pool2d(new_features,
                                        kernel_size=[
                                            1, new_features.size(3)
                                        ])  # (B, mlp[-1], npoint, 1)
            new_features = new_features.squeeze(-1)  # (B, mlp[-1], npoint)

            new_features_list.append(new_features)

        return new_xyz, torch.cat(new_features_list, dim=1), inds
示例#17
0
    def forward(self, code, inputs, step_ratio, num_extract=512, mean_feature=None):
        '''
        :param code: B * C
        :param inputs: B * C * N
        :param step_ratio: int
        :param num_extract: int
        :param mean_feature: B * C
        :return: coarse(B * N * C), fine(B, N, C)
        '''
        coarse = torch.tanh(self.coarse_mlp(code))  # (32, 1536)
        coarse = coarse.view(-1, 512, 3)  # (32, 512, 3)
        coarse = coarse.transpose(2, 1).contiguous()  # (32, 3, 512)

        inputs_new = inputs.transpose(2, 1).contiguous()  # (32, 2048, 3)
        input_fps = symmetric_sample(inputs_new, int(num_extract/2))  # [32, 512,  3]
        input_fps = input_fps.transpose(2, 1).contiguous()  # [32, 3, 512]
        level0 = torch.cat([input_fps, coarse], 2)   # (32, 3, 1024)
        if num_extract > 512:
            level0_flipped = level0.transpose(2, 1).contiguous()
            level0 = pn2.gather_operation(level0, pn2.furthest_point_sample(level0_flipped, 1024))

        for i in range(int(math.log2(step_ratio))):
            num_fine = 2 ** (i + 1) * 1024
            grid = gen_grid_up(2 ** (i + 1)).cuda().contiguous()
            grid = torch.unsqueeze(grid, 0)   # (1, 2, 2)
            grid_feat = grid.repeat(level0.shape[0], 1, 1024)   # (32, 2, 2048)
            point_feat = torch.unsqueeze(level0, 3).repeat(1, 1, 1, 2)  # (32, 3, 1024, 2)
            point_feat = point_feat.view(-1, 3, num_fine)  # (32, 3, 2048)
            global_feat = torch.unsqueeze(code, 2).repeat(1, 1, num_fine)  # (32, 1024, 2048)

            if mean_feature is not None:
                mean_feature_use = F.relu(self.mean_fc(mean_feature))  #(32, 128)
                mean_feature_use = torch.unsqueeze(mean_feature_use, 2).repeat(1, 1, num_fine)  #(32, 128, 2048)
                feat = torch.cat([grid_feat, point_feat, global_feat, mean_feature_use], dim=1)  # (32, 1157, 2048)
                feat1 = F.relu(self.up_branch_mlp_conv_mf(feat))  # (32, 64, 2048)
            else:
                feat = torch.cat([grid_feat, point_feat, global_feat], dim=1)
                feat1 = F.relu(self.up_branch_mlp_conv_nomf(feat))  # (32, 64, 2048)

            feat2 = self.contract_expand(feat1) # (32, 64, 2048)
            feat = feat1 + feat2  # (32, 64, 2048)

            fine = self.fine_mlp_conv(feat) + point_feat  # (32, 3, 2048)
            level0 = fine

        return coarse.transpose(1, 2).contiguous(), fine.transpose(1, 2).contiguous()
示例#18
0
def validate(test_dataloader, model, criterion, args, iter):
    global g_acc
    model.eval()
    losses, preds, labels = [], [], []
    for j, data in enumerate(test_dataloader, 0):
        points, target = data
        points, target = points.cuda(), target.cuda()
        points, target = Variable(points,
                                  volatile=True), Variable(target,
                                                           volatile=True)

        # farthest point sampling
        # fps_idx = pointnet2_utils.furthest_point_sample(points, args.num_points)  # (B, npoint)

        # random sampling
        fps_idx = np.random.randint(0,
                                    points.shape[1] - 1,
                                    size=[points.shape[0], args.num_points])
        fps_idx = torch.from_numpy(fps_idx).type(torch.IntTensor).cuda()

        # fps_idx = fps_idx[:, np.random.choice(1200, args.num_points, False)]
        points = pointnet2_utils.gather_operation(
            points.transpose(1, 2).contiguous(),
            fps_idx).transpose(1, 2).contiguous()

        pred = model(points)
        target = target.view(-1)
        loss = criterion(pred, target)
        losses.append(loss.data.clone())
        _, pred_choice = torch.max(pred.data, -1)

        preds.append(pred_choice)
        labels.append(target.data)

    preds = torch.cat(preds, 0)
    labels = torch.cat(labels, 0)
    acc = (preds == labels).sum() / labels.numel()
    print('\nval loss: %0.6f \t acc: %0.6f\n' % (np.array(losses).mean(), acc))
    if acc > g_acc:
        g_acc = acc
        torch.save(
            model.state_dict(),
            '%s/cls_iter_%d_acc_%0.6f.pth' % (args.save_path, iter, acc))
        print('saved model with accuracy ', acc)
    model.train()
示例#19
0
    def forward(self,
                xyz: torch.Tensor,
                features: torch.Tensor = None,
                new_xyz=None) -> (torch.Tensor, torch.Tensor):
        """
        :param xyz: (B, N, 3) tensor of the xyz coordinates of the features
        :param features: (B, N, C) tensor of the descriptors of the the features
        :param new_xyz:
        :return:
            new_xyz: (B, npoint, 3) tensor of the new features' xyz
            new_features: (B, npoint, \sum_k(mlps[k][-1])) tensor of the new_features descriptors
        """
        new_features_list = []
        xyz_flipped = xyz.transpose(1, 2).contiguous()
        if new_xyz is None:
            new_xyz = pointnet2_utils.gather_operation(
                xyz_flipped,
                pointnet2_utils.furthest_point_sample(
                    xyz, self.npoint)).transpose(
                        1, 2).contiguous() if self.npoint is not None else None

        for i in range(len(self.groupers)):
            new_features = self.groupers[i](
                xyz, new_xyz, features)  # (B, C, npoint, nsample)
            new_features = self.mlps[i](
                new_features)  # (B, mlp[-1], npoint, nsample)
            if self.pool_method == 'max_pool':
                new_features = F.max_pool2d(new_features,
                                            kernel_size=[
                                                1, new_features.size(3)
                                            ])  # (B, mlp[-1], npoint, 1)
            elif self.pool_method == 'avg_pool':
                new_features = F.avg_pool2d(new_features,
                                            kernel_size=[
                                                1, new_features.size(3)
                                            ])  # (B, mlp[-1], npoint, 1)
            else:
                raise NotImplementedError

            new_features = new_features.squeeze(-1)  # (B, mlp[-1], npoint)
            new_features_list.append(new_features)

        return new_xyz, torch.cat(new_features_list, dim=1)
示例#20
0
文件: ecg.py 项目: paul007pl/VRCNet
    def forward(self, global_feat, point_input):
        batch_size = global_feat.size()[0]
        coarse = F.relu(self.fc1(global_feat))
        coarse = F.relu(self.fc2(coarse))
        coarse = self.fc3(coarse).view(batch_size, 3, self.num_coarse)
        org_points_input = point_input
        points = torch.cat((coarse, org_points_input), 2)

        dense_feat = self.encoder(points)

        if self.scale >= 2:
            dense_feat = self.expansion(dense_feat)

        point_feat = F.relu(self.conv1(dense_feat))
        fine = self.conv2(point_feat)

        num_out = fine.size()[2]
        if num_out > self.num_fine:
            fine = pn2.gather_operation(fine,
                                        pn2.furthest_point_sample(fine.transpose(1, 2).contiguous(), self.num_fine))

        return coarse, fine
示例#21
0
    def forward(self, inputs, gt, eps, iters, EMD=True, CD=True):
        cur_bs = inputs.size()[0]
        output1, output2 = self.model(inputs)
        gt = gt[:, :, :3]

        emd1 = emd2 = cd_p1 = cd_p2 = cd_t1 = cd_t2 = torch.tensor(
            [0], dtype=torch.float32).cuda()

        if EMD:
            num_coarse = self.model.num_coarse
            gt_fps = pn2.gather_operation(
                gt.transpose(1, 2).contiguous(),
                pn2.furthest_point_sample(gt, num_coarse)).transpose(
                    1, 2).contiguous()

            dist1, _ = self.EMD(output1, gt_fps, eps, iters)
            emd1 = torch.sqrt(dist1).mean(1)

            dist2, _ = self.EMD(output2, gt, eps, iters)
            emd2 = torch.sqrt(dist2).mean(1)

            # CD loss
        if CD:
            dist11, dist12, _, _ = chamLoss(gt, output1)
            cd_p1 = (torch.sqrt(dist11).mean(1) +
                     torch.sqrt(dist12).mean(1)) / 2
            cd_t1 = (dist11.mean(1) + dist12.mean(1))

            dist21, dist22, _, _ = chamLoss(gt, output2)
            cd_p2 = (torch.sqrt(dist21).mean(1) +
                     torch.sqrt(dist22).mean(1)) / 2
            cd_t2 = (dist21.mean(1) + dist22.mean(1))

        u1 = get_uniform_loss(output1)
        u2 = get_uniform_loss(output2)

        return output1, output2, emd1, emd2, cd_p1, cd_p2, cd_t1, cd_t2, u1, u2
示例#22
0
    def forward(self, xyzs: torch.Tensor, features: torch.Tensor = None) -> (torch.Tensor, torch.Tensor):
        """
        Args:
            xyzs: torch.Tensor
                 (B, L, N, 3) tensor of sequence of the xyz coordinates
            features: torch.Tensor
                 (B, L, C, N) tensor of sequence of the features
        """
        device = xyzs.get_device()

        nframes = xyzs.size(1)  # L
        npoints = xyzs.size(2)  # N

        if self.temporal_kernel_size > 1 and self.temporal_stride > 1:
            assert ((nframes + sum(self.temporal_padding) - self.temporal_kernel_size) % self.temporal_stride == 0), "PSTConv: Temporal parameter error!"

        xyzs = torch.split(tensor=xyzs, split_size_or_sections=1, dim=1)
        xyzs = [torch.squeeze(input=xyz, dim=1).contiguous() for xyz in xyzs]

        if self.in_planes != 0:
            features = torch.split(tensor=features, split_size_or_sections=1, dim=1)
            features = [torch.squeeze(input=feature, dim=1).contiguous() for feature in features]

        if self.padding_mode == "zeros":
            xyz_padding = torch.zeros(xyzs[0].size(), dtype=torch.float32, device=device)
            for i in range(self.temporal_padding[0]):
                xyzs = [xyz_padding] + xyzs
            for i in range(self.temporal_padding[1]):
                xyzs = xyzs + [xyz_padding]

            if self.in_planes != 0:
                feature_padding = torch.zeros(features[0].size(), dtype=torch.float32, device=device)
                for i in range(self.temporal_padding[0]):
                    features = [feature_padding] + features
                for i in range(self.temporal_padding[1]):
                    features = features + [feature_padding]
        else:   # "replicate"
            for i in range(self.temporal_padding[0]):
                xyzs = [xyzs[0]] + xyzs
            for i in range(self.temporal_padding[1]):
                xyzs = xyzs + [xyzs[-1]]

            if self.in_planes != 0:
                for i in range(self.temporal_padding[0]):
                    features = [features[0]] + features
                for i in range(self.temporal_padding[1]):
                    features = features + [features[-1]]

        new_xyzs = []
        new_features = []
        for t in range(self.temporal_radius, len(xyzs)-self.temporal_radius, self.temporal_stride):                                 # temporal anchor frames
            # spatial anchor point subsampling by FPS
            anchor_idx = pointnet2_utils.furthest_point_sample(xyzs[t], npoints//self.spatial_stride)                               # (B, N//self.spatial_stride)
            anchor_xyz_flipped = pointnet2_utils.gather_operation(xyzs[t].transpose(1, 2).contiguous(), anchor_idx)                 # (B, 3, N//self.spatial_stride)
            anchor_xyz_expanded = torch.unsqueeze(anchor_xyz_flipped, 3)                                                            # (B, 3, N//spatial_stride, 1)
            anchor_xyz = anchor_xyz_flipped.transpose(1, 2).contiguous()                                                            # (B, N//spatial_stride, 3)

            # spatial convolution
            spatial_features = []
            for i in range(t-self.temporal_radius, t+self.temporal_radius+1):
                neighbor_xyz = xyzs[i]

                idx = pointnet2_utils.ball_query(self.r, self.k, neighbor_xyz, anchor_xyz)

                neighbor_xyz_flipped = neighbor_xyz.transpose(1, 2).contiguous()                                                    # (B, 3, N)
                neighbor_xyz_grouped = pointnet2_utils.grouping_operation(neighbor_xyz_flipped, idx)                                # (B, 3, N//spatial_stride, k)

                displacement = neighbor_xyz_grouped - anchor_xyz_expanded                                                           # (B, 3, N//spatial_stride, k)
                displacement = self.spatial_conv_d(displacement)                                                                    # (B, mid_planes, N//spatial_stride, k)

                if self.in_planes != 0:
                    neighbor_feature_grouped = pointnet2_utils.grouping_operation(features[i], idx)                                 # (B, in_planes, N//spatial_stride, k)
                    feature = self.spatial_conv_f(neighbor_feature_grouped)                                                         # (B, mid_planes, N//spatial_stride, k)

                    if self.spatial_aggregation == "addition":
                        spatial_feature = feature + displacement
                    else:
                        spatial_feature = feature * displacement

                else:
                    spatial_feature = displacement

                if self.spatial_pooling == 'max':
                    spatial_feature, _ = torch.max(input=spatial_feature, dim=-1, keepdim=False)                                    # (B, mid_planes, N//spatial_stride)
                elif self.spatial_pooling == 'sum':
                    spatial_feature = torch.sum(input=spatial_feature, dim=-1, keepdim=False)                                       # (B, mid_planes, N//spatial_stride)
                else:
                    spatial_feature = torch.mean(input=spatial_feature, dim=-1, keepdim=False)                                      # (B, mid_planes, N//spatial_stride)

                spatial_features.append(spatial_feature)

            spatial_features = torch.cat(tensors=spatial_features, dim=1, out=None)                                                 # (B, temporal_kernel_size*mid_planes, N//spatial_stride)

            # batch norm and relu
            if self.batch_norm:
                spatial_features = self.batch_norm(spatial_features)

            spatial_features = self.relu(spatial_features)

            # temporal convolution
            spatio_temporal_feature = self.temporal(spatial_features)

            new_xyzs.append(anchor_xyz)
            new_features.append(spatio_temporal_feature)

        new_xyzs = torch.stack(tensors=new_xyzs, dim=1)
        new_features = torch.stack(tensors=new_features, dim=1)

        return new_xyzs, new_features
def main():
    args = parser.parse_args()
    with open(args.config) as f:
        config = yaml.load(f)
    for k, v in config['common'].items():
        setattr(args, k, v)
    
    test_transforms = transforms.Compose([
        d_utils.PointcloudToTensor()
    ])

    test_dataset = ModelNet40Cls(num_points = args.num_points, root = args.data_root, transforms=test_transforms, train=False)
    test_dataloader = DataLoader(
        test_dataset, 
        batch_size=args.batch_size,
        shuffle=False, 
        num_workers=int(args.workers), 
        pin_memory=True
    )
    
    model = DensePoint(num_classes = args.num_classes, input_channels = args.input_channels, use_xyz = True)
    model.cuda()
    
    if args.checkpoint is not '':
        model.load_state_dict(torch.load(args.checkpoint))
        print('Load model successfully: %s' % (args.checkpoint))
    
    # evaluate
    PointcloudScale = d_utils.PointcloudScale()   # initialize random scaling
    model.eval()
    global_acc = 0
    for i in range(NUM_REPEAT):
        preds = []
        labels = []

        s = time.time()
        for j, data in enumerate(test_dataloader, 0):
            points, target = data
            points, target = points.cuda(), target.cuda()
            points, target = Variable(points, volatile=True), Variable(target, volatile=True)
            # points [batch_size, num_points, dimensions], e.g., [256, 2048, 3]

            # furthest point sampling
            # fps_idx = pointnet2_utils.furthest_point_sample(points, 1200)  # (B, npoint)

            # random sampling
            fps_idx = np.random.randint(0, points.shape[1]-1, size=[points.shape[0], 1200])
            fps_idx = torch.from_numpy(fps_idx).type(torch.IntTensor).cuda()

            pred = 0
            for v in range(NUM_VOTE):
                new_fps_idx = fps_idx[:, np.random.choice(1200, args.num_points, False)]
                new_points = pointnet2_utils.gather_operation(points.transpose(1, 2).contiguous(), new_fps_idx).transpose(1, 2).contiguous()
                if v > 0:
                    new_points.data = PointcloudScale(new_points.data)
                pred += F.softmax(model(new_points), dim = 1)
            pred /= NUM_VOTE
            target = target.view(-1)
            _, pred_choice = torch.max(pred.data, -1)
            
            preds.append(pred_choice)
            labels.append(target.data)
        e = time.time()

        preds = torch.cat(preds, 0)
        labels = torch.cat(labels, 0)
        acc = (preds == labels).sum() / labels.numel()
        if acc > global_acc:
            global_acc = acc
        print('Repeat %3d \t Acc: %0.6f' % (i + 1, acc))
        print('time (secs) for 1 epoch: ', (e - s))
    print('\nBest voting acc: %0.6f' % (global_acc))
示例#24
0
    def forward(self, x, gt, is_training=True, mean_feature=None, alpha=None):
        num_input = x.size()[2]

        if is_training:
            y = pn2.gather_operation(
                gt.transpose(1, 2).contiguous(),
                pn2.furthest_point_sample(gt, num_input))
            gt = torch.cat([gt, gt], dim=0)
            points = torch.cat([x, y], dim=0)
            x = torch.cat([x, x], dim=0)
        else:
            points = x
        feat = self.encoder(points)

        if is_training:
            feat_x, feat_y = feat.chunk(2)
            o_x = self.posterior_infer2(self.posterior_infer1(feat_x))
            q_mu, q_std = torch.split(o_x, self.size_z, dim=1)
            o_y = self.prior_infer(feat_y)
            p_mu, p_std = torch.split(o_y, self.size_z, dim=1)
            q_std = F.softplus(q_std)
            p_std = F.softplus(p_std)
            q_distribution = torch.distributions.Normal(q_mu, q_std)
            p_distribution = torch.distributions.Normal(p_mu, p_std)
            p_distribution_fix = torch.distributions.Normal(
                p_mu.detach(), p_std.detach())
            m_distribution = torch.distributions.Normal(
                torch.zeros_like(p_mu), torch.ones_like(p_std))
            z_q = q_distribution.rsample()
            z_p = p_distribution.rsample()
            z = torch.cat([z_q, z_p], dim=0)
            feat = torch.cat([feat_x, feat_x], dim=0)

        else:
            o_x = self.posterior_infer2(self.posterior_infer1(feat))
            q_mu, q_std = torch.split(o_x, self.size_z, dim=1)
            q_std = F.softplus(q_std)
            q_distribution = torch.distributions.Normal(q_mu, q_std)
            p_distribution = q_distribution
            p_distribution_fix = p_distribution
            m_distribution = p_distribution
            z = q_distribution.rsample()

        feat += self.generator(z)

        coarse_raw, coarse_high, coarse, fine = self.decoder(feat, x)
        coarse_raw = coarse_raw.transpose(1, 2).contiguous()
        coarse_high = coarse_high.transpose(1, 2).contiguous()
        coarse = coarse.transpose(1, 2).contiguous()
        fine = fine.transpose(1, 2).contiguous()

        if is_training:
            if self.distribution_loss == 'MMD':
                z_m = m_distribution.rsample()
                z_q = q_distribution.rsample()
                z_p = p_distribution.rsample()
                z_p_fix = p_distribution_fix.rsample()
                dl_rec = self.mmd_loss(z_m, z_p)
                dl_g = self.mmd_loss2(z_q, z_p_fix)
            elif self.distribution_loss == 'KLD':
                dl_rec = torch.distributions.kl_divergence(
                    m_distribution, p_distribution)
                dl_g = torch.distributions.kl_divergence(
                    p_distribution_fix, q_distribution)
            else:
                raise NotImplementedError(
                    'Distribution loss is either MMD or KLD')

            if self.train_loss == 'cd':
                loss1, _ = calc_cd(coarse_raw, gt)
                loss2, _ = calc_cd(coarse_high, gt)
                loss3, _ = calc_cd(coarse, gt)
                loss4, _ = calc_cd(fine, gt)
            else:
                raise NotImplementedError('Only CD is supported')

            total_train_loss = loss1.mean() * 10 + loss2.mean(
            ) * 0.5 + loss3.mean() + loss4.mean() * alpha
            total_train_loss += (dl_rec.mean() + dl_g.mean()) * 20
            return fine, loss4, total_train_loss
        else:
            emd = calc_emd(fine, gt, eps=0.004, iterations=3000)
            cd_p, cd_t, f1 = calc_cd(fine, gt, calc_f1=True)
            return {
                'out1': coarse_raw,
                'out2': fine,
                'emd': emd,
                'cd_p': cd_p,
                'cd_t': cd_t,
                'f1': f1
            }
示例#25
0
    def forward(self, global_feat, point_input):
        batch_size = global_feat.size()[0]

        coarse_raw = self.fc3(self.af(self.fc2(self.af(
            self.fc1(global_feat))))).view(batch_size, 3, self.num_coarse_raw)

        input_points_num = point_input.size()[2]
        org_points_input = point_input

        if self.points_label:
            id0 = torch.zeros(coarse_raw.shape[0], 1,
                              coarse_raw.shape[2]).cuda().contiguous()
            coarse_input = torch.cat((coarse_raw, id0), 1)
            id1 = torch.ones(org_points_input.shape[0], 1,
                             org_points_input.shape[2]).cuda().contiguous()
            org_points_input = torch.cat((org_points_input, id1), 1)
        else:
            coarse_input = coarse_raw

        points = torch.cat((coarse_input, org_points_input), 2)
        dense_feat = self.encoder(points)

        if self.up_scale >= 2:
            dense_feat = self.expansion1(dense_feat)

        coarse_features = self.af(self.conv_cup1(dense_feat))
        coarse_high = self.conv_cup2(coarse_features)

        if coarse_high.size()[2] > self.num_fps:
            idx_fps = pn2.furthest_point_sample(
                coarse_high.transpose(1, 2).contiguous(), self.num_fps)
            coarse_fps = pn2.gather_operation(coarse_high, idx_fps)
            coarse_features = pn2.gather_operation(coarse_features, idx_fps)
        else:
            coarse_fps = coarse_high

        if coarse_fps.size()[2] > self.num_coarse:
            scores = F.softplus(
                self.conv_s3(
                    self.af(
                        self.conv_s2(self.af(self.conv_s1(coarse_features))))))
            idx_scores = scores.topk(k=self.num_coarse,
                                     dim=2)[1].view(batch_size, -1).int()
            coarse = pn2.gather_operation(coarse_fps, idx_scores)
            coarse_features = pn2.gather_operation(coarse_features, idx_scores)
        else:
            coarse = coarse_fps

        if coarse.size()[2] < self.num_fine:
            if self.local_folding:
                up_features = self.expansion2(coarse_features, global_feat)
                center = coarse.transpose(
                    2, 1).contiguous().unsqueeze(2).repeat(
                        1, 1, self.num_fine // self.num_coarse,
                        1).view(batch_size, self.num_fine,
                                3).transpose(2, 1).contiguous()
                fine = self.conv_f2(self.af(
                    self.conv_f1(up_features))) + center
            else:
                up_features = self.expansion2(coarse_features)
                fine = self.conv_f2(self.af(self.conv_f1(up_features)))
        else:
            assert (coarse.size()[2] == self.num_fine)
            fine = coarse

        return coarse_raw, coarse_high, coarse, fine
示例#26
0
    def forward(self,
                xyz: torch.Tensor,
                features: torch.Tensor = None,
                inds: torch.Tensor = None) -> (torch.Tensor, torch.Tensor):
        r"""
        Parameters
        ----------
        xyz : torch.Tensor
            (B, N, 3) tensor of the xyz coordinates of the features
        features : torch.Tensor
            (B, C, N) tensor of the descriptors of the the features
        inds : torch.Tensor
            (B, npoint) tensor that stores index to the xyz points (values in 0-N-1)

        Returns
        -------
        new_xyz : torch.Tensor
            (B, npoint, 3) tensor of the new features' xyz
        new_features : torch.Tensor
            (B, \sum_k(mlps[k][-1]), npoint) tensor of the new_features descriptors
        inds: torch.Tensor
            (B, npoint) tensor of the inds
        """

        xyz_flipped = xyz.transpose(1, 2).contiguous()
        if inds is None:
            inds = pointnet2_utils.furthest_point_sample(xyz, self.npoint)
        else:
            assert (inds.shape[1] == self.npoint)
        new_xyz = pointnet2_utils.gather_operation(
            xyz_flipped, inds).transpose(
                1, 2).contiguous() if self.npoint is not None else None

        if not self.ret_unique_cnt:
            grouped_features, grouped_xyz = self.grouper(
                xyz, new_xyz, features)  # (B, C, npoint, nsample)
        else:
            grouped_features, grouped_xyz, unique_cnt = self.grouper(
                xyz, new_xyz, features
            )  # (B, C, npoint, nsample), (B,3,npoint,nsample), (B,npoint)

        new_features = self.mlp_module(
            grouped_features)  # (B, mlp[-1], npoint, nsample)
        if self.pooling == 'max':
            new_features = F.max_pool2d(new_features,
                                        kernel_size=[
                                            1, new_features.size(3)
                                        ])  # (B, mlp[-1], npoint, 1)
        elif self.pooling == 'avg':
            new_features = F.avg_pool2d(new_features,
                                        kernel_size=[
                                            1, new_features.size(3)
                                        ])  # (B, mlp[-1], npoint, 1)
        elif self.pooling == 'rbf':
            # Use radial basis function kernel for weighted sum of features (normalized by nsample and sigma)
            # Ref: https://en.wikipedia.org/wiki/Radial_basis_function_kernel
            rbf = torch.exp(-1 * grouped_xyz.pow(2).sum(1, keepdim=False) /
                            (self.sigma**2) / 2)  # (B, npoint, nsample)
            new_features = torch.sum(
                new_features * rbf.unsqueeze(1), -1, keepdim=True) / float(
                    self.nsample)  # (B, mlp[-1], npoint, 1)
        new_features = new_features.squeeze(-1)  # (B, mlp[-1], npoint)

        if not self.ret_unique_cnt:
            return new_xyz, new_features, inds
        else:
            return new_xyz, new_features, inds, unique_cnt
示例#27
0
    def forward(self,
                partial_cloud,
                gt,
                is_training=True,
                mean_feature=None,
                alpha=None):
        pt_features_64_l = self.gridding(partial_cloud).view(-1, 1, 64, 64, 64)
        pt_features_32_l = self.conv1(pt_features_64_l)
        pt_features_16_l = self.conv2(pt_features_32_l)
        pt_features_8_l = self.conv3(pt_features_16_l)
        pt_features_4_l = self.conv4(pt_features_8_l)
        features = self.fc5(pt_features_4_l.view(-1, 16384))
        pt_features_4_r = self.fc6(features).view(-1, 256, 4, 4,
                                                  4) + pt_features_4_l
        pt_features_8_r = self.dconv7(pt_features_4_r) + pt_features_8_l
        pt_features_16_r = self.dconv8(pt_features_8_r) + pt_features_16_l
        pt_features_32_r = self.dconv9(pt_features_16_r) + pt_features_32_l
        pt_features_64_r = self.dconv10(pt_features_32_r) + pt_features_64_l
        sparse_cloud = self.gridding_rev(pt_features_64_r.squeeze(dim=1))
        sparse_cloud = self.point_sampling(sparse_cloud, partial_cloud)
        point_features_32 = self.feature_sampling(sparse_cloud,
                                                  pt_features_32_r).view(
                                                      -1, 2048, 256)
        point_features_16 = self.feature_sampling(sparse_cloud,
                                                  pt_features_16_r).view(
                                                      -1, 2048, 512)
        point_features_8 = self.feature_sampling(sparse_cloud,
                                                 pt_features_8_r).view(
                                                     -1, 2048, 1024)
        point_features = torch.cat(
            [point_features_32, point_features_16, point_features_8], dim=2)
        point_features = self.fc11(point_features)
        point_features = self.fc12(point_features)
        point_features = self.fc13(point_features)
        point_offset = self.fc14(point_features).view(-1, 16384, 3)
        dense_cloud = sparse_cloud.unsqueeze(dim=2).repeat(1, 1, 8, 1).view(
            -1, 16384, 3) + point_offset
        if self.num_points < 16384:
            idx_fps = pn2.furthest_point_sample(dense_cloud, self.num_points)
            dense_cloud = pn2.gather_operation(dense_cloud, idx_fps)

        if is_training:
            if self.train_loss == 'emd':
                loss1 = calc_emd(sparse_cloud, gt)
                loss2 = calc_emd(dense_cloud, gt)
            elif self.train_loss == 'cd':
                _, loss1 = calc_cd(sparse_cloud, gt)  # cd_t
                _, loss2 = calc_cd(dense_cloud, gt)  # cd_t
            else:
                raise NotImplementedError('Train loss is either CD or EMD!')

            total_train_loss = loss1.mean() + loss2.mean()
            return dense_cloud, loss2, total_train_loss
        else:
            emd = calc_emd(dense_cloud, gt, eps=0.004, iterations=3000)
            cd_p, cd_t, f1 = calc_cd(dense_cloud, gt, calc_f1=True)
            return {
                'out1': dense_cloud,
                'out2': dense_cloud,
                'emd': emd,
                'cd_p': cd_p,
                'cd_t': cd_t,
                'f1': f1
            }