def _train_epoch(self, epoch): gc.collect() # Fix the feature model and train the inlier model self.feat_model.eval() self.inlier_model.train() # Epoch starts from 1 total_loss, total_num = 0, 0.0 data_loader = self.data_loader iter_size = self.iter_size # Meters for statistics average_valid_meter = AverageMeter() loss_meter = AverageMeter() data_meter = AverageMeter() regist_succ_meter = AverageMeter() regist_rte_meter = AverageMeter() regist_rre_meter = AverageMeter() # Timers for profiling data_timer = Timer() nn_timer = Timer() inlier_timer = Timer() total_timer = Timer() if self.config.num_train_iter > 0: num_train_iter = self.config.num_train_iter else: num_train_iter = len(data_loader) // iter_size start_iter = (epoch - 1) * num_train_iter tp, fp, tn, fn = 0, 0, 0, 0 # Iterate over batches for curr_iter in range(num_train_iter): self.optimizer.zero_grad() batch_loss, data_time = 0, 0 total_timer.tic() for iter_idx in range(iter_size): data_timer.tic() input_dict = self.get_data(self.train_data_loader_iter) data_time += data_timer.toc(average=False) # Initial inlier prediction with FCGF and KNN matching reg_coords, reg_feats, pred_pairs, is_correct, feat_time, nn_time = self.generate_inlier_input( xyz0=input_dict['pcd0'], xyz1=input_dict['pcd1'], iC0=input_dict['sinput0_C'], iC1=input_dict['sinput1_C'], iF0=input_dict['sinput0_F'], iF1=input_dict['sinput1_F'], len_batch=input_dict['len_batch'], pos_pairs=input_dict['correspondences']) nn_timer.update(nn_time) # Inlier prediction with 6D ConvNet inlier_timer.tic() reg_sinput = ME.SparseTensor(reg_feats.contiguous(), coordinates=reg_coords.int(), device=self.device) reg_soutput = self.inlier_model(reg_sinput) inlier_timer.toc() logits = reg_soutput.F weights = logits.sigmoid() # Truncate weights too low # For training, inplace modification is prohibited for backward if self.clip_weight_thresh > 0: weights_tmp = torch.zeros_like(weights) valid_mask = weights > self.clip_weight_thresh weights_tmp[valid_mask] = weights[valid_mask] weights = weights_tmp # Weighted Procrustes pred_rots, pred_trans, ws = self.weighted_procrustes( xyz0s=input_dict['pcd0'], xyz1s=input_dict['pcd1'], pred_pairs=pred_pairs, weights=weights) # Get batch registration loss gt_rots, gt_trans = self.decompose_rotation_translation( input_dict['T_gt']) rot_error = batch_rotation_error(pred_rots, gt_rots) trans_error = batch_translation_error(pred_trans, gt_trans) individual_loss = rot_error + self.config.trans_weight * trans_error # Select batches with at least 10 valid correspondences valid_mask = ws > 10 num_valid = valid_mask.sum().item() average_valid_meter.update(num_valid) # Registration loss against registration GT loss = self.config.procrustes_loss_weight * individual_loss[ valid_mask].mean() if not np.isfinite(loss.item()): max_val = loss.item() logging.info('Loss is infinite, abort ') continue # Direct inlier loss against nearest neighbor searched GT target = torch.from_numpy(is_correct).squeeze() if self.config.inlier_use_direct_loss: inlier_loss = self.config.inlier_direct_loss_weight * self.crit( logits.cpu().squeeze(), target.to( torch.float)) / iter_size loss += inlier_loss loss.backward() # Update statistics before backprop with torch.no_grad(): regist_rre_meter.update(rot_error.squeeze() * 180 / np.pi) regist_rte_meter.update(trans_error.squeeze()) success = (trans_error.squeeze() < self.config.success_rte_thresh) * ( rot_error.squeeze() * 180 / np.pi < self.config.success_rre_thresh) regist_succ_meter.update(success.float()) batch_loss += loss.mean().item() neg_target = (~target).to(torch.bool) pred = logits > 0 # todo thresh pred_on_pos, pred_on_neg = pred[target], pred[neg_target] tp += pred_on_pos.sum().item() fp += pred_on_neg.sum().item() tn += (~pred_on_neg).sum().item() fn += (~pred_on_pos).sum().item() # Check gradient and avoid backprop of inf values max_grad = torch.abs(self.inlier_model.final.kernel.grad ).max().cpu().item() # Backprop only if gradient is finite if not np.isfinite(max_grad): self.optimizer.zero_grad() logging.info( f'Clearing the NaN gradient at iter {curr_iter}') else: self.optimizer.step() total_loss += batch_loss total_num += 1.0 total_timer.toc() data_meter.update(data_time) loss_meter.update(batch_loss) # Output to logs if curr_iter % self.config.stat_freq == 0: precision = tp / (tp + fp + eps) recall = tp / (tp + fn + eps) f1 = 2 * (precision * recall) / (precision + recall + eps) tpr = tp / (tp + fn + eps) tnr = tn / (tn + fp + eps) balanced_accuracy = (tpr + tnr) / 2 correspondence_accuracy = is_correct.sum() / len(is_correct) total, free = ME.get_gpu_memory_info() used = (total - free) / 1073741824.0 stat = { 'loss': loss_meter.avg, 'precision': precision, 'recall': recall, 'tpr': tpr, 'tnr': tnr, 'balanced_accuracy': balanced_accuracy, 'f1': f1, 'num_valid': average_valid_meter.avg, 'gpu_used': used, } for k, v in stat.items(): self.writer.add_scalar(f'train/{k}', v, start_iter + curr_iter) logging.info(' '.join([ f"Train Epoch: {epoch} [{curr_iter}/{num_train_iter}],", f"Current Loss: {loss_meter.avg:.3e},", f"Correspondence acc: {correspondence_accuracy:.3e}", f", Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f},", f"TPR: {tpr:.4f}, TNR: {tnr:.4f}, BAcc: {balanced_accuracy:.4f}", f"RTE: {regist_rte_meter.avg:.3e}, RRE: {regist_rre_meter.avg:.3e},", f"Succ rate: {regist_succ_meter.avg:3e}", f"Avg num valid: {average_valid_meter.avg:3e}", f"\tData time: {data_meter.avg:.4f}, Train time: {total_timer.avg - data_meter.avg:.4f},", f"NN search time: {nn_timer.avg:.3e}, Total time: {total_timer.avg:.4f}" ])) loss_meter.reset() regist_rte_meter.reset() regist_rre_meter.reset() regist_succ_meter.reset() average_valid_meter.reset() data_meter.reset() total_timer.reset() tp, fp, tn, fn = 0, 0, 0, 0
def test(self): self.assertTrue(ME.is_cuda_available() == torch.cuda.is_available()) if ME.is_cuda_available(): print(ME.cuda_version()) print(ME.get_gpu_memory_info())