Example #1
0
def validate_with_gt(args, val_loader, disp_net, epoch, logger, tb_writer, sample_nb_to_log=3):
    global device
    batch_time = AverageMeter()
    error_names = ['abs_diff', 'abs_rel', 'sq_rel', 'a1', 'a2', 'a3']
    errors = AverageMeter(i=len(error_names))
    log_outputs = sample_nb_to_log > 0
    # Output the logs throughout the whole dataset
    batches_to_log = list(np.linspace(0, len(val_loader)-1, sample_nb_to_log).astype(int))

    # switch to evaluate mode
    disp_net.eval()

    end = time.time()
    logger.valid_bar.update(0)
    for i, (tgt_img, depth) in enumerate(val_loader):
        tgt_img = tgt_img.to(device)
        depth = depth.to(device)

        # compute output
        output_disp = disp_net(tgt_img)
        output_depth = 1/output_disp[:, 0]

        if log_outputs and i in batches_to_log:
            index = batches_to_log.index(i)
            if epoch == 0:
                tb_writer.add_image('val Input/{}'.format(index), tensor2array(tgt_img[0]), 0)
                depth_to_show = depth[0]
                tb_writer.add_image('val target Depth Normalized/{}'.format(index),
                                    tensor2array(depth_to_show, max_value=None),
                                    epoch)
                depth_to_show[depth_to_show == 0] = 1000
                disp_to_show = (1/depth_to_show).clamp(0, 10)
                tb_writer.add_image('val target Disparity Normalized/{}'.format(index),
                                    tensor2array(disp_to_show, max_value=None, colormap='magma'),
                                    epoch)

            tb_writer.add_image('val Dispnet Output Normalized/{}'.format(index),
                                tensor2array(output_disp[0], max_value=None, colormap='magma'),
                                epoch)
            tb_writer.add_image('val Depth Output Normalized/{}'.format(index),
                                tensor2array(output_depth[0], max_value=None),
                                epoch)
        errors.update(compute_depth_errors(depth, output_depth))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        logger.valid_bar.update(i+1)
        if i % args.print_freq == 0:
            logger.valid_writer.write('valid: Time {} Abs Error {:.4f} ({:.4f})'.format(batch_time, errors.val[0], errors.avg[0]))
    logger.valid_bar.update(len(val_loader))
    return errors.avg, error_names
