Beispiel #1
0
    def compute_loss(self, outputs1, outputs2, gt_points, normals, img=None):
        r''' Returns the complete loss.

        The full loss is adopted from the authors' implementation and
            consists of
                a.) Chamfer distance loss
                b.) edge length loss
                c.) normal loss
                d.) Laplacian loss
                e.) move loss

        Arguments:
            outputs1 (list): first outputs of model
            outputs2 (list): second outputs of model
            gt_points (tensor): ground truth point cloud locations
            normals (tensor): normals of the ground truth point cloud
            img (tensor): input images
        '''
        pred_vertices_1, pred_vertices_2, pred_vertices_3 = outputs1

        # Chamfer Distance Loss
        lc11, lc12, id11, id12 = chamfer_distance(pred_vertices_1,
                                                  gt_points,
                                                  give_id=True)
        lc21, lc22, id21, id22 = chamfer_distance(pred_vertices_2,
                                                  gt_points,
                                                  give_id=True)
        lc31, lc32, id31, id32 = chamfer_distance(pred_vertices_3,
                                                  gt_points,
                                                  give_id=True)
        l_c = lc11.mean() + lc21.mean() + lc31.mean()
        l_c2 = lc12.mean() + lc22.mean() + lc32.mean()
        l_c = (l_c2 + self.param_chamfer_rel * l_c) * self.param_chamfer_w

        # Edge Length Loss
        l_e = (self.edge_length_loss(pred_vertices_1, 1) +
               self.edge_length_loss(pred_vertices_2, 2) +
               self.edge_length_loss(pred_vertices_3, 3)) * self.param_edge

        # Normal Loss
        l_n = (
            self.normal_loss(pred_vertices_1, normals, id11, 1) +
            self.normal_loss(pred_vertices_2, normals, id21, 2) +
            self.normal_loss(pred_vertices_3, normals, id31, 3)) * self.param_n

        # Laplacian Loss and move loss
        l_l1, _ = self.laplacian_loss(pred_vertices_1, outputs2[0], block_id=1)
        l_l2, move_loss1 = self.laplacian_loss(pred_vertices_2,
                                               outputs2[1],
                                               block_id=2)
        l_l3, move_loss2 = self.laplacian_loss(pred_vertices_3,
                                               outputs2[2],
                                               block_id=3)
        l_l = (self.param_lap_rel * l_l1 + l_l2 + l_l3) * self.param_lap
        l_m = (move_loss1 + move_loss2) * self.param_move

        # Final loss
        loss = l_c + l_e + l_n + l_l + l_m
        return loss
Beispiel #2
0
    def eval_step(self, data):
        r''' Performs an evaluation step.

        The chamfer loss is calculated and returned in a dictionary.

        Args:
            data (tensor): input data
        '''
        self.model.eval()

        device = self.device

        points = data.get('pointcloud_chamfer').to(device)
        inputs = data.get('inputs').to(device)

        with torch.no_grad():
            points_out = self.model(inputs)

        loss = chamfer_distance(points, points_out).mean()
        loss = loss.item()
        eval_dict = {
            'loss': loss,
            'chamfer': loss,
        }

        return eval_dict
Beispiel #3
0
    def eval_step(self, data):
        r''' Performs an evaluation step.

        The chamfer loss is calculated and returned in a dictionary.

        Args:
            data (tensor): input data
        '''
        self.model.eval()

        device = self.device

        evaluator = MeshEvaluator(n_points=100000)

        points = data.get('pointcloud_chamfer').to(device)
        inputs = data.get('inputs').to(device)

        with torch.no_grad():
            points_out, _ = self.model(inputs)

        batch_size = points.shape[0]
        points_np = points.cpu()
        points_out_np = points_out.cpu().numpy()

        loss = chamfer_distance(points, points_out).mean()
        loss = loss.item()
        eval_dict = {
            'loss': loss,
        }

        return eval_dict
Beispiel #4
0
    def compute_loss(self, points, inputs):
        r''' Computes the loss.

        The Point Set Generation Network is trained on the Chamfer distance.

        Args:
            points (tensor): GT point cloud data
            inputs (tensor): input data for the model
        '''
        points_out = self.model(inputs)
        loss = chamfer_distance(points, points_out).mean()
        return loss
