def evaluate(pred_transforms, data_loader: torch.utils.data.dataloader.DataLoader): """ Evaluates the computed transforms against the groundtruth Args: pred_transforms: Predicted transforms (B, [iter], 3/4, 4) data_loader: Loader for dataset. Returns: Computed metrics (List of dicts), and summary metrics (only for last iter) """ _logger.info('Evaluating transforms...') num_processed, num_total = 0, len(pred_transforms) if pred_transforms.ndim == 4: pred_transforms = torch.from_numpy(pred_transforms).to(_device) else: assert pred_transforms.ndim == 3 and \ (pred_transforms.shape[1:] == (4, 4) or pred_transforms.shape[1:] == (3, 4)) pred_transforms = torch.from_numpy( pred_transforms[:, None, :, :]).to(_device) metrics_for_iter = [ defaultdict(list) for _ in range(pred_transforms.shape[1]) ] for data in tqdm(data_loader, leave=False): dict_all_to_device(data, _device) batch_size = 0 for i_iter in range(pred_transforms.shape[1]): batch_size = data['points_src'].shape[0] cur_pred_transforms = pred_transforms[num_processed:num_processed + batch_size, i_iter, :, :] metrics = compute_metrics(data, cur_pred_transforms) for k in metrics: metrics_for_iter[i_iter][k].append(metrics[k]) num_processed += batch_size for i_iter in range(len(metrics_for_iter)): metrics_for_iter[i_iter] = { k: np.concatenate(metrics_for_iter[i_iter][k], axis=0) for k in metrics_for_iter[i_iter] } summary_metrics = summarize_metrics(metrics_for_iter[i_iter]) print_metrics(_logger, summary_metrics, title='Evaluation result (iter {})'.format(i_iter)) return metrics_for_iter, summary_metrics
def evaluate(data_loader, model: torch.nn.Module): """ Evaluates the model's prediction against the groundtruth """ _logger.info('Starting evaluation...') with torch.no_grad(): all_test_metrics_np = defaultdict(list) for test_data in data_loader: dict_all_to_device(test_data, _device) pred_test_transforms, endpoints = model(test_data, _args.num_reg_iter) test_metrics = compute_metrics(test_data, pred_test_transforms[-1],endpoints['perm_matrices'][-1]) for k in test_metrics: all_test_metrics_np[k].append(test_metrics[k]) all_test_metrics_np = {k: np.concatenate(all_test_metrics_np[k]) for k in all_test_metrics_np} summary_metrics = summarize_metrics(all_test_metrics_np) print_metrics(_logger, summary_metrics, title='Evaluation results')
def validate(data_loader, model: torch.nn.Module, summary_writer: SummaryWriter, step: int): """Perform a single validation run, and saves results into tensorboard summaries""" _logger.info('Starting validation run...') with torch.no_grad(): all_val_losses = defaultdict(list) all_val_metrics_np = defaultdict(list) for val_data in data_loader: dict_all_to_device(val_data, _device) pred_test_transforms, endpoints = model(val_data, _args.num_reg_iter) val_losses = compute_losses(val_data, pred_test_transforms, endpoints, loss_type=_args.loss_type, reduction='none') val_metrics = compute_metrics(val_data, pred_test_transforms[-1]) for k in val_losses: all_val_losses[k].append(val_losses[k]) for k in val_metrics: all_val_metrics_np[k].append(val_metrics[k]) all_val_losses = {k: torch.cat(all_val_losses[k]) for k in all_val_losses} all_val_metrics_np = {k: np.concatenate(all_val_metrics_np[k]) for k in all_val_metrics_np} mean_val_losses = {k: torch.mean(all_val_losses[k]) for k in all_val_losses} # Rerun on random and worst data instances and save to summary rand_idx = random.randint(0, all_val_losses['total'].shape[0] - 1) worst_idx = torch.argmax(all_val_losses['{}_{}'.format(_args.loss_type, _args.num_reg_iter - 1)]).cpu().item() indices_to_rerun = [rand_idx, worst_idx] data_to_rerun = defaultdict(list) for i in indices_to_rerun: cur = data_loader.dataset[i] for k in cur: data_to_rerun[k].append(cur[k]) for k in data_to_rerun: data_to_rerun[k] = torch.from_numpy(np.stack(data_to_rerun[k], axis=0)) dict_all_to_device(data_to_rerun, _device) pred_transforms, endpoints = model(data_to_rerun, _args.num_reg_iter) summary_metrics = summarize_metrics(all_val_metrics_np) losses_by_iteration = torch.stack([mean_val_losses['{}_{}'.format(_args.loss_type, k)] for k in range(_args.num_reg_iter)]).cpu().numpy() print_metrics(_logger, summary_metrics, losses_by_iteration, 'Validation results') save_summaries(summary_writer, data=data_to_rerun, predicted=pred_transforms, endpoints=endpoints, losses=mean_val_losses, metrics=summary_metrics, step=step) score = -summary_metrics['chamfer_dist'] return score
def run(train_set, val_set): """Main train/val loop""" _logger.debug('Trainer (PID=%d), %s', os.getpid(), _args) model = get_model(_args) model.to(_device) global_step = 0 # dataloaders train_loader = torch.utils.data.DataLoader(train_set, batch_size=_args.train_batch_size, shuffle=True, num_workers=_args.num_workers) val_loader = torch.utils.data.DataLoader(val_set, batch_size=_args.val_batch_size, shuffle=False, num_workers=_args.num_workers) # optimizer optimizer = torch.optim.Adam(model.parameters(), lr=_args.lr) # Summary writer and Checkpoint manager train_writer = SummaryWriter(os.path.join(_log_path, 'train'), flush_secs=10) val_writer = SummaryWriter(os.path.join(_log_path, 'val'), flush_secs=10) saver = CheckPointManager(os.path.join(_log_path, 'ckpt', 'model'), keep_checkpoint_every_n_hours=0.5) if _args.resume is not None: global_step = saver.load(_args.resume, model, optimizer) # trainings torch.autograd.set_detect_anomaly(_args.debug) model.train() steps_per_epoch = len(train_loader) if _args.summary_every < 0: _args.summary_every = abs(_args.summary_every) * steps_per_epoch if _args.validate_every < 0: _args.validate_every = abs(_args.validate_every) * steps_per_epoch for epoch in range(0, _args.epochs): _logger.info('Begin epoch {} (steps {} - {})'.format(epoch, global_step, global_step + len(train_loader))) tbar = tqdm(total=len(train_loader), ncols=100) for train_data in train_loader: global_step += 1 optimizer.zero_grad() # Forward through neural network dict_all_to_device(train_data, _device) pred_transforms, endpoints = model(train_data, _args.num_train_reg_iter) # Use less iter during training # Compute loss, and optimize train_losses = compute_losses(train_data, pred_transforms, endpoints, loss_type=_args.loss_type, reduction='mean') if _args.debug: with TorchDebugger(): train_losses['total'].backward() else: train_losses['total'].backward() optimizer.step() tbar.set_description('Loss:{:.3g}'.format(train_losses['total'])) tbar.update(1) if global_step % _args.summary_every == 0: # Save tensorboard logs save_summaries(train_writer, data=train_data, predicted=pred_transforms, endpoints=endpoints, losses=train_losses, step=global_step) if global_step % _args.validate_every == 0: # Validation loop. Also saves checkpoints model.eval() val_score = validate(val_loader, model, val_writer, global_step) saver.save(model, optimizer, step=global_step, score=val_score) model.train() tbar.close() _logger.info('Ending training. Number of steps = {}.'.format(global_step))
def inference(data_loader, model: torch.nn.Module): """Runs inference over entire dataset Args: data_loader (torch.utils.data.DataLoader): Dataset loader model (model.nn.Module): Network model to evaluate Returns: pred_transforms_all: predicted transforms (B, n_iter, 3, 4) where B is total number of instances endpoints_out (Dict): Network endpoints """ _logger.info('Starting inference...') model.eval() pred_transforms_all = [] all_betas, all_alphas = [], [] total_time = 0.0 endpoints_out = defaultdict(list) total_rotation = [] with torch.no_grad(): for val_data in tqdm(data_loader): rot_trace = val_data['transform_gt'][:, 0, 0] + val_data['transform_gt'][:, 1, 1] + \ val_data['transform_gt'][:, 2, 2] rotdeg = torch.acos( torch.clamp(0.5 * (rot_trace - 1), min=-1.0, max=1.0)) * 180.0 / np.pi total_rotation.append(np.abs(to_numpy(rotdeg))) dict_all_to_device(val_data, _device) time_before = time.time() pred_transforms, endpoints = model(val_data, _args.num_reg_iter) total_time += time.time() - time_before if _args.method == 'rpmnet': all_betas.append(endpoints['beta']) all_alphas.append(endpoints['alpha']) if isinstance(pred_transforms[-1], torch.Tensor): pred_transforms_all.append( to_numpy(torch.stack(pred_transforms, dim=1))) else: pred_transforms_all.append(np.stack(pred_transforms, axis=1)) # Saves match matrix. We only save the top matches to save storage/time. # However, this still takes quite a bit of time to save. Comment out if not needed. if 'perm_matrices' in endpoints: perm_matrices = to_numpy( torch.stack(endpoints['perm_matrices'], dim=1)) thresh = np.percentile( perm_matrices, 99.9, axis=[2, 3]) # Only retain top 0.1% of entries below_thresh_mask = perm_matrices < thresh[:, :, None, None] perm_matrices[below_thresh_mask] = 0.0 for i_data in range(perm_matrices.shape[0]): sparse_perm_matrices = [] for i_iter in range(perm_matrices.shape[1]): sparse_perm_matrices.append( sparse.coo_matrix(perm_matrices[i_data, i_iter, :, :])) endpoints_out['perm_matrices'].append(sparse_perm_matrices) _logger.info('Total inference time: {}s'.format(total_time)) total_rotation = np.concatenate(total_rotation, axis=0) _logger.info('Rotation range in data: {}(avg), {}(max)'.format( np.mean(total_rotation), np.max(total_rotation))) pred_transforms_all = np.concatenate(pred_transforms_all, axis=0) return pred_transforms_all, endpoints_out