def save_summaries(writer: SummaryWriter, data: Dict, predicted: List, endpoints: Dict = None, losses: Dict = None, metrics: Dict = None, step: int = 0): """Save tensorboard summaries""" subset = [0, 1] with torch.no_grad(): # Save clouds if 'points_src' in data: points_src = data['points_src'][subset, ..., :3] points_ref = data['points_ref'][subset, ..., :3] colors = torch.from_numpy( np.concatenate([np.tile(ORANGE, (*points_src.shape[0:2], 1)), np.tile(BLUE, (*points_ref.shape[0:2], 1))], axis=1)) iters_to_save = [0, len(predicted)-1] if len(predicted) > 1 else [0] # Save point cloud at iter0, iter1 and after last iter concat_cloud_input = torch.cat((points_src, points_ref), dim=1) writer.add_mesh('iter_0', vertices=concat_cloud_input, colors=colors, global_step=step) for i_iter in iters_to_save: src_transformed_first = se3.transform(predicted[i_iter][subset, ...], points_src) concat_cloud_first = torch.cat((src_transformed_first, points_ref), dim=1) writer.add_mesh('iter_{}'.format(i_iter+1), vertices=concat_cloud_first, colors=colors, global_step=step) if endpoints is not None and 'perm_matrices' in endpoints: color_mapper = colormap.ScalarMappable(norm=None, cmap=colormap.get_cmap('coolwarm')) for i_iter in iters_to_save: ref_weights = torch.sum(endpoints['perm_matrices'][i_iter][subset, ...], dim=1) ref_colors = color_mapper.to_rgba(ref_weights.detach().cpu().numpy())[..., :3] writer.add_mesh('ref_weights_{}'.format(i_iter), vertices=points_ref, colors=torch.from_numpy(ref_colors) * 255, global_step=step) if endpoints is not None: if 'perm_matrices' in endpoints: for i_iter in range(len(endpoints['perm_matrices'])): src_weights = torch.sum(endpoints['perm_matrices'][i_iter], dim=2) ref_weights = torch.sum(endpoints['perm_matrices'][i_iter], dim=1) writer.add_histogram('src_weights_{}'.format(i_iter), src_weights, global_step=step) writer.add_histogram('ref_weights_{}'.format(i_iter), ref_weights, global_step=step) # Write losses and metrics if losses is not None: for l in losses: writer.add_scalar('losses/{}'.format(l), losses[l], step) if metrics is not None: for m in metrics: writer.add_scalar('metrics/{}'.format(m), metrics[m], step) writer.flush()
def compute_metrics(data, pred_transforms): """ Compute metrics required in the paper """ def square_distance(src, dst): return torch.sum((src[:, :, None, :] - dst[:, None, :, :])**2, dim=-1) with torch.no_grad(): pred_transforms = pred_transforms gt_transforms = data['transform_gt'] points_src = data['points_src'][..., :3] points_ref = data['points_ref'][..., :3] points_raw = data['points_raw'][..., :3] # Euler angles, Individual translation errors (Deep Closest Point convention) # TODO Change rotation to torch operations r_gt_euler_deg = dcm2euler(gt_transforms[:, :3, :3].numpy(), seq='xyz') r_pred_euler_deg = dcm2euler(pred_transforms[:, :3, :3].numpy(), seq='xyz') t_gt = gt_transforms[:, :3, 3] t_pred = pred_transforms[:, :3, 3] r_mse = np.mean((r_gt_euler_deg - r_pred_euler_deg)**2, axis=1) r_mae = np.mean(np.abs(r_gt_euler_deg - r_pred_euler_deg), axis=1) t_mse = torch.mean((t_gt - t_pred)**2, dim=1) t_mae = torch.mean(torch.abs(t_gt - t_pred), dim=1) # Rotation, translation errors (isotropic, i.e. doesn't depend on error # direction, which is more representative of the actual error) concatenated = se3.concatenate(se3.inverse(gt_transforms), pred_transforms) rot_trace = concatenated[:, 0, 0] + concatenated[:, 1, 1] + concatenated[:, 2, 2] residual_rotdeg = torch.acos( torch.clamp(0.5 * (rot_trace - 1), min=-1.0, max=1.0)) * 180.0 / np.pi residual_transmag = concatenated[:, :, 3].norm(dim=-1) # Modified Chamfer distance src_transformed = se3.transform(pred_transforms, points_src) ref_clean = points_raw src_clean = se3.transform( se3.concatenate(pred_transforms, se3.inverse(gt_transforms)), points_raw) dist_src = torch.min(square_distance(src_transformed, ref_clean), dim=-1)[0] dist_ref = torch.min(square_distance(points_ref, src_clean), dim=-1)[0] chamfer_dist = torch.mean(dist_src, dim=1) + torch.mean(dist_ref, dim=1) metrics = { 'r_mse': r_mse, 'r_mae': r_mae, 't_mse': to_array(t_mse), 't_mae': to_array(t_mae), 'err_r_deg': to_array(residual_rotdeg), 'err_t': to_array(residual_transmag), 'chamfer_dist': to_array(chamfer_dist) } return metrics
def compute_losses(data: Dict, pred_transforms: List, endpoints: Dict, loss_type: str = 'mae', reduction: str = 'mean') -> Dict: """Compute losses Args: data: Current mini-batch data pred_transforms: Predicted transform, to compute main registration loss endpoints: Endpoints for training. For computing outlier penalty loss_type: Registration loss type, either 'mae' (Mean absolute error, used in paper) or 'mse' reduction: Either 'mean' or 'none'. Use 'none' to accumulate losses outside (useful for accumulating losses for entire validation dataset) Returns: losses: Dict containing various fields. Total loss to be optimized is in losses['total'] """ losses = {} num_iter = len(pred_transforms) # Compute losses gt_src_transformed = se3.transform(data['transform_gt'], data['points_src'][..., :3]) if loss_type == 'mse': # MSE loss to the groundtruth (does not take into account possible symmetries) criterion = nn.MSELoss(reduction=reduction) for i in range(num_iter): pred_src_transformed = se3.transform(pred_transforms[i], data['points_src'][..., :3]) if reduction.lower() == 'mean': losses['mse_{}'.format(i)] = criterion(pred_src_transformed, gt_src_transformed) elif reduction.lower() == 'none': losses['mse_{}'.format(i)] = torch.mean(criterion(pred_src_transformed, gt_src_transformed), dim=[-1, -2]) elif loss_type == 'mae': # MSE loss to the groundtruth (does not take into account possible symmetries) criterion = nn.L1Loss(reduction=reduction) for i in range(num_iter): pred_src_transformed = se3.transform(pred_transforms[i], data['points_src'][..., :3]) if reduction.lower() == 'mean': losses['mae_{}'.format(i)] = criterion(pred_src_transformed, gt_src_transformed) elif reduction.lower() == 'none': losses['mae_{}'.format(i)] = torch.mean(criterion(pred_src_transformed, gt_src_transformed), dim=[-1, -2]) else: raise NotImplementedError # Penalize outliers for i in range(num_iter): ref_outliers_strength = (1.0 - torch.sum(endpoints['perm_matrices'][i], dim=1)) * _args.wt_inliers src_outliers_strength = (1.0 - torch.sum(endpoints['perm_matrices'][i], dim=2)) * _args.wt_inliers if reduction.lower() == 'mean': losses['outlier_{}'.format(i)] = torch.mean(ref_outliers_strength) + torch.mean(src_outliers_strength) elif reduction.lower() == 'none': losses['outlier_{}'.format(i)] = torch.mean(ref_outliers_strength, dim=1) + \ torch.mean(src_outliers_strength, dim=1) discount_factor = 0.5 # Early iterations will be discounted total_losses = [] for k in losses: discount = discount_factor ** (num_iter - int(k[k.rfind('_')+1:]) - 1) total_losses.append(losses[k] * discount) losses['total'] = torch.sum(torch.stack(total_losses), dim=0) return losses
def forward(self, data, num_iter: int = 1): """Forward pass for RPMNet Args: data: Dict containing the following fields: 'points_src': Source points (B, J, 6) 'points_ref': Reference points (B, K, 6) num_iter (int): Number of iterations. Recommended to be 2 for training Returns: transform: Transform to apply to source points such that they align to reference src_transformed: Transformed source points """ endpoints = {} xyz_ref, norm_ref = data['points_ref'][:, :, :3], data[ 'points_ref'][:, :, 3:6] xyz_src, norm_src = data['points_src'][:, :, :3], data[ 'points_src'][:, :, 3:6] xyz_src_t, norm_src_t = xyz_src, norm_src transforms = [] all_gamma, all_perm_matrices, all_weighted_ref = [], [], [] all_beta, all_alpha = [], [] for i in range(num_iter): beta, alpha = self.weights_net([xyz_src_t, xyz_ref]) feat_src = self.feat_extractor(xyz_src_t, norm_src_t) feat_ref = self.feat_extractor(xyz_ref, norm_ref) feat_distance = match_features(feat_src, feat_ref) affinity = self.compute_affinity(beta, feat_distance, alpha=alpha) # Compute weighted coordinates log_perm_matrix = sinkhorn(affinity, n_iters=self.num_sk_iter, slack=self.add_slack) perm_matrix = torch.exp(log_perm_matrix) weighted_ref = perm_matrix @ xyz_ref / ( torch.sum(perm_matrix, dim=2, keepdim=True) + _EPS) # Compute transform and transform points transform = compute_rigid_transform(xyz_src, weighted_ref, weights=torch.sum(perm_matrix, dim=2)) xyz_src_t, norm_src_t = se3.transform(transform.detach(), xyz_src, norm_src) transforms.append(transform) all_gamma.append(torch.exp(affinity)) all_perm_matrices.append(perm_matrix) all_weighted_ref.append(weighted_ref) all_beta.append(to_numpy(beta)) all_alpha.append(to_numpy(alpha)) endpoints['perm_matrices_init'] = all_gamma endpoints['perm_matrices'] = all_perm_matrices endpoints['weighted_ref'] = all_weighted_ref endpoints['beta'] = np.stack(all_beta, axis=0) endpoints['alpha'] = np.stack(all_alpha, axis=0) return transforms, endpoints