Example #2
0
def validate_with_gt(args,
                     val_loader,
                     depth_net,
                     pose_net,
                     epoch,
                     logger,
                     output_writers=[],
                     **env):
    global device
    batch_time = AverageMeter()
    depth_error_names = ['abs diff', 'abs rel', 'sq rel', 'a1', 'a2', 'a3']
    stab_depth_errors = AverageMeter(i=len(depth_error_names))
    unstab_depth_errors = AverageMeter(i=len(depth_error_names))
    pose_error_names = ['Absolute Trajectory Error', 'Rotation Error']
    pose_errors = AverageMeter(i=len(pose_error_names))

    # switch to evaluate mode
    depth_net.eval()
    pose_net.eval()

    end = time.time()
    logger.valid_bar.update(0)
    for i, sample in enumerate(val_loader):
        log_output = i < len(output_writers)

        imgs = torch.stack(sample['imgs'], dim=1).to(device)
        batch_size, seq, c, h, w = imgs.size()

        intrinsics = sample['intrinsics'].to(device)
        intrinsics_inv = sample['intrinsics_inv'].to(device)

        if args.network_input_size is not None:
            imgs = F.interpolate(imgs, (c, *args.network_input_size),
                                 mode='area')

            downscale = h / args.network_input_size[0]
            intrinsics = torch.cat(
                (intrinsics[:, 0:2] / downscale, intrinsics[:, 2:]), dim=1)
            intrinsics_inv = torch.cat(
                (intrinsics_inv[:, :, 0:2] * downscale, intrinsics_inv[:, :,
                                                                       2:]),
                dim=2)

        GT_depth = sample['depth'].to(device)
        GT_pose = sample['pose'].to(device)

        mid_index = (args.sequence_length - 1) // 2

        tgt_img = imgs[:, mid_index]

        if epoch == 1 and log_output:
            for j, img in enumerate(sample['imgs']):
                output_writers[i].add_image('val Input', tensor2array(img[0]),
                                            j)
            depth_to_show = GT_depth[0].cpu()
            # KITTI Like data routine to discard invalid data
            depth_to_show[depth_to_show == 0] = 1000
            disp_to_show = (1 / depth_to_show).clamp(0, 10)
            output_writers[i].add_image(
                'val target Disparity Normalized',
                tensor2array(disp_to_show, max_value=None, colormap='bone'),
                epoch)

        poses = pose_net(imgs)
        pose_matrices = pose_vec2mat(poses,
                                     args.rotation_mode)  # [B, seq, 3, 4]
        inverted_pose_matrices = invert_mat(pose_matrices)
        pose_errors.update(
            compute_pose_error(GT_pose[:, :-1],
                               inverted_pose_matrices.data[:, :-1]))

        tgt_poses = pose_matrices[:, mid_index]  # [B, 3, 4]
        compensated_predicted_poses = compensate_pose(pose_matrices, tgt_poses)
        compensated_GT_poses = compensate_pose(GT_pose, GT_pose[:, mid_index])

        for j in range(args.sequence_length):
            if j == mid_index:
                if log_output and epoch == 1:
                    output_writers[i].add_image(
                        'val Input Stabilized',
                        tensor2array(sample['imgs'][j][0]), j)
                continue
            '''compute displacement magnitude for each element of batch, and rescale
            depth accordingly.'''

            prior_img = imgs[:, j]
            displacement = compensated_GT_poses[:, j, :, -1]  # [B,3]
            displacement_magnitude = displacement.norm(p=2, dim=1)  # [B]
            current_GT_depth = GT_depth * args.nominal_displacement / displacement_magnitude.view(
                -1, 1, 1)

            prior_predicted_pose = compensated_predicted_poses[:,
                                                               j]  # [B, 3, 4]
            prior_GT_pose = compensated_GT_poses[:, j]

            prior_predicted_rot = prior_predicted_pose[:, :, :-1]
            prior_GT_rot = prior_GT_pose[:, :, :-1].transpose(1, 2)

            prior_compensated_from_GT = inverse_rotate(prior_img, prior_GT_rot,
                                                       intrinsics,
                                                       intrinsics_inv)
            if log_output and epoch == 1:
                depth_to_show = current_GT_depth[0]
                output_writers[i].add_image(
                    'val target Depth {}'.format(j),
                    tensor2array(depth_to_show, max_value=args.max_depth),
                    epoch)
                output_writers[i].add_image(
                    'val Input Stabilized',
                    tensor2array(prior_compensated_from_GT[0]), j)

            prior_compensated_from_prediction = inverse_rotate(
                prior_img, prior_predicted_rot, intrinsics, intrinsics_inv)
            predicted_input_pair = torch.cat(
                [prior_compensated_from_prediction, tgt_img],
                dim=1)  # [B, 6, W, H]
            GT_input_pair = torch.cat([prior_compensated_from_GT, tgt_img],
                                      dim=1)  # [B, 6, W, H]

            # This is the depth from footage stabilized with GT pose, it should be better than depth from raw footage without any GT info
            raw_depth_stab = depth_net(GT_input_pair)
            raw_depth_unstab = depth_net(predicted_input_pair)

            # Upsample depth so that it matches GT size
            scale_factor = GT_depth.size(-1) // raw_depth_stab.size(-1)
            depth_stab = F.interpolate(raw_depth_stab,
                                       scale_factor=scale_factor,
                                       mode='bilinear',
                                       align_corners=False)
            depth_unstab = F.interpolate(raw_depth_unstab,
                                         scale_factor=scale_factor,
                                         mode='bilinear',
                                         align_corners=False)

            for k, depth in enumerate([depth_stab, depth_unstab]):
                disparity = 1 / depth
                errors = stab_depth_errors if k == 0 else unstab_depth_errors
                errors.update(
                    compute_depth_errors(current_GT_depth, depth, crop=True))
                if log_output:
                    prefix = 'stabilized' if k == 0 else 'unstabilized'
                    output_writers[i].add_image(
                        'val {} Dispnet Output Normalized {}'.format(
                            prefix, j),
                        tensor2array(disparity[0],
                                     max_value=None,
                                     colormap='bone'), epoch)
                    output_writers[i].add_image(
                        'val {} Depth Output {}'.format(prefix, j),
                        tensor2array(depth[0], max_value=args.max_depth),
                        epoch)

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        logger.valid_bar.update(i + 1)
        if i % args.print_freq == 0:
            logger.valid_writer.write(
                'valid: Time {} ATE Error {:.4f} ({:.4f}), Unstab Rel Abs Error {:.4f} ({:.4f})'
                .format(batch_time, pose_errors.val[0], pose_errors.avg[0],
                        unstab_depth_errors.val[1],
                        unstab_depth_errors.avg[1]))
    logger.valid_bar.update(len(val_loader))

    errors = (*pose_errors.avg, *unstab_depth_errors.avg,
              *stab_depth_errors.avg)
    error_names = (*pose_error_names,
                   *['unstab {}'.format(e) for e in depth_error_names],
                   *['stab {}'.format(e) for e in depth_error_names])

    return OrderedDict(zip(error_names, errors))
