def compute_loss(self, outputs1, outputs2, gt_points, normals, img=None): r''' Returns the complete loss. The full loss is adopted from the authors' implementation and consists of a.) Chamfer distance loss b.) edge length loss c.) normal loss d.) Laplacian loss e.) move loss Arguments: outputs1 (list): first outputs of model outputs2 (list): second outputs of model gt_points (tensor): ground truth point cloud locations normals (tensor): normals of the ground truth point cloud img (tensor): input images ''' pred_vertices_1, pred_vertices_2, pred_vertices_3 = outputs1 # Chamfer Distance Loss lc11, lc12, id11, id12 = chamfer_distance(pred_vertices_1, gt_points, give_id=True) lc21, lc22, id21, id22 = chamfer_distance(pred_vertices_2, gt_points, give_id=True) lc31, lc32, id31, id32 = chamfer_distance(pred_vertices_3, gt_points, give_id=True) l_c = lc11.mean() + lc21.mean() + lc31.mean() l_c2 = lc12.mean() + lc22.mean() + lc32.mean() l_c = (l_c2 + self.param_chamfer_rel * l_c) * self.param_chamfer_w # Edge Length Loss l_e = (self.edge_length_loss(pred_vertices_1, 1) + self.edge_length_loss(pred_vertices_2, 2) + self.edge_length_loss(pred_vertices_3, 3)) * self.param_edge # Normal Loss l_n = ( self.normal_loss(pred_vertices_1, normals, id11, 1) + self.normal_loss(pred_vertices_2, normals, id21, 2) + self.normal_loss(pred_vertices_3, normals, id31, 3)) * self.param_n # Laplacian Loss and move loss l_l1, _ = self.laplacian_loss(pred_vertices_1, outputs2[0], block_id=1) l_l2, move_loss1 = self.laplacian_loss(pred_vertices_2, outputs2[1], block_id=2) l_l3, move_loss2 = self.laplacian_loss(pred_vertices_3, outputs2[2], block_id=3) l_l = (self.param_lap_rel * l_l1 + l_l2 + l_l3) * self.param_lap l_m = (move_loss1 + move_loss2) * self.param_move # Final loss loss = l_c + l_e + l_n + l_l + l_m return loss
def eval_step(self, data): r''' Performs an evaluation step. The chamfer loss is calculated and returned in a dictionary. Args: data (tensor): input data ''' self.model.eval() device = self.device points = data.get('pointcloud_chamfer').to(device) inputs = data.get('inputs').to(device) with torch.no_grad(): points_out = self.model(inputs) loss = chamfer_distance(points, points_out).mean() loss = loss.item() eval_dict = { 'loss': loss, 'chamfer': loss, } return eval_dict
def eval_step(self, data): r''' Performs an evaluation step. The chamfer loss is calculated and returned in a dictionary. Args: data (tensor): input data ''' self.model.eval() device = self.device evaluator = MeshEvaluator(n_points=100000) points = data.get('pointcloud_chamfer').to(device) inputs = data.get('inputs').to(device) with torch.no_grad(): points_out, _ = self.model(inputs) batch_size = points.shape[0] points_np = points.cpu() points_out_np = points_out.cpu().numpy() loss = chamfer_distance(points, points_out).mean() loss = loss.item() eval_dict = { 'loss': loss, } return eval_dict
def compute_loss(self, points, inputs): r''' Computes the loss. The Point Set Generation Network is trained on the Chamfer distance. Args: points (tensor): GT point cloud data inputs (tensor): input data for the model ''' points_out = self.model(inputs) loss = chamfer_distance(points, points_out).mean() return loss
def eval_step(self, data): r''' Performs an evaluation step. Arguments: data (tensor): input data ''' self.model.eval() points = data.get('pointcloud').to(self.device) img = data.get('inputs').to(self.device) normals = data.get('pointcloud.normals').to(self.device) # Transform GT points to camera coordinates camera_args = common.get_camera_args(data, 'pointcloud.loc', 'pointcloud.scale', device=self.device) world_mat, camera_mat = camera_args['Rt'], camera_args['K'] points_transformed = common.transform_points(points, world_mat) # Transform GT normals to camera coordinates world_normal_mat = world_mat[:, :, :3] normals = common.transform_points(normals, world_normal_mat) with torch.no_grad(): outputs1, outputs2 = self.model(img, camera_mat) pred_vertices_1, pred_vertices_2, pred_vertices_3 = outputs1 loss = self.compute_loss(outputs1, outputs2, points_transformed, normals, img) lc1, lc2, id31, id32 = chamfer_distance(pred_vertices_3, points_transformed, give_id=True) l_c = (lc1 + lc2).mean() l_e = self.edge_length_loss(pred_vertices_3, 3) l_n = self.normal_loss(pred_vertices_3, normals, id31, 3) l_l, move_loss = self.laplacian_loss(pred_vertices_3, outputs2[2], block_id=3) eval_dict = { 'loss': loss.item(), 'chamfer': l_c.item(), 'edge': l_e.item(), 'normal': l_n.item(), 'laplace': l_l.item(), 'move': move_loss.item() } return eval_dict
def compute_loss(self, points, inputs): r''' Computes the loss. The Point Set Generation Network is trained on the Chamfer distance. Args: points (tensor): GT point cloud data inputs (tensor): input data for the model ''' batch_size, n_steps, n_pts, dim = points.shape points_out = self.model(inputs) if self.loss_corr: n_pts_pred = points_out.shape[2] points_out = points_out.transpose(1, 2).contiguous().view( batch_size, n_pts_pred, -1) points = points.transpose(1, 2).contiguous().view( batch_size, n_pts, -1) else: points_out = points_out.contiguous().view(batch_size * n_steps, -1, dim) points = points.contiguous().view(-1, n_pts, dim) loss = chamfer_distance(points, points_out).mean() return loss
def compute_loss(self, data): ''' Computes the loss. Args: data (dict): data dictionary ''' device = self.device encoder_inputs, raw_data = compose_inputs( data, mode='train', device=self.device, input_type=self.input_type, depth_pointcloud_transfer=self.depth_pointcloud_transfer, ) world_mat = None if (self.model.encoder_world_mat is not None) \ or self.gt_pointcloud_transfer in ('view', 'view_scale_model'): if 'world_mat' in raw_data: world_mat = raw_data['world_mat'] else: world_mat = get_world_mat(data, device=device) gt_pc = compose_pointcloud(data, device, self.gt_pointcloud_transfer, world_mat=world_mat) if self.model.encoder_world_mat is not None: out = self.model(encoder_inputs, world_mat=world_mat) else: out = self.model(encoder_inputs) loss = 0 if isinstance(out, tuple): out, trans_feat = out if isinstance(self.model.encoder, PointNetEncoder) or isinstance( self.model.encoder, PointNetResEncoder): loss = loss + 0.001 * feature_transform_reguliarzer(trans_feat) # chamfer distance loss if self.loss_type == 'cd': loss = loss + chamfer_distance(out, gt_pc).mean() else: out_pts_count = out.size(1) loss = loss + (emd.earth_mover_distance( out, gt_pc, transpose=False) / out_pts_count).mean() # view penalty loss if self.view_penalty: gt_mask_flow = data.get('inputs.mask_flow').to( device) # B * 1 * H * W if world_mat is None: world_mat = get_world_mat(data, device=device) camera_mat = get_camera_mat(data, device=device) # projection use world mat & camera mat if self.gt_pointcloud_transfer == 'world_scale_model': out_pts = transform_points(out, world_mat) elif self.gt_pointcloud_transfer == 'view_scale_model': t = world_mat[:, :, 3:] out_pts = out_pts + t elif self.gt_pointcloud_transfer == 'view': t = world_mat[:, :, 3:] out_pts = out_pts * t[:, 2:, :] out_pts = out_pts + t else: raise NotImplementedError out_pts_img = project_to_camera(out_pts, camera_mat) out_pts_img = out_pts_img.unsqueeze(1) # B * 1 * n_pts * 2 out_mask_flow = F.grid_sample(gt_mask_flow, out_pts_img) # B * 1 * 1 * n_pts loss_mask_flow = F.relu(1. - out_mask_flow, inplace=True).mean() loss = loss + self.loss_mask_flow_ratio * loss_mask_flow if self.view_penalty == 'mask_flow_and_depth': # depth test loss t_scale = world_mat[:, 2, 3].view(world_mat.size(0), 1, 1, 1) gt_mask = data.get('inputs.mask').byte().to(device) depth_pred = data.get('inputs.depth_pred').to( device) * t_scale # absolute depth from view background_setting(depth_pred, gt_mask) depth_z = out_pts[:, :, 2:].transpose(1, 2) corresponding_z = F.grid_sample( depth_pred, out_pts_img) # B * 1 * 1 * n_pts corresponding_z = corresponding_z.squeeze(1) # eps loss_depth_test = F.relu(depth_z - self.depth_test_eps - corresponding_z, inplace=True).mean() loss = loss + self.loss_depth_test_ratio * loss_depth_test return loss
def eval_step(self, data): ''' Performs an evaluation step. Args: data (dict): data dictionary ''' self.model.eval() device = self.device encoder_inputs, raw_data = compose_inputs( data, mode='train', device=self.device, input_type=self.input_type, depth_pointcloud_transfer=self.depth_pointcloud_transfer) world_mat = None if (self.model.encoder_world_mat is not None) \ or self.gt_pointcloud_transfer in ('view', 'view_scale_model'): if 'world_mat' in raw_data: world_mat = raw_data['world_mat'] else: world_mat = get_world_mat(data, device=device) gt_pc = compose_pointcloud(data, device, self.gt_pointcloud_transfer, world_mat=world_mat) batch_size = gt_pc.size(0) with torch.no_grad(): if self.model.encoder_world_mat is not None: out = self.model(encoder_inputs, world_mat=world_mat) else: out = self.model(encoder_inputs) if isinstance(out, tuple): out, trans_feat = out eval_dict = {} if batch_size == 1: pointcloud_hat = out.cpu().squeeze(0).numpy() pointcloud_gt = gt_pc.cpu().squeeze(0).numpy() eval_dict = self.mesh_evaluator.eval_pointcloud( pointcloud_hat, pointcloud_gt) # chamfer distance loss if self.loss_type == 'cd': loss = chamfer_distance(out, gt_pc) else: loss = emd.earth_mover_distance(out, gt_pc, transpose=False) if self.gt_pointcloud_transfer in ('world_scale_model', 'view_scale_model', 'view'): pointcloud_scale = data.get('pointcloud.scale').to( device).view(batch_size, 1, 1) loss = loss / (pointcloud_scale**2) if self.gt_pointcloud_transfer == 'view': if world_mat is None: world_mat = get_world_mat(data, device=device) t_scale = world_mat[:, 2:, 3:] loss = loss * (t_scale**2) if self.loss_type == 'cd': loss = loss.mean() eval_dict['chamfer'] = loss.item() else: out_pts_count = out.size(1) loss = (loss / out_pts_count).mean() eval_dict['emd'] = loss.item() # view penalty loss if self.view_penalty: gt_mask_flow = data.get('inputs.mask_flow').to( device) # B * 1 * H * W if world_mat is None: world_mat = get_world_mat(data, device=device) camera_mat = get_camera_mat(data, device=device) # projection use world mat & camera mat if self.gt_pointcloud_transfer == 'world_scale_model': out_pts = transform_points(out, world_mat) elif self.gt_pointcloud_transfer == 'view_scale_model': t = world_mat[:, :, 3:] out_pts = out_pts + t elif self.gt_pointcloud_transfer == 'view': t = world_mat[:, :, 3:] out_pts = out_pts * t[:, 2:, :] out_pts = out_pts + t else: raise NotImplementedError out_pts_img = project_to_camera(out_pts, camera_mat) out_pts_img = out_pts_img.unsqueeze(1) # B * 1 * n_pts * 2 out_mask_flow = F.grid_sample(gt_mask_flow, out_pts_img) # B * 1 * 1 * n_pts loss_mask_flow = F.relu(1. - out_mask_flow, inplace=True).mean() loss = self.loss_mask_flow_ratio * loss_mask_flow eval_dict['loss_mask_flow'] = loss.item() if self.view_penalty == 'mask_flow_and_depth': # depth test loss t_scale = world_mat[:, 2, 3].view(world_mat.size(0), 1, 1, 1) gt_mask = data.get('inputs.mask').byte().to(device) depth_pred = data.get('inputs.depth_pred').to( device) * t_scale background_setting(depth_pred, gt_mask) depth_z = out_pts[:, :, 2:].transpose(1, 2) corresponding_z = F.grid_sample( depth_pred, out_pts_img) # B * 1 * 1 * n_pts corresponding_z = corresponding_z.squeeze(1) # eps = 0.05 loss_depth_test = F.relu(depth_z - self.depth_test_eps - corresponding_z, inplace=True).mean() loss = self.loss_depth_test_ratio * loss_depth_test eval_dict['loss_depth_test'] = loss return eval_dict
def eval_step_full(self, data): r''' Performs an evaluation step. The chamfer loss is calculated and returned in a dictionary. Args: data (tensor): input data ''' self.model.eval() device = self.device evaluator = MeshEvaluator(n_points=100000) points = data.get('pointcloud_chamfer').to(device) normals = data.get('pointcloud_chamfer.normals') inputs = data.get('inputs').to(device) with torch.no_grad(): points_out, _ = self.model(inputs) batch_size = points.shape[0] points_np = points.cpu() points_out_np = points_out.cpu().numpy() completeness_list = [] accuracy_list = [] completeness2_list = [] accuracy2_list = [] chamfer_L1_list = [] chamfer_L2_list = [] edge_chamferL1_list = [] edge_chamferL2_list = [] emd_list = [] fscore_list = [] for idx in range(batch_size): eval_dict_pcl = evaluator.eval_pointcloud(points_out_np[idx], points_np[idx], normals_tgt=normals[idx]) completeness_list.append(eval_dict_pcl['completeness']) accuracy_list.append(eval_dict_pcl['accuracy']) completeness2_list.append(eval_dict_pcl['completeness2']) accuracy2_list.append(eval_dict_pcl['accuracy2']) chamfer_L1_list.append(eval_dict_pcl['chamfer-L1']) chamfer_L2_list.append(eval_dict_pcl['chamfer-L2']) edge_chamferL1_list.append(eval_dict_pcl['edge-chamfer-L1']) edge_chamferL2_list.append(eval_dict_pcl['edge-chamfer-L2']) emd_list.append(eval_dict_pcl['emd']) fscore_list.append(eval_dict_pcl['fscore']) completeness = mean(completeness_list).item() accuracy = mean(accuracy_list).item() completeness2 = mean(completeness2_list).item() accuracy2 = mean(accuracy2_list).item() chamfer_L1 = mean(chamfer_L1_list).item() chamfer_L2 = mean(chamfer_L2_list).item() edge_chamferL1 = mean(edge_chamferL1_list).item() edge_chamferL2 = mean(edge_chamferL2_list).item() emd = mean(emd_list).item() fscore = mean(fscore_list) loss = chamfer_distance(points, points_out).mean() loss = loss.item() eval_dict = { 'loss': loss, 'chamfer': loss, 'completeness': completeness, 'completeness2': completeness2, 'accuracy': accuracy, 'accuracy2': accuracy2, 'chamfer-L1': chamfer_L1, 'chamfer-L2': chamfer_L2, 'edge-chamfer-L2': edge_chamferL2, 'edge-chamfer-L1': edge_chamferL1, 'emd': emd, 'fscore': fscore } return eval_dict
def plot_dot_step(self, data, epoch, m): self.model.train() anc_points = data.get('pointcloud').to(self.device) anc_inputs = data.get('inputs').to(self.device) pos_inputs = data.get('inputs.Bias').to(self.device) anc_points_out, anc_feature = self.model(anc_inputs) anc_loss = chamfer_distance(anc_points, anc_points_out).mean() _, pos_feature = self.model(pos_inputs) batch_size = anc_feature.shape[0] pdist = nn.PairwiseDistance(p=2) anc_pos_dist = pdist(anc_feature, pos_feature) anc_pos_dot = [] for m in range(batch_size): anc_pos = torch.dot(anc_feature[m], pos_feature[m]) anc_pos_dot.append(anc_pos) # import pudb; pu.db anc_pos_cham = chamfer_distance(anc_points, anc_points) anc_neg_dist_list_pos = [] mix_dist_list_pos = [] anc_neg_cham_list_pos = [] mix_cham_list_pos = [] # import pudb; pu.db for n in range(batch_size): anc_list = torch.stack([anc_feature[n]] * (2 * batch_size - 2), dim=0) neg_list = torch.cat([ anc_feature[0:n], anc_feature[n + 1:], pos_feature[0:n], pos_feature[n + 1:] ], dim=0) anc_neg_dist = pdist(anc_list, neg_list) neg_indexes = [ idx for idx, dist in enumerate(anc_neg_dist) if (anc_pos_dist[n] < dist) and (dist < anc_pos_dist[n] + m) ] # neg_num = len(neg_indexes) anc_neg_dot = [] for m in range(2 * batch_size - 2): anc_neg = torch.dot(anc_list[m], neg_list[m]) anc_neg_dot.append(anc_neg) anc_neg_list = torch.Tensor(anc_neg_dot)[neg_indexes] anc_neg_dist_list_pos.extend(anc_neg_list) mix_dist_list_pos.extend(torch.Tensor(anc_neg_dot)) anc_points_list = torch.stack([anc_points[n]] * (2 * batch_size - 2), dim=0) neg_points = torch.cat([ anc_points[0:n], anc_points[n + 1:], anc_points[0:n], anc_points[n + 1:] ], dim=0) # neg_points_list = neg_points[neg_indexes] anc_mix_cham = chamfer_distance(anc_points_list, neg_points) anc_neg_cham = anc_mix_cham[neg_indexes] anc_neg_cham_list_pos.extend(anc_neg_cham) mix_cham_list_pos.extend(anc_mix_cham) anc_neg_dist_list = [] mix_dist_list = [] anc_neg_cham_list = [] mix_cham_list = [] for n in range(batch_size): anc_list = torch.stack([anc_feature[n]] * (batch_size - 1), dim=0) neg_list = torch.cat([anc_feature[0:n], anc_feature[n + 1:]], dim=0) anc_neg_dist = pdist(anc_list, neg_list) neg_indexes = [ idx for idx, dist in enumerate(anc_neg_dist) if (anc_pos_dist[n] < dist) and (dist < anc_pos_dist[n] + m) ] # neg_num = len(neg_indexes) # anc_neg_list = anc_neg_dist[neg_indexes] anc_neg_dot = [] for m in range(batch_size - 1): anc_neg = torch.dot(anc_list[m], neg_list[m]) anc_neg_dot.append(anc_neg) anc_neg_list = torch.Tensor(anc_neg_dot)[neg_indexes] anc_neg_dist_list.extend(anc_neg_list) mix_dist_list.extend(torch.Tensor(anc_neg_dot)) anc_points_list = torch.stack([anc_points[n]] * (batch_size - 1), dim=0) neg_points = torch.cat([anc_points[0:n], anc_points[n + 1:]], dim=0) # neg_points_list = neg_points[neg_indexes] anc_mix_cham = chamfer_distance(anc_points_list, neg_points) anc_neg_cham = anc_mix_cham[neg_indexes] anc_neg_cham_list.extend(anc_neg_cham) mix_cham_list.extend(anc_mix_cham) anc_neg_dist_list_pos = torch.Tensor( anc_neg_dist_list_pos).cpu().detach().numpy().tolist() mix_dist_list_pos = torch.Tensor( mix_dist_list_pos).cpu().detach().numpy().tolist() anc_neg_cham_list_pos = torch.Tensor( anc_neg_cham_list_pos).cpu().detach().numpy().tolist() mix_cham_list_pos = torch.Tensor( mix_cham_list_pos).cpu().detach().numpy().tolist() anc_neg_dist_list = torch.Tensor( anc_neg_dist_list).cpu().detach().numpy().tolist() mix_dist_list = torch.Tensor( mix_dist_list).cpu().detach().numpy().tolist() anc_neg_cham_list = torch.Tensor( anc_neg_cham_list).cpu().detach().numpy().tolist() mix_cham_list = torch.Tensor( mix_cham_list).cpu().detach().numpy().tolist() anc_pos_dot = torch.Tensor(anc_pos_dot).cpu().detach().numpy().tolist() anc_pos_cham = anc_pos_cham.cpu().detach().numpy().tolist() return anc_loss.item( ), anc_neg_dist_list_pos, mix_dist_list_pos, anc_neg_cham_list_pos, mix_cham_list_pos, anc_neg_dist_list, mix_dist_list, anc_neg_cham_list, mix_cham_list, anc_pos_dot, anc_pos_cham
def compute_loss(self, epoch, anc_points, anc_inputs, m, reg): r''' Computes the loss. The Point Set Generation Network is trained on the Chamfer distance. Args: points (tensor): GT point cloud data inputs (tensor): input data for the model ''' anc_points_out, anc_feature = self.model(anc_inputs) anc_loss = chamfer_distance(anc_points, anc_points_out).mean() # _, pos_feature = self.model(pos_inputs) # pdist = nn.PairwiseDistance(p=2) # anc_pos_dist = pdist(anc_feature, pos_feature) batch_size = anc_feature.shape[0] trip_list = [] reg_list = [] # vals = 0.11 sigma = m * 0.997 / 3 # l1 = nn.L1Loss() # for n in range(batch_size): # anc_list = torch.stack([anc_feature[n]]*(2*batch_size-2), dim=0) # neg_list = torch.cat([anc_feature[0:n], anc_feature[n+1:], pos_feature[0:n], pos_feature[n+1:]], dim=0) # anc_neg_dist = pdist(anc_list, neg_list) # anc_neg_list = anc_neg_dist[(anc_pos_dist[n] < anc_neg_dist) * (anc_neg_dist < anc_pos_dist[n] + m)] # neg_num = anc_neg_list.shape[0] # if neg_num !=0: # triploss_list = [torch.max(torch.tensor([anc_pos_dist[n] - anc_neg_list[n_neg] + m, 0])) for n_neg in range(neg_num)] # triplet_loss = torch.mean(torch.tensor(triploss_list)) # trip_list.append(triplet_loss) if reg == 'gau': for n in range(batch_size): t_feature = anc_feature[n] t_feature_list = torch.stack([anc_feature[n]] * (batch_size - 1), dim=0) s_feature_list = torch.cat( [anc_feature[0:n], anc_feature[n + 1:]], dim=0) t_points_list = torch.stack([anc_points[n]] * (batch_size - 1), dim=0) t_points = anc_points[n] s_points_list = torch.cat( [anc_points[0:n], anc_points[n + 1:]], dim=0) s_num = batch_size - 1 # t_points_list = t_points[None, :, :] # t_points_list = t_points_list.repeat(s_points_list.shape[0],1,1) p_deno = torch.sum( torch.exp(-1 * chamfer_distance(t_points_list, s_points_list) / (2 * sigma))) p_hat_deno = torch.sum( torch.einsum('bl, bl -> b', t_feature_list, s_feature_list)) # p_hat_deno = sum([torch.dot(t_feature, t_feature_list[s_idx]) for s_idx in range(s_num)]) # p_hat_nume = # t_reg_list = [] # reg_t = 0. p = torch.exp( -1 * chamfer_distance(t_points_list, s_points_list) / (2 * sigma)) / p_deno p_hat = torch.einsum('bl, bl -> b', t_feature_list, s_feature_list) / p_hat_deno reg_t = torch.mean(torch.abs(p_hat - p)) # torch.sum(F.l1_loss(p_hat, p)) # for s_idx in range(s_num): # reg_t += F.l1_loss((torch.dot(t_feature, s_feature_list[s_idx]) / p_hat_deno), # (torch.exp(-1 * chamfer_distance(t_points[None,...], s_points_list[s_idx][None,...])/ (2 * sigma)) / p_deno)) # # t_reg_list.append(reg) # reg_t = reg_t / s_num # t_reg = torch.mean(torch.stack(t_reg_list), dim=0) reg_list.append(reg_t) else: for n in range(batch_size): t_feature = anc_feature[n] t_feature_list = torch.stack([anc_feature[n]] * (batch_size - 1), dim=0) s_feature_list = torch.cat( [anc_feature[0:n], anc_feature[n + 1:]], dim=0) t_points_list = torch.stack([anc_points[n]] * (batch_size - 1), dim=0) # t_points = anc_points[n] s_points_list = torch.cat( [anc_points[0:n], anc_points[n + 1:]], dim=0) s_num = batch_size - 1 # p_deno = sum([1 - (chamfer_distance(t_points[None,...], s_points_list[s_idx][None,...]) / vals) for s_idx in range(s_num)]) p_deno = torch.sum( 1 - chamfer_distance(t_points_list, s_points_list) / vals) p_hat_deno = torch.sum( torch.einsum('bl, bl -> b', t_feature_list, s_feature_list)) # p_hat_nume = # t_reg_list = [] p = (1 - chamfer_distance(t_points_list, s_points_list) / vals) / p_deno p_hat = torch.einsum('bl, bl -> b', t_feature_list, s_feature_list) / p_hat_deno reg_t = torch.mean(torch.abs(p_hat - p)) # reg_t = 0. # for s_idx in range(s_num): # reg_t += F.l1_loss((torch.dot(t_feature, s_feature_list[s_idx]) / p_hat_deno), # (1 - (chamfer_distance(t_points[None,...], s_points_list[s_idx][None,...]) / vals) / p_deno)) # # t_reg_list.append(reg) # reg_t = reg_t / s_num # t_reg = torch.mean(torch.stack(t_reg_list), dim=0) reg_list.append(reg_t) reg_loss = torch.mean(torch.stack(reg_list), dim=0) # triplet_output = torch.mean(torch.tensor(trip_list)) # triplet_loss = nn.TripletMarginLoss(margin=m, p=2) # triplet_output = triplet_loss(anc_feature, pos_feature, neg_feature) w_a = 1 # w_t = 1 w_r = 1 # loss = w_a * anc_loss + w_t * triplet_output + w_r * reg_loss loss = w_a * anc_loss + w_r * reg_loss return loss, anc_loss, reg_loss