Exemplo n.º 1
0
def evaluate(pred_transforms,
             data_loader: torch.utils.data.dataloader.DataLoader):
    """ Evaluates the computed transforms against the groundtruth

    Args:
        pred_transforms: Predicted transforms (B, [iter], 3/4, 4)
        data_loader: Loader for dataset.

    Returns:
        Computed metrics (List of dicts), and summary metrics (only for last iter)
    """

    _logger.info('Evaluating transforms...')
    num_processed, num_total = 0, len(pred_transforms)

    if pred_transforms.ndim == 4:
        pred_transforms = torch.from_numpy(pred_transforms).to(_device)
    else:
        assert pred_transforms.ndim == 3 and \
               (pred_transforms.shape[1:] == (4, 4) or pred_transforms.shape[1:] == (3, 4))
        pred_transforms = torch.from_numpy(
            pred_transforms[:, None, :, :]).to(_device)

    metrics_for_iter = [
        defaultdict(list) for _ in range(pred_transforms.shape[1])
    ]

    for data in tqdm(data_loader, leave=False):
        dict_all_to_device(data, _device)

        batch_size = 0
        for i_iter in range(pred_transforms.shape[1]):
            batch_size = data['points_src'].shape[0]

            cur_pred_transforms = pred_transforms[num_processed:num_processed +
                                                  batch_size, i_iter, :, :]
            metrics = compute_metrics(data, cur_pred_transforms)
            for k in metrics:
                metrics_for_iter[i_iter][k].append(metrics[k])
        num_processed += batch_size

    for i_iter in range(len(metrics_for_iter)):
        metrics_for_iter[i_iter] = {
            k: np.concatenate(metrics_for_iter[i_iter][k], axis=0)
            for k in metrics_for_iter[i_iter]
        }
        summary_metrics = summarize_metrics(metrics_for_iter[i_iter])
        print_metrics(_logger,
                      summary_metrics,
                      title='Evaluation result (iter {})'.format(i_iter))

    return metrics_for_iter, summary_metrics
Exemplo n.º 2
0
def evaluate(data_loader, model: torch.nn.Module):
    """ Evaluates the model's prediction against the groundtruth """
    _logger.info('Starting evaluation...')
    with torch.no_grad():
        all_test_metrics_np = defaultdict(list)
        for test_data in data_loader:
            dict_all_to_device(test_data, _device)
            pred_test_transforms, endpoints = model(test_data, _args.num_reg_iter)
            test_metrics = compute_metrics(test_data, pred_test_transforms[-1],endpoints['perm_matrices'][-1])
            for k in test_metrics:
                all_test_metrics_np[k].append(test_metrics[k])
        all_test_metrics_np = {k: np.concatenate(all_test_metrics_np[k]) for k in all_test_metrics_np}

    summary_metrics = summarize_metrics(all_test_metrics_np)
    print_metrics(_logger, summary_metrics, title='Evaluation results')
Exemplo n.º 3
0
def validate(data_loader, model: torch.nn.Module, summary_writer: SummaryWriter, step: int):
    """Perform a single validation run, and saves results into tensorboard summaries"""

    _logger.info('Starting validation run...')

    with torch.no_grad():
        all_val_losses = defaultdict(list)
        all_val_metrics_np = defaultdict(list)
        for val_data in data_loader:
            dict_all_to_device(val_data, _device)
            pred_test_transforms, endpoints = model(val_data, _args.num_reg_iter)
            val_losses = compute_losses(val_data, pred_test_transforms, endpoints,
                                        loss_type=_args.loss_type, reduction='none')
            val_metrics = compute_metrics(val_data, pred_test_transforms[-1])

            for k in val_losses:
                all_val_losses[k].append(val_losses[k])
            for k in val_metrics:
                all_val_metrics_np[k].append(val_metrics[k])

        all_val_losses = {k: torch.cat(all_val_losses[k]) for k in all_val_losses}
        all_val_metrics_np = {k: np.concatenate(all_val_metrics_np[k]) for k in all_val_metrics_np}
        mean_val_losses = {k: torch.mean(all_val_losses[k]) for k in all_val_losses}

    # Rerun on random and worst data instances and save to summary
    rand_idx = random.randint(0, all_val_losses['total'].shape[0] - 1)
    worst_idx = torch.argmax(all_val_losses['{}_{}'.format(_args.loss_type, _args.num_reg_iter - 1)]).cpu().item()
    indices_to_rerun = [rand_idx, worst_idx]
    data_to_rerun = defaultdict(list)
    for i in indices_to_rerun:
        cur = data_loader.dataset[i]
        for k in cur:
            data_to_rerun[k].append(cur[k])
    for k in data_to_rerun:
        data_to_rerun[k] = torch.from_numpy(np.stack(data_to_rerun[k], axis=0))
    dict_all_to_device(data_to_rerun, _device)
    pred_transforms, endpoints = model(data_to_rerun, _args.num_reg_iter)

    summary_metrics = summarize_metrics(all_val_metrics_np)
    losses_by_iteration = torch.stack([mean_val_losses['{}_{}'.format(_args.loss_type, k)]
                                       for k in range(_args.num_reg_iter)]).cpu().numpy()
    print_metrics(_logger, summary_metrics, losses_by_iteration, 'Validation results')

    save_summaries(summary_writer, data=data_to_rerun, predicted=pred_transforms, endpoints=endpoints,
                   losses=mean_val_losses, metrics=summary_metrics, step=step)

    score = -summary_metrics['chamfer_dist']
    return score
