def _train_epoch(self, epoch):
        gc.collect()

        # Fix the feature model and train the inlier model
        self.feat_model.eval()
        self.inlier_model.train()

        # Epoch starts from 1
        total_loss, total_num = 0, 0.0
        data_loader = self.data_loader
        iter_size = self.iter_size

        # Meters for statistics
        average_valid_meter = AverageMeter()
        loss_meter = AverageMeter()
        data_meter = AverageMeter()
        regist_succ_meter = AverageMeter()
        regist_rte_meter = AverageMeter()
        regist_rre_meter = AverageMeter()

        # Timers for profiling
        data_timer = Timer()
        nn_timer = Timer()
        inlier_timer = Timer()
        total_timer = Timer()

        if self.config.num_train_iter > 0:
            num_train_iter = self.config.num_train_iter
        else:
            num_train_iter = len(data_loader) // iter_size
        start_iter = (epoch - 1) * num_train_iter

        tp, fp, tn, fn = 0, 0, 0, 0

        # Iterate over batches
        for curr_iter in range(num_train_iter):
            self.optimizer.zero_grad()

            batch_loss, data_time = 0, 0
            total_timer.tic()

            for iter_idx in range(iter_size):
                data_timer.tic()
                input_dict = self.get_data(self.train_data_loader_iter)
                data_time += data_timer.toc(average=False)

                # Initial inlier prediction with FCGF and KNN matching
                reg_coords, reg_feats, pred_pairs, is_correct, feat_time, nn_time = self.generate_inlier_input(
                    xyz0=input_dict['pcd0'],
                    xyz1=input_dict['pcd1'],
                    iC0=input_dict['sinput0_C'],
                    iC1=input_dict['sinput1_C'],
                    iF0=input_dict['sinput0_F'],
                    iF1=input_dict['sinput1_F'],
                    len_batch=input_dict['len_batch'],
                    pos_pairs=input_dict['correspondences'])
                nn_timer.update(nn_time)

                # Inlier prediction with 6D ConvNet
                inlier_timer.tic()
                reg_sinput = ME.SparseTensor(reg_feats.contiguous(),
                                             coordinates=reg_coords.int(),
                                             device=self.device)
                reg_soutput = self.inlier_model(reg_sinput)
                inlier_timer.toc()

                logits = reg_soutput.F
                weights = logits.sigmoid()

                # Truncate weights too low
                # For training, inplace modification is prohibited for backward
                if self.clip_weight_thresh > 0:
                    weights_tmp = torch.zeros_like(weights)
                    valid_mask = weights > self.clip_weight_thresh
                    weights_tmp[valid_mask] = weights[valid_mask]
                    weights = weights_tmp

                # Weighted Procrustes
                pred_rots, pred_trans, ws = self.weighted_procrustes(
                    xyz0s=input_dict['pcd0'],
                    xyz1s=input_dict['pcd1'],
                    pred_pairs=pred_pairs,
                    weights=weights)

                # Get batch registration loss
                gt_rots, gt_trans = self.decompose_rotation_translation(
                    input_dict['T_gt'])
                rot_error = batch_rotation_error(pred_rots, gt_rots)
                trans_error = batch_translation_error(pred_trans, gt_trans)
                individual_loss = rot_error + self.config.trans_weight * trans_error

                # Select batches with at least 10 valid correspondences
                valid_mask = ws > 10
                num_valid = valid_mask.sum().item()
                average_valid_meter.update(num_valid)

                # Registration loss against registration GT
                loss = self.config.procrustes_loss_weight * individual_loss[
                    valid_mask].mean()
                if not np.isfinite(loss.item()):
                    max_val = loss.item()
                    logging.info('Loss is infinite, abort ')
                    continue

                # Direct inlier loss against nearest neighbor searched GT
                target = torch.from_numpy(is_correct).squeeze()
                if self.config.inlier_use_direct_loss:
                    inlier_loss = self.config.inlier_direct_loss_weight * self.crit(
                        logits.cpu().squeeze(), target.to(
                            torch.float)) / iter_size
                    loss += inlier_loss

                loss.backward()

                # Update statistics before backprop
                with torch.no_grad():
                    regist_rre_meter.update(rot_error.squeeze() * 180 / np.pi)
                    regist_rte_meter.update(trans_error.squeeze())

                    success = (trans_error.squeeze() <
                               self.config.success_rte_thresh) * (
                                   rot_error.squeeze() * 180 / np.pi <
                                   self.config.success_rre_thresh)
                    regist_succ_meter.update(success.float())

                    batch_loss += loss.mean().item()

                    neg_target = (~target).to(torch.bool)
                    pred = logits > 0  # todo thresh
                    pred_on_pos, pred_on_neg = pred[target], pred[neg_target]
                    tp += pred_on_pos.sum().item()
                    fp += pred_on_neg.sum().item()
                    tn += (~pred_on_neg).sum().item()
                    fn += (~pred_on_pos).sum().item()

                    # Check gradient and avoid backprop of inf values
                    max_grad = torch.abs(self.inlier_model.final.kernel.grad
                                         ).max().cpu().item()

                # Backprop only if gradient is finite
                if not np.isfinite(max_grad):
                    self.optimizer.zero_grad()
                    logging.info(
                        f'Clearing the NaN gradient at iter {curr_iter}')
                else:
                    self.optimizer.step()

            total_loss += batch_loss
            total_num += 1.0
            total_timer.toc()
            data_meter.update(data_time)
            loss_meter.update(batch_loss)

            # Output to logs
            if curr_iter % self.config.stat_freq == 0:
                precision = tp / (tp + fp + eps)
                recall = tp / (tp + fn + eps)
                f1 = 2 * (precision * recall) / (precision + recall + eps)
                tpr = tp / (tp + fn + eps)
                tnr = tn / (tn + fp + eps)
                balanced_accuracy = (tpr + tnr) / 2

                correspondence_accuracy = is_correct.sum() / len(is_correct)

                total, free = ME.get_gpu_memory_info()
                used = (total - free) / 1073741824.0
                stat = {
                    'loss': loss_meter.avg,
                    'precision': precision,
                    'recall': recall,
                    'tpr': tpr,
                    'tnr': tnr,
                    'balanced_accuracy': balanced_accuracy,
                    'f1': f1,
                    'num_valid': average_valid_meter.avg,
                    'gpu_used': used,
                }

                for k, v in stat.items():
                    self.writer.add_scalar(f'train/{k}', v,
                                           start_iter + curr_iter)

                logging.info(' '.join([
                    f"Train Epoch: {epoch} [{curr_iter}/{num_train_iter}],",
                    f"Current Loss: {loss_meter.avg:.3e},",
                    f"Correspondence acc: {correspondence_accuracy:.3e}",
                    f", Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f},",
                    f"TPR: {tpr:.4f}, TNR: {tnr:.4f}, BAcc: {balanced_accuracy:.4f}",
                    f"RTE: {regist_rte_meter.avg:.3e}, RRE: {regist_rre_meter.avg:.3e},",
                    f"Succ rate: {regist_succ_meter.avg:3e}",
                    f"Avg num valid: {average_valid_meter.avg:3e}",
                    f"\tData time: {data_meter.avg:.4f}, Train time: {total_timer.avg - data_meter.avg:.4f},",
                    f"NN search time: {nn_timer.avg:.3e}, Total time: {total_timer.avg:.4f}"
                ]))

                loss_meter.reset()
                regist_rte_meter.reset()
                regist_rre_meter.reset()
                regist_succ_meter.reset()
                average_valid_meter.reset()
                data_meter.reset()
                total_timer.reset()

                tp, fp, tn, fn = 0, 0, 0, 0
 def test(self):
     self.assertTrue(ME.is_cuda_available() == torch.cuda.is_available())
     if ME.is_cuda_available():
         print(ME.cuda_version())
         print(ME.get_gpu_memory_info())