Beispiel #5
0
    def eval_step(self, data):
        r''' Performs an evaluation step.

        Arguments:
            data (tensor): input data
        '''
        self.model.eval()
        points = data.get('pointcloud').to(self.device)
        img = data.get('inputs').to(self.device)
        normals = data.get('pointcloud.normals').to(self.device)

        # Transform GT points to camera coordinates
        camera_args = common.get_camera_args(data,
                                             'pointcloud.loc',
                                             'pointcloud.scale',
                                             device=self.device)
        world_mat, camera_mat = camera_args['Rt'], camera_args['K']
        points_transformed = common.transform_points(points, world_mat)
        # Transform GT normals to camera coordinates
        world_normal_mat = world_mat[:, :, :3]
        normals = common.transform_points(normals, world_normal_mat)

        with torch.no_grad():
            outputs1, outputs2 = self.model(img, camera_mat)

        pred_vertices_1, pred_vertices_2, pred_vertices_3 = outputs1

        loss = self.compute_loss(outputs1, outputs2, points_transformed,
                                 normals, img)
        lc1, lc2, id31, id32 = chamfer_distance(pred_vertices_3,
                                                points_transformed,
                                                give_id=True)
        l_c = (lc1 + lc2).mean()
        l_e = self.edge_length_loss(pred_vertices_3, 3)
        l_n = self.normal_loss(pred_vertices_3, normals, id31, 3)
        l_l, move_loss = self.laplacian_loss(pred_vertices_3,
                                             outputs2[2],
                                             block_id=3)

        eval_dict = {
            'loss': loss.item(),
            'chamfer': l_c.item(),
            'edge': l_e.item(),
            'normal': l_n.item(),
            'laplace': l_l.item(),
            'move': move_loss.item()
        }
        return eval_dict
    def compute_loss(self, points, inputs):
        r''' Computes the loss.

        The Point Set Generation Network is trained on the Chamfer distance.

        Args:
            points (tensor): GT point cloud data
            inputs (tensor): input data for the model
        '''
        batch_size, n_steps, n_pts, dim = points.shape

        points_out = self.model(inputs)
        if self.loss_corr:
            n_pts_pred = points_out.shape[2]
            points_out = points_out.transpose(1, 2).contiguous().view(
                batch_size, n_pts_pred, -1)
            points = points.transpose(1, 2).contiguous().view(
                batch_size, n_pts, -1)
        else:
            points_out = points_out.contiguous().view(batch_size * n_steps, -1,
                                                      dim)
            points = points.contiguous().view(-1, n_pts, dim)
        loss = chamfer_distance(points, points_out).mean()
        return loss
Beispiel #7
0
    def compute_loss(self, data):
        ''' Computes the loss.

        Args:
            data (dict): data dictionary
        '''
        device = self.device

        encoder_inputs, raw_data = compose_inputs(
            data,
            mode='train',
            device=self.device,
            input_type=self.input_type,
            depth_pointcloud_transfer=self.depth_pointcloud_transfer,
        )

        world_mat = None
        if (self.model.encoder_world_mat is not None) \
            or self.gt_pointcloud_transfer in ('view', 'view_scale_model'):
            if 'world_mat' in raw_data:
                world_mat = raw_data['world_mat']
            else:
                world_mat = get_world_mat(data, device=device)
        gt_pc = compose_pointcloud(data,
                                   device,
                                   self.gt_pointcloud_transfer,
                                   world_mat=world_mat)

        if self.model.encoder_world_mat is not None:
            out = self.model(encoder_inputs, world_mat=world_mat)
        else:
            out = self.model(encoder_inputs)

        loss = 0
        if isinstance(out, tuple):
            out, trans_feat = out

            if isinstance(self.model.encoder, PointNetEncoder) or isinstance(
                    self.model.encoder, PointNetResEncoder):
                loss = loss + 0.001 * feature_transform_reguliarzer(trans_feat)

        # chamfer distance loss
        if self.loss_type == 'cd':
            loss = loss + chamfer_distance(out, gt_pc).mean()
        else:
            out_pts_count = out.size(1)
            loss = loss + (emd.earth_mover_distance(
                out, gt_pc, transpose=False) / out_pts_count).mean()

        # view penalty loss
        if self.view_penalty:
            gt_mask_flow = data.get('inputs.mask_flow').to(
                device)  # B * 1 * H * W
            if world_mat is None:
                world_mat = get_world_mat(data, device=device)
            camera_mat = get_camera_mat(data, device=device)

            # projection use world mat & camera mat
            if self.gt_pointcloud_transfer == 'world_scale_model':
                out_pts = transform_points(out, world_mat)
            elif self.gt_pointcloud_transfer == 'view_scale_model':
                t = world_mat[:, :, 3:]
                out_pts = out_pts + t
            elif self.gt_pointcloud_transfer == 'view':
                t = world_mat[:, :, 3:]
                out_pts = out_pts * t[:, 2:, :]
                out_pts = out_pts + t
            else:
                raise NotImplementedError

            out_pts_img = project_to_camera(out_pts, camera_mat)
            out_pts_img = out_pts_img.unsqueeze(1)  # B * 1 * n_pts * 2

            out_mask_flow = F.grid_sample(gt_mask_flow,
                                          out_pts_img)  # B * 1 * 1 * n_pts
            loss_mask_flow = F.relu(1. - out_mask_flow, inplace=True).mean()
            loss = loss + self.loss_mask_flow_ratio * loss_mask_flow

            if self.view_penalty == 'mask_flow_and_depth':
                # depth test loss
                t_scale = world_mat[:, 2, 3].view(world_mat.size(0), 1, 1, 1)
                gt_mask = data.get('inputs.mask').byte().to(device)
                depth_pred = data.get('inputs.depth_pred').to(
                    device) * t_scale  # absolute depth from view
                background_setting(depth_pred, gt_mask)
                depth_z = out_pts[:, :, 2:].transpose(1, 2)
                corresponding_z = F.grid_sample(
                    depth_pred, out_pts_img)  # B * 1 * 1 * n_pts
                corresponding_z = corresponding_z.squeeze(1)

                # eps
                loss_depth_test = F.relu(depth_z - self.depth_test_eps -
                                         corresponding_z,
                                         inplace=True).mean()
                loss = loss + self.loss_depth_test_ratio * loss_depth_test

        return loss