Exemplo n.º 4
0
def run(train_set, val_set):
    """Main train/val loop"""

    _logger.debug('Trainer (PID=%d), %s', os.getpid(), _args)

    model = get_model(_args)
    model.to(_device)
    global_step = 0

    # dataloaders
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=_args.train_batch_size, shuffle=True, num_workers=_args.num_workers)
    val_loader = torch.utils.data.DataLoader(val_set,
                                             batch_size=_args.val_batch_size, shuffle=False, num_workers=_args.num_workers)

    # optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=_args.lr)

    # Summary writer and Checkpoint manager
    train_writer = SummaryWriter(os.path.join(_log_path, 'train'), flush_secs=10)
    val_writer = SummaryWriter(os.path.join(_log_path, 'val'), flush_secs=10)
    saver = CheckPointManager(os.path.join(_log_path, 'ckpt', 'model'), keep_checkpoint_every_n_hours=0.5)
    if _args.resume is not None:
        global_step = saver.load(_args.resume, model, optimizer)

    # trainings
    torch.autograd.set_detect_anomaly(_args.debug)
    model.train()

    steps_per_epoch = len(train_loader)
    if _args.summary_every < 0:
        _args.summary_every = abs(_args.summary_every) * steps_per_epoch
    if _args.validate_every < 0:
        _args.validate_every = abs(_args.validate_every) * steps_per_epoch

    for epoch in range(0, _args.epochs):
        _logger.info('Begin epoch {} (steps {} - {})'.format(epoch, global_step, global_step + len(train_loader)))
        tbar = tqdm(total=len(train_loader), ncols=100)
        for train_data in train_loader:
            global_step += 1

            optimizer.zero_grad()

            # Forward through neural network
            dict_all_to_device(train_data, _device)
            pred_transforms, endpoints = model(train_data, _args.num_train_reg_iter)  # Use less iter during training

            # Compute loss, and optimize
            train_losses = compute_losses(train_data, pred_transforms, endpoints,
                                          loss_type=_args.loss_type, reduction='mean')
            if _args.debug:
                with TorchDebugger():
                    train_losses['total'].backward()
            else:
                train_losses['total'].backward()
            optimizer.step()

            tbar.set_description('Loss:{:.3g}'.format(train_losses['total']))
            tbar.update(1)

            if global_step % _args.summary_every == 0:  # Save tensorboard logs
                save_summaries(train_writer, data=train_data, predicted=pred_transforms, endpoints=endpoints,
                               losses=train_losses, step=global_step)

            if global_step % _args.validate_every == 0:  # Validation loop. Also saves checkpoints
                model.eval()
                val_score = validate(val_loader, model, val_writer, global_step)
                saver.save(model, optimizer, step=global_step, score=val_score)
                model.train()

        tbar.close()

    _logger.info('Ending training. Number of steps = {}.'.format(global_step))
Exemplo n.º 5
0
def inference(data_loader, model: torch.nn.Module):
    """Runs inference over entire dataset

    Args:
        data_loader (torch.utils.data.DataLoader): Dataset loader
        model (model.nn.Module): Network model to evaluate

    Returns:
        pred_transforms_all: predicted transforms (B, n_iter, 3, 4) where B is total number of instances
        endpoints_out (Dict): Network endpoints
    """

    _logger.info('Starting inference...')
    model.eval()

    pred_transforms_all = []
    all_betas, all_alphas = [], []
    total_time = 0.0
    endpoints_out = defaultdict(list)
    total_rotation = []

    with torch.no_grad():
        for val_data in tqdm(data_loader):

            rot_trace = val_data['transform_gt'][:, 0, 0] + val_data['transform_gt'][:, 1, 1] + \
                        val_data['transform_gt'][:, 2, 2]
            rotdeg = torch.acos(
                torch.clamp(0.5 * (rot_trace - 1), min=-1.0,
                            max=1.0)) * 180.0 / np.pi
            total_rotation.append(np.abs(to_numpy(rotdeg)))

            dict_all_to_device(val_data, _device)
            time_before = time.time()
            pred_transforms, endpoints = model(val_data, _args.num_reg_iter)
            total_time += time.time() - time_before

            if _args.method == 'rpmnet':
                all_betas.append(endpoints['beta'])
                all_alphas.append(endpoints['alpha'])

            if isinstance(pred_transforms[-1], torch.Tensor):
                pred_transforms_all.append(
                    to_numpy(torch.stack(pred_transforms, dim=1)))
            else:
                pred_transforms_all.append(np.stack(pred_transforms, axis=1))

            # Saves match matrix. We only save the top matches to save storage/time.
            # However, this still takes quite a bit of time to save. Comment out if not needed.
            if 'perm_matrices' in endpoints:
                perm_matrices = to_numpy(
                    torch.stack(endpoints['perm_matrices'], dim=1))
                thresh = np.percentile(
                    perm_matrices, 99.9,
                    axis=[2, 3])  # Only retain top 0.1% of entries
                below_thresh_mask = perm_matrices < thresh[:, :, None, None]
                perm_matrices[below_thresh_mask] = 0.0

                for i_data in range(perm_matrices.shape[0]):
                    sparse_perm_matrices = []
                    for i_iter in range(perm_matrices.shape[1]):
                        sparse_perm_matrices.append(
                            sparse.coo_matrix(perm_matrices[i_data,
                                                            i_iter, :, :]))
                    endpoints_out['perm_matrices'].append(sparse_perm_matrices)

    _logger.info('Total inference time: {}s'.format(total_time))
    total_rotation = np.concatenate(total_rotation, axis=0)
    _logger.info('Rotation range in data: {}(avg), {}(max)'.format(
        np.mean(total_rotation), np.max(total_rotation)))
    pred_transforms_all = np.concatenate(pred_transforms_all, axis=0)

    return pred_transforms_all, endpoints_out