Example #3
0
def validate_with_gt_pose(args,
                          val_loader,
                          disp_net,
                          pose_exp_net,
                          epoch,
                          logger,
                          tb_writer,
                          sample_nb_to_log=3):
    global device
    batch_time = AverageMeter()
    depth_error_names = ['abs_diff', 'abs_rel', 'sq_rel', 'a1', 'a2', 'a3']
    depth_errors = AverageMeter(i=len(depth_error_names), precision=4)
    pose_error_names = ['ATE', 'RTE']
    pose_errors = AverageMeter(i=2, precision=4)
    log_outputs = sample_nb_to_log > 0
    # Output the logs throughout the whole dataset
    batches_to_log = list(
        np.linspace(0, len(val_loader), sample_nb_to_log).astype(int))
    poses_values = np.zeros(
        ((len(val_loader) - 1) * args.batch_size * (args.sequence_length - 1),
         6))
    disp_values = np.zeros(((len(val_loader) - 1) * args.batch_size * 3))

    # switch to evaluate mode
    disp_net.eval()
    pose_exp_net.eval()

    end = time.time()
    logger.valid_bar.update(0)
    for i, (tgt_img, ref_imgs, gt_depth, gt_poses) in enumerate(val_loader):
        tgt_img = tgt_img.to(device)
        gt_depth = gt_depth.to(device)
        gt_poses = gt_poses.to(device)
        ref_imgs = [img.to(device) for img in ref_imgs]
        b = tgt_img.shape[0]

        # compute output
        output_disp = disp_net(tgt_img)
        output_depth = 1 / output_disp
        explainability_mask, output_poses = pose_exp_net(tgt_img, ref_imgs)

        reordered_output_poses = torch.cat([
            output_poses[:, :gt_poses.shape[1] // 2],
            torch.zeros(b, 1, 6).to(output_poses),
            output_poses[:, gt_poses.shape[1] // 2:]
        ],
                                           dim=1)

        # pose_vec2mat only takes B, 6 tensors, so we simulate a batch dimension of B * seq_length
        unravelled_poses = reordered_output_poses.reshape(-1, 6)
        unravelled_matrices = pose_vec2mat(unravelled_poses,
                                           rotation_mode=args.rotation_mode)
        inv_transform_matrices = unravelled_matrices.reshape(b, -1, 3, 4)

        rot_matrices = inv_transform_matrices[..., :3].transpose(-2, -1)
        tr_vectors = -rot_matrices @ inv_transform_matrices[..., -1:]

        transform_matrices = torch.cat([rot_matrices, tr_vectors], axis=-1)

        first_inv_transform = inv_transform_matrices.reshape(b, -1, 3,
                                                             4)[:, :1]
        final_poses = first_inv_transform[..., :3] @ transform_matrices
        final_poses[..., -1:] += first_inv_transform[..., -1:]
        final_poses = final_poses.reshape(b, -1, 3, 4)

        if log_outputs and i in batches_to_log:  # log first output of wanted batches
            index = batches_to_log.index(i)
            if epoch == 0:
                for j, ref in enumerate(ref_imgs):
                    tb_writer.add_image('val Input {}/{}'.format(j, index),
                                        tensor2array(tgt_img[0]), 0)
                    tb_writer.add_image('val Input {}/{}'.format(j, index),
                                        tensor2array(ref[0]), 1)

            log_output_tensorboard(tb_writer, 'val', index, '', epoch,
                                   output_depth, output_disp, None, None,
                                   explainability_mask)

        if log_outputs and i < len(val_loader) - 1:
            step = args.batch_size * (args.sequence_length - 1)
            poses_values[i * step:(i + 1) * step] = output_poses.cpu().view(
                -1, 6).numpy()
            step = args.batch_size * 3
            disp_unraveled = output_disp.cpu().view(args.batch_size, -1)
            disp_values[i * step:(i + 1) * step] = torch.cat([
                disp_unraveled.min(-1)[0],
                disp_unraveled.median(-1)[0],
                disp_unraveled.max(-1)[0]
            ]).numpy()

        depth_errors.update(compute_depth_errors(gt_depth, output_depth[:, 0]))
        pose_errors.update(compute_pose_errors(gt_poses, final_poses))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        logger.valid_bar.update(i + 1)
        if i % args.print_freq == 0:
            logger.valid_writer.write(
                'valid: Time {} Abs Error {:.4f} ({:.4f}), ATE {:.4f} ({:.4f})'
                .format(batch_time, depth_errors.val[0], depth_errors.avg[0],
                        pose_errors.val[0], pose_errors.avg[0]))
    if log_outputs:
        prefix = 'valid poses'
        coeffs_names = ['tx', 'ty', 'tz']
        if args.rotation_mode == 'euler':
            coeffs_names.extend(['rx', 'ry', 'rz'])
        elif args.rotation_mode == 'quat':
            coeffs_names.extend(['qx', 'qy', 'qz'])
        for i in range(poses_values.shape[1]):
            tb_writer.add_histogram('{} {}'.format(prefix, coeffs_names[i]),
                                    poses_values[:, i], epoch)
        tb_writer.add_histogram('disp_values', disp_values, epoch)
    logger.valid_bar.update(len(val_loader))
    return depth_errors.avg + pose_errors.avg, depth_error_names + pose_error_names