Beispiel #8
0
    def eval_step(self, data):
        ''' Performs an evaluation step.

        Args:
            data (dict): data dictionary
        '''
        self.model.eval()

        device = self.device

        encoder_inputs, raw_data = compose_inputs(
            data,
            mode='train',
            device=self.device,
            input_type=self.input_type,
            depth_pointcloud_transfer=self.depth_pointcloud_transfer)
        world_mat = None
        if (self.model.encoder_world_mat is not None) \
            or self.gt_pointcloud_transfer in ('view', 'view_scale_model'):
            if 'world_mat' in raw_data:
                world_mat = raw_data['world_mat']
            else:
                world_mat = get_world_mat(data, device=device)
        gt_pc = compose_pointcloud(data,
                                   device,
                                   self.gt_pointcloud_transfer,
                                   world_mat=world_mat)
        batch_size = gt_pc.size(0)

        with torch.no_grad():
            if self.model.encoder_world_mat is not None:
                out = self.model(encoder_inputs, world_mat=world_mat)
            else:
                out = self.model(encoder_inputs)

            if isinstance(out, tuple):
                out, trans_feat = out

            eval_dict = {}
            if batch_size == 1:
                pointcloud_hat = out.cpu().squeeze(0).numpy()
                pointcloud_gt = gt_pc.cpu().squeeze(0).numpy()

                eval_dict = self.mesh_evaluator.eval_pointcloud(
                    pointcloud_hat, pointcloud_gt)

            # chamfer distance loss
            if self.loss_type == 'cd':
                loss = chamfer_distance(out, gt_pc)
            else:
                loss = emd.earth_mover_distance(out, gt_pc, transpose=False)

            if self.gt_pointcloud_transfer in ('world_scale_model',
                                               'view_scale_model', 'view'):
                pointcloud_scale = data.get('pointcloud.scale').to(
                    device).view(batch_size, 1, 1)
                loss = loss / (pointcloud_scale**2)
                if self.gt_pointcloud_transfer == 'view':
                    if world_mat is None:
                        world_mat = get_world_mat(data, device=device)
                    t_scale = world_mat[:, 2:, 3:]
                    loss = loss * (t_scale**2)

            if self.loss_type == 'cd':
                loss = loss.mean()
                eval_dict['chamfer'] = loss.item()
            else:
                out_pts_count = out.size(1)
                loss = (loss / out_pts_count).mean()
                eval_dict['emd'] = loss.item()

            # view penalty loss
            if self.view_penalty:
                gt_mask_flow = data.get('inputs.mask_flow').to(
                    device)  # B * 1 * H * W
                if world_mat is None:
                    world_mat = get_world_mat(data, device=device)
                camera_mat = get_camera_mat(data, device=device)

                # projection use world mat & camera mat
                if self.gt_pointcloud_transfer == 'world_scale_model':
                    out_pts = transform_points(out, world_mat)
                elif self.gt_pointcloud_transfer == 'view_scale_model':
                    t = world_mat[:, :, 3:]
                    out_pts = out_pts + t
                elif self.gt_pointcloud_transfer == 'view':
                    t = world_mat[:, :, 3:]
                    out_pts = out_pts * t[:, 2:, :]
                    out_pts = out_pts + t
                else:
                    raise NotImplementedError

                out_pts_img = project_to_camera(out_pts, camera_mat)
                out_pts_img = out_pts_img.unsqueeze(1)  # B * 1 * n_pts * 2

                out_mask_flow = F.grid_sample(gt_mask_flow,
                                              out_pts_img)  # B * 1 * 1 * n_pts
                loss_mask_flow = F.relu(1. - out_mask_flow,
                                        inplace=True).mean()
                loss = self.loss_mask_flow_ratio * loss_mask_flow
                eval_dict['loss_mask_flow'] = loss.item()

                if self.view_penalty == 'mask_flow_and_depth':
                    # depth test loss
                    t_scale = world_mat[:, 2, 3].view(world_mat.size(0), 1, 1,
                                                      1)
                    gt_mask = data.get('inputs.mask').byte().to(device)
                    depth_pred = data.get('inputs.depth_pred').to(
                        device) * t_scale

                    background_setting(depth_pred, gt_mask)
                    depth_z = out_pts[:, :, 2:].transpose(1, 2)
                    corresponding_z = F.grid_sample(
                        depth_pred, out_pts_img)  # B * 1 * 1 * n_pts
                    corresponding_z = corresponding_z.squeeze(1)

                    # eps = 0.05
                    loss_depth_test = F.relu(depth_z - self.depth_test_eps -
                                             corresponding_z,
                                             inplace=True).mean()
                    loss = self.loss_depth_test_ratio * loss_depth_test
                    eval_dict['loss_depth_test'] = loss

        return eval_dict
Beispiel #9
0
    def eval_step_full(self, data):
        r''' Performs an evaluation step.

        The chamfer loss is calculated and returned in a dictionary.

        Args:
            data (tensor): input data
        '''
        self.model.eval()

        device = self.device

        evaluator = MeshEvaluator(n_points=100000)

        points = data.get('pointcloud_chamfer').to(device)
        normals = data.get('pointcloud_chamfer.normals')
        inputs = data.get('inputs').to(device)

        with torch.no_grad():
            points_out, _ = self.model(inputs)

        batch_size = points.shape[0]
        points_np = points.cpu()
        points_out_np = points_out.cpu().numpy()

        completeness_list = []
        accuracy_list = []
        completeness2_list = []
        accuracy2_list = []
        chamfer_L1_list = []
        chamfer_L2_list = []
        edge_chamferL1_list = []
        edge_chamferL2_list = []
        emd_list = []
        fscore_list = []

        for idx in range(batch_size):
            eval_dict_pcl = evaluator.eval_pointcloud(points_out_np[idx],
                                                      points_np[idx],
                                                      normals_tgt=normals[idx])
            completeness_list.append(eval_dict_pcl['completeness'])
            accuracy_list.append(eval_dict_pcl['accuracy'])
            completeness2_list.append(eval_dict_pcl['completeness2'])
            accuracy2_list.append(eval_dict_pcl['accuracy2'])
            chamfer_L1_list.append(eval_dict_pcl['chamfer-L1'])
            chamfer_L2_list.append(eval_dict_pcl['chamfer-L2'])
            edge_chamferL1_list.append(eval_dict_pcl['edge-chamfer-L1'])
            edge_chamferL2_list.append(eval_dict_pcl['edge-chamfer-L2'])
            emd_list.append(eval_dict_pcl['emd'])
            fscore_list.append(eval_dict_pcl['fscore'])

        completeness = mean(completeness_list).item()
        accuracy = mean(accuracy_list).item()
        completeness2 = mean(completeness2_list).item()
        accuracy2 = mean(accuracy2_list).item()
        chamfer_L1 = mean(chamfer_L1_list).item()
        chamfer_L2 = mean(chamfer_L2_list).item()
        edge_chamferL1 = mean(edge_chamferL1_list).item()
        edge_chamferL2 = mean(edge_chamferL2_list).item()
        emd = mean(emd_list).item()
        fscore = mean(fscore_list)

        loss = chamfer_distance(points, points_out).mean()
        loss = loss.item()
        eval_dict = {
            'loss': loss,
            'chamfer': loss,
            'completeness': completeness,
            'completeness2': completeness2,
            'accuracy': accuracy,
            'accuracy2': accuracy2,
            'chamfer-L1': chamfer_L1,
            'chamfer-L2': chamfer_L2,
            'edge-chamfer-L2': edge_chamferL2,
            'edge-chamfer-L1': edge_chamferL1,
            'emd': emd,
            'fscore': fscore
        }

        return eval_dict
Beispiel #10
0
    def plot_dot_step(self, data, epoch, m):

        self.model.train()
        anc_points = data.get('pointcloud').to(self.device)
        anc_inputs = data.get('inputs').to(self.device)

        pos_inputs = data.get('inputs.Bias').to(self.device)

        anc_points_out, anc_feature = self.model(anc_inputs)
        anc_loss = chamfer_distance(anc_points, anc_points_out).mean()

        _, pos_feature = self.model(pos_inputs)
        batch_size = anc_feature.shape[0]

        pdist = nn.PairwiseDistance(p=2)
        anc_pos_dist = pdist(anc_feature, pos_feature)
        anc_pos_dot = []
        for m in range(batch_size):
            anc_pos = torch.dot(anc_feature[m], pos_feature[m])
            anc_pos_dot.append(anc_pos)
        # import pudb; pu.db
        anc_pos_cham = chamfer_distance(anc_points, anc_points)

        anc_neg_dist_list_pos = []
        mix_dist_list_pos = []

        anc_neg_cham_list_pos = []
        mix_cham_list_pos = []

        # import pudb; pu.db

        for n in range(batch_size):
            anc_list = torch.stack([anc_feature[n]] * (2 * batch_size - 2),
                                   dim=0)
            neg_list = torch.cat([
                anc_feature[0:n], anc_feature[n + 1:], pos_feature[0:n],
                pos_feature[n + 1:]
            ],
                                 dim=0)
            anc_neg_dist = pdist(anc_list, neg_list)
            neg_indexes = [
                idx for idx, dist in enumerate(anc_neg_dist)
                if (anc_pos_dist[n] < dist) and (dist < anc_pos_dist[n] + m)
            ]
            # neg_num = len(neg_indexes)

            anc_neg_dot = []
            for m in range(2 * batch_size - 2):
                anc_neg = torch.dot(anc_list[m], neg_list[m])
                anc_neg_dot.append(anc_neg)

            anc_neg_list = torch.Tensor(anc_neg_dot)[neg_indexes]

            anc_neg_dist_list_pos.extend(anc_neg_list)
            mix_dist_list_pos.extend(torch.Tensor(anc_neg_dot))

            anc_points_list = torch.stack([anc_points[n]] *
                                          (2 * batch_size - 2),
                                          dim=0)
            neg_points = torch.cat([
                anc_points[0:n], anc_points[n + 1:], anc_points[0:n],
                anc_points[n + 1:]
            ],
                                   dim=0)
            # neg_points_list = neg_points[neg_indexes]
            anc_mix_cham = chamfer_distance(anc_points_list, neg_points)
            anc_neg_cham = anc_mix_cham[neg_indexes]

            anc_neg_cham_list_pos.extend(anc_neg_cham)
            mix_cham_list_pos.extend(anc_mix_cham)

        anc_neg_dist_list = []
        mix_dist_list = []

        anc_neg_cham_list = []
        mix_cham_list = []

        for n in range(batch_size):
            anc_list = torch.stack([anc_feature[n]] * (batch_size - 1), dim=0)
            neg_list = torch.cat([anc_feature[0:n], anc_feature[n + 1:]],
                                 dim=0)
            anc_neg_dist = pdist(anc_list, neg_list)
            neg_indexes = [
                idx for idx, dist in enumerate(anc_neg_dist)
                if (anc_pos_dist[n] < dist) and (dist < anc_pos_dist[n] + m)
            ]
            # neg_num = len(neg_indexes)
            # anc_neg_list = anc_neg_dist[neg_indexes]

            anc_neg_dot = []
            for m in range(batch_size - 1):
                anc_neg = torch.dot(anc_list[m], neg_list[m])
                anc_neg_dot.append(anc_neg)

            anc_neg_list = torch.Tensor(anc_neg_dot)[neg_indexes]

            anc_neg_dist_list.extend(anc_neg_list)
            mix_dist_list.extend(torch.Tensor(anc_neg_dot))

            anc_points_list = torch.stack([anc_points[n]] * (batch_size - 1),
                                          dim=0)
            neg_points = torch.cat([anc_points[0:n], anc_points[n + 1:]],
                                   dim=0)
            # neg_points_list = neg_points[neg_indexes]
            anc_mix_cham = chamfer_distance(anc_points_list, neg_points)
            anc_neg_cham = anc_mix_cham[neg_indexes]

            anc_neg_cham_list.extend(anc_neg_cham)
            mix_cham_list.extend(anc_mix_cham)

        anc_neg_dist_list_pos = torch.Tensor(
            anc_neg_dist_list_pos).cpu().detach().numpy().tolist()
        mix_dist_list_pos = torch.Tensor(
            mix_dist_list_pos).cpu().detach().numpy().tolist()

        anc_neg_cham_list_pos = torch.Tensor(
            anc_neg_cham_list_pos).cpu().detach().numpy().tolist()
        mix_cham_list_pos = torch.Tensor(
            mix_cham_list_pos).cpu().detach().numpy().tolist()

        anc_neg_dist_list = torch.Tensor(
            anc_neg_dist_list).cpu().detach().numpy().tolist()
        mix_dist_list = torch.Tensor(
            mix_dist_list).cpu().detach().numpy().tolist()

        anc_neg_cham_list = torch.Tensor(
            anc_neg_cham_list).cpu().detach().numpy().tolist()
        mix_cham_list = torch.Tensor(
            mix_cham_list).cpu().detach().numpy().tolist()

        anc_pos_dot = torch.Tensor(anc_pos_dot).cpu().detach().numpy().tolist()
        anc_pos_cham = anc_pos_cham.cpu().detach().numpy().tolist()

        return anc_loss.item(
        ), anc_neg_dist_list_pos, mix_dist_list_pos, anc_neg_cham_list_pos, mix_cham_list_pos, anc_neg_dist_list, mix_dist_list, anc_neg_cham_list, mix_cham_list, anc_pos_dot, anc_pos_cham
Beispiel #11
0
    def compute_loss(self, epoch, anc_points, anc_inputs, m, reg):
        r''' Computes the loss.

        The Point Set Generation Network is trained on the Chamfer distance.

        Args:
            points (tensor): GT point cloud data
            inputs (tensor): input data for the model
        '''
        anc_points_out, anc_feature = self.model(anc_inputs)
        anc_loss = chamfer_distance(anc_points, anc_points_out).mean()

        # _, pos_feature = self.model(pos_inputs)

        # pdist = nn.PairwiseDistance(p=2)
        # anc_pos_dist = pdist(anc_feature, pos_feature)

        batch_size = anc_feature.shape[0]

        trip_list = []

        reg_list = []
        # vals = 0.11
        sigma = m * 0.997 / 3
        # l1 = nn.L1Loss()

        # for n in range(batch_size):
        #     anc_list = torch.stack([anc_feature[n]]*(2*batch_size-2), dim=0)
        #     neg_list = torch.cat([anc_feature[0:n], anc_feature[n+1:], pos_feature[0:n], pos_feature[n+1:]], dim=0)
        #     anc_neg_dist = pdist(anc_list, neg_list)
        #     anc_neg_list = anc_neg_dist[(anc_pos_dist[n] < anc_neg_dist) * (anc_neg_dist < anc_pos_dist[n] + m)]
        #     neg_num = anc_neg_list.shape[0]
        #     if neg_num !=0:
        #         triploss_list = [torch.max(torch.tensor([anc_pos_dist[n] - anc_neg_list[n_neg] + m, 0])) for n_neg in range(neg_num)]
        #         triplet_loss = torch.mean(torch.tensor(triploss_list))
        #         trip_list.append(triplet_loss)

        if reg == 'gau':
            for n in range(batch_size):
                t_feature = anc_feature[n]
                t_feature_list = torch.stack([anc_feature[n]] *
                                             (batch_size - 1),
                                             dim=0)
                s_feature_list = torch.cat(
                    [anc_feature[0:n], anc_feature[n + 1:]], dim=0)
                t_points_list = torch.stack([anc_points[n]] * (batch_size - 1),
                                            dim=0)
                t_points = anc_points[n]
                s_points_list = torch.cat(
                    [anc_points[0:n], anc_points[n + 1:]], dim=0)
                s_num = batch_size - 1
                # t_points_list = t_points[None, :, :]
                # t_points_list = t_points_list.repeat(s_points_list.shape[0],1,1)
                p_deno = torch.sum(
                    torch.exp(-1 *
                              chamfer_distance(t_points_list, s_points_list) /
                              (2 * sigma)))
                p_hat_deno = torch.sum(
                    torch.einsum('bl, bl -> b', t_feature_list,
                                 s_feature_list))
                # p_hat_deno = sum([torch.dot(t_feature, t_feature_list[s_idx]) for s_idx in range(s_num)])
                # p_hat_nume =
                # t_reg_list = []
                # reg_t = 0.
                p = torch.exp(
                    -1 * chamfer_distance(t_points_list, s_points_list) /
                    (2 * sigma)) / p_deno
                p_hat = torch.einsum('bl, bl -> b', t_feature_list,
                                     s_feature_list) / p_hat_deno
                reg_t = torch.mean(torch.abs(p_hat - p))
                # torch.sum(F.l1_loss(p_hat, p))

                # for s_idx in range(s_num):
                #     reg_t += F.l1_loss((torch.dot(t_feature, s_feature_list[s_idx]) / p_hat_deno),
                #         (torch.exp(-1 * chamfer_distance(t_points[None,...], s_points_list[s_idx][None,...])/ (2 * sigma)) / p_deno))
                #     # t_reg_list.append(reg)
                # reg_t = reg_t / s_num
                # t_reg = torch.mean(torch.stack(t_reg_list), dim=0)
                reg_list.append(reg_t)
        else:
            for n in range(batch_size):
                t_feature = anc_feature[n]
                t_feature_list = torch.stack([anc_feature[n]] *
                                             (batch_size - 1),
                                             dim=0)
                s_feature_list = torch.cat(
                    [anc_feature[0:n], anc_feature[n + 1:]], dim=0)
                t_points_list = torch.stack([anc_points[n]] * (batch_size - 1),
                                            dim=0)
                # t_points = anc_points[n]
                s_points_list = torch.cat(
                    [anc_points[0:n], anc_points[n + 1:]], dim=0)
                s_num = batch_size - 1
                # p_deno = sum([1 - (chamfer_distance(t_points[None,...], s_points_list[s_idx][None,...]) / vals) for s_idx in range(s_num)])
                p_deno = torch.sum(
                    1 - chamfer_distance(t_points_list, s_points_list) / vals)
                p_hat_deno = torch.sum(
                    torch.einsum('bl, bl -> b', t_feature_list,
                                 s_feature_list))
                # p_hat_nume =
                # t_reg_list = []
                p = (1 - chamfer_distance(t_points_list, s_points_list) /
                     vals) / p_deno
                p_hat = torch.einsum('bl, bl -> b', t_feature_list,
                                     s_feature_list) / p_hat_deno
                reg_t = torch.mean(torch.abs(p_hat - p))
                # reg_t = 0.
                # for s_idx in range(s_num):
                #     reg_t += F.l1_loss((torch.dot(t_feature, s_feature_list[s_idx]) / p_hat_deno),
                #         (1 - (chamfer_distance(t_points[None,...], s_points_list[s_idx][None,...]) / vals) / p_deno))
                #     # t_reg_list.append(reg)
                # reg_t = reg_t / s_num
                # t_reg = torch.mean(torch.stack(t_reg_list), dim=0)
                reg_list.append(reg_t)

        reg_loss = torch.mean(torch.stack(reg_list), dim=0)

        # triplet_output = torch.mean(torch.tensor(trip_list))

        # triplet_loss = nn.TripletMarginLoss(margin=m, p=2)
        # triplet_output = triplet_loss(anc_feature, pos_feature, neg_feature)
        w_a = 1
        # w_t = 1
        w_r = 1
        # loss = w_a * anc_loss + w_t * triplet_output + w_r * reg_loss
        loss = w_a * anc_loss + w_r * reg_loss

        return loss, anc_loss, reg_loss