def main():
    global best_error, worst_error
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    args = parser.parse_args()
    if args.gt_type == 'KITTI':
        from kitti_eval.depth_evaluation_utils import test_framework_KITTI as test_framework
    elif args.gt_type == 'stillbox':
        from stillbox_eval.depth_evaluation_utils import test_framework_stillbox as test_framework

    weights = torch.load(args.pretrained_depthnet)
    depth_net = DepthNet(depth_activation="elu", batch_norm='bn' in weights.keys() and weights['bn']).to(device)

    depth_net.load_state_dict(weights['state_dict'])
    depth_net.eval()

    if args.pretrained_posenet is None:
        args.stabilize_from_GT = True
        print('no PoseNet specified, stab will be done from ground truth')
        seq_length = 5
    else:
        weights = torch.load(args.pretrained_posenet)
        seq_length = int(weights['state_dict']['conv1.0.weight'].size(1)/3)
        pose_net = PoseNet(seq_length=seq_length).to(device)
        pose_net.load_state_dict(weights['state_dict'], strict=False)

    dataset_dir = Path(args.dataset_dir)
    if args.dataset_list is not None:
        with open(args.dataset_list, 'r') as f:
            test_files = list(f.read().splitlines())
    else:
        test_files = [file.relpathto(dataset_dir) for file in sum([dataset_dir.files('*.{}'.format(ext)) for ext in args.img_exts], [])]

    framework = test_framework(dataset_dir, test_files, seq_length, args.min_depth, args.max_depth)

    print('{} files to test'.format(len(test_files)))
    errors = np.zeros((7, len(test_files)), np.float32)

    args.output_dir = Path(args.output_dir)
    args.output_dir.makedirs_p()

    for j, sample in enumerate(tqdm(framework)):
        imgs = sample['imgs']
        intrinsics = sample['intrinsics'].copy()

        h,w,_ = imgs[0].shape
        if (not args.no_resize) and (h != args.img_height or w != args.img_width):
            imgs = [imresize(img, (args.img_height, args.img_width)).astype(np.float32) for img in imgs]
            intrinsics[0] *= args.img_width/w
            intrinsics[1] *= args.img_height/h

        intrinsics_inv = np.linalg.inv(intrinsics)

        intrinsics = torch.from_numpy(intrinsics).unsqueeze(0).to(device)
        intrinsics_inv = torch.from_numpy(intrinsics_inv).unsqueeze(0).to(device)
        imgs = [torch.from_numpy(np.transpose(img, (2,0,1))) for img in imgs]
        imgs = torch.stack(imgs).unsqueeze(0).to(device)
        imgs = 2*(imgs/255 - 0.5)

        tgt_img = imgs[:,sample['tgt_index']]

        # Construct a batch of all possible stabilized pairs, with PoseNet or with GT orientation, will take the output closest to target mean depth
        if args.stabilize_from_GT:
            poses_GT = Variable(torch.from_numpy(sample['poses']).cuda()).unsqueeze(0)
            inv_poses_GT = invert_mat(poses_GT)
            tgt_pose = inv_poses_GT[:,sample['tgt_index']]
            inv_transform_matrices_tgt = compensate_pose(inv_poses_GT, tgt_pose)
        else:
            poses = pose_net(imgs)
            inv_transform_matrices = pose_vec2mat(poses, rotation_mode=args.rotation_mode)

            tgt_pose = inv_transform_matrices[:,sample['tgt_index']]
            inv_transform_matrices_tgt = compensate_pose(inv_transform_matrices, tgt_pose)

        stabilized_pairs = []
        corresponding_displ = []
        for i in range(seq_length):
            if i == sample['tgt_index']:
                continue
            img = imgs[:,i]
            img_pose = inv_transform_matrices_tgt[:,i]
            stab_img = inverse_rotate(img, img_pose[:,:,:3], intrinsics, intrinsics_inv)
            pair = torch.cat([stab_img, tgt_img], dim=1)  # [1, 6, H, W]
            stabilized_pairs.append(pair)

            GT_translations = sample['poses'][:,:,-1]
            real_displacement = np.linalg.norm(GT_translations[sample['tgt_index']] - GT_translations[i])
            corresponding_displ.append(real_displacement)
        stab_batch = torch.cat(stabilized_pairs)  # [seq, 6, H, W]
        depth_maps = depth_net(stab_batch)  # [seq, 1 , H/4, W/4]

        selected_depth, selected_index = select_best_map(depth_maps, target_mean_depthnet_output)

        pred_depth = selected_depth.cpu().data.numpy() * corresponding_displ[selected_index] / args.nominal_displacement

        if args.save_output:
            if j == 0:
                predictions = np.zeros((len(test_files), *pred_depth.shape))
            predictions[j] = 1/pred_depth

        gt_depth = sample['gt_depth']
        pred_depth_zoomed = zoom(pred_depth,
                                 (gt_depth.shape[0]/pred_depth.shape[0],
                                  gt_depth.shape[1]/pred_depth.shape[1])
                                 ).clip(args.min_depth, args.max_depth)
        if sample['mask'] is not None:
            pred_depth_zoomed_masked = pred_depth_zoomed[sample['mask']]
            gt_depth = gt_depth[sample['mask']]
        errors[:,j] = compute_errors(gt_depth, pred_depth_zoomed_masked)
        if args.log_best_worst:
            if best_error > errors[0,j]:
                best_error = errors[0,j]
                log_result(pred_depth_zoomed, sample['gt_depth'], stab_batch, selected_index, args.output_dir, 'best')
            if worst_error < errors[0,j]:
                worst_error = errors[0,j]
                log_result(pred_depth_zoomed, sample['gt_depth'], stab_batch, selected_index, args.output_dir, 'worst')

    mean_errors = errors.mean(1)
    error_names = ['abs_rel','sq_rel','rms','log_rms','a1','a2','a3']

    print("Results : ")
    print("{:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}".format(*error_names))
    print("{:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}".format(*mean_errors))

    if args.save_output:
        np.save(args.output_dir/'predictions.npy', predictions)
Example #2
0
def validate_with_gt(args,
                     val_loader,
                     depth_net,
                     pose_net,
                     epoch,
                     logger,
                     output_writers=[],
                     **env):
    global device
    batch_time = AverageMeter()
    depth_error_names = ['abs diff', 'abs rel', 'sq rel', 'a1', 'a2', 'a3']
    stab_depth_errors = AverageMeter(i=len(depth_error_names))
    unstab_depth_errors = AverageMeter(i=len(depth_error_names))
    pose_error_names = ['Absolute Trajectory Error', 'Rotation Error']
    pose_errors = AverageMeter(i=len(pose_error_names))

    # switch to evaluate mode
    depth_net.eval()
    pose_net.eval()

    end = time.time()
    logger.valid_bar.update(0)
    for i, sample in enumerate(val_loader):
        log_output = i < len(output_writers)

        imgs = torch.stack(sample['imgs'], dim=1).to(device)
        batch_size, seq, c, h, w = imgs.size()

        intrinsics = sample['intrinsics'].to(device)
        intrinsics_inv = sample['intrinsics_inv'].to(device)

        if args.network_input_size is not None:
            imgs = F.interpolate(imgs, (c, *args.network_input_size),
                                 mode='area')

            downscale = h / args.network_input_size[0]
            intrinsics = torch.cat(
                (intrinsics[:, 0:2] / downscale, intrinsics[:, 2:]), dim=1)
            intrinsics_inv = torch.cat(
                (intrinsics_inv[:, :, 0:2] * downscale, intrinsics_inv[:, :,
                                                                       2:]),
                dim=2)

        GT_depth = sample['depth'].to(device)
        GT_pose = sample['pose'].to(device)

        mid_index = (args.sequence_length - 1) // 2

        tgt_img = imgs[:, mid_index]

        if epoch == 1 and log_output:
            for j, img in enumerate(sample['imgs']):
                output_writers[i].add_image('val Input', tensor2array(img[0]),
                                            j)
            depth_to_show = GT_depth[0].cpu()
            # KITTI Like data routine to discard invalid data
            depth_to_show[depth_to_show == 0] = 1000
            disp_to_show = (1 / depth_to_show).clamp(0, 10)
            output_writers[i].add_image(
                'val target Disparity Normalized',
                tensor2array(disp_to_show, max_value=None, colormap='bone'),
                epoch)

        poses = pose_net(imgs)
        pose_matrices = pose_vec2mat(poses,
                                     args.rotation_mode)  # [B, seq, 3, 4]
        inverted_pose_matrices = invert_mat(pose_matrices)
        pose_errors.update(
            compute_pose_error(GT_pose[:, :-1],
                               inverted_pose_matrices.data[:, :-1]))

        tgt_poses = pose_matrices[:, mid_index]  # [B, 3, 4]
        compensated_predicted_poses = compensate_pose(pose_matrices, tgt_poses)
        compensated_GT_poses = compensate_pose(GT_pose, GT_pose[:, mid_index])

        for j in range(args.sequence_length):
            if j == mid_index:
                if log_output and epoch == 1:
                    output_writers[i].add_image(
                        'val Input Stabilized',
                        tensor2array(sample['imgs'][j][0]), j)
                continue
            '''compute displacement magnitude for each element of batch, and rescale
            depth accordingly.'''

            prior_img = imgs[:, j]
            displacement = compensated_GT_poses[:, j, :, -1]  # [B,3]
            displacement_magnitude = displacement.norm(p=2, dim=1)  # [B]
            current_GT_depth = GT_depth * args.nominal_displacement / displacement_magnitude.view(
                -1, 1, 1)

            prior_predicted_pose = compensated_predicted_poses[:,
                                                               j]  # [B, 3, 4]
            prior_GT_pose = compensated_GT_poses[:, j]

            prior_predicted_rot = prior_predicted_pose[:, :, :-1]
            prior_GT_rot = prior_GT_pose[:, :, :-1].transpose(1, 2)

            prior_compensated_from_GT = inverse_rotate(prior_img, prior_GT_rot,
                                                       intrinsics,
                                                       intrinsics_inv)
            if log_output and epoch == 1:
                depth_to_show = current_GT_depth[0]
                output_writers[i].add_image(
                    'val target Depth {}'.format(j),
                    tensor2array(depth_to_show, max_value=args.max_depth),
                    epoch)
                output_writers[i].add_image(
                    'val Input Stabilized',
                    tensor2array(prior_compensated_from_GT[0]), j)

            prior_compensated_from_prediction = inverse_rotate(
                prior_img, prior_predicted_rot, intrinsics, intrinsics_inv)
            predicted_input_pair = torch.cat(
                [prior_compensated_from_prediction, tgt_img],
                dim=1)  # [B, 6, W, H]
            GT_input_pair = torch.cat([prior_compensated_from_GT, tgt_img],
                                      dim=1)  # [B, 6, W, H]

            # This is the depth from footage stabilized with GT pose, it should be better than depth from raw footage without any GT info
            raw_depth_stab = depth_net(GT_input_pair)
            raw_depth_unstab = depth_net(predicted_input_pair)

            # Upsample depth so that it matches GT size
            scale_factor = GT_depth.size(-1) // raw_depth_stab.size(-1)
            depth_stab = F.interpolate(raw_depth_stab,
                                       scale_factor=scale_factor,
                                       mode='bilinear',
                                       align_corners=False)
            depth_unstab = F.interpolate(raw_depth_unstab,
                                         scale_factor=scale_factor,
                                         mode='bilinear',
                                         align_corners=False)

            for k, depth in enumerate([depth_stab, depth_unstab]):
                disparity = 1 / depth
                errors = stab_depth_errors if k == 0 else unstab_depth_errors
                errors.update(
                    compute_depth_errors(current_GT_depth, depth, crop=True))
                if log_output:
                    prefix = 'stabilized' if k == 0 else 'unstabilized'
                    output_writers[i].add_image(
                        'val {} Dispnet Output Normalized {}'.format(
                            prefix, j),
                        tensor2array(disparity[0],
                                     max_value=None,
                                     colormap='bone'), epoch)
                    output_writers[i].add_image(
                        'val {} Depth Output {}'.format(prefix, j),
                        tensor2array(depth[0], max_value=args.max_depth),
                        epoch)

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        logger.valid_bar.update(i + 1)
        if i % args.print_freq == 0:
            logger.valid_writer.write(
                'valid: Time {} ATE Error {:.4f} ({:.4f}), Unstab Rel Abs Error {:.4f} ({:.4f})'
                .format(batch_time, pose_errors.val[0], pose_errors.avg[0],
                        unstab_depth_errors.val[1],
                        unstab_depth_errors.avg[1]))
    logger.valid_bar.update(len(val_loader))

    errors = (*pose_errors.avg, *unstab_depth_errors.avg,
              *stab_depth_errors.avg)
    error_names = (*pose_error_names,
                   *['unstab {}'.format(e) for e in depth_error_names],
                   *['stab {}'.format(e) for e in depth_error_names])

    return OrderedDict(zip(error_names, errors))
def main():
    args = parser.parse_args()
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    if args.gt_type == 'KITTI':
        from kitti_eval.pose_evaluation_utils import test_framework_KITTI as test_framework
    elif args.gt_type == 'stillbox':
        from stillbox_eval.pose_evaluation_utils import test_framework_stillbox as test_framework

    weights = torch.load(args.pretrained_posenet)
    seq_length = int(weights['state_dict']['conv1.0.weight'].size(1)/3)
    pose_net = PoseNet(seq_length=seq_length).to(device)
    pose_net.load_state_dict(weights['state_dict'], strict=False)

    dataset_dir = Path(args.dataset_dir)
    framework = test_framework(dataset_dir, args.sequences, seq_length)

    print('{} snippets to test'.format(len(framework)))
    errors = np.zeros((len(framework), 2), np.float32)
    if args.output_dir is not None:
        output_dir = Path(args.output_dir)
        output_dir.makedirs_p()
        predictions_array = np.zeros((len(framework), seq_length, 3, 4))

    for j, sample in enumerate(tqdm(framework)):
        imgs = sample['imgs']

        h,w,_ = imgs[0].shape
        if (not args.no_resize) and (h != args.img_height or w != args.img_width):
            imgs = [imresize(img, (args.img_height, args.img_width)).astype(np.float32) for img in imgs]

        imgs = [torch.from_numpy(np.transpose(img, (2,0,1))) for img in imgs]
        imgs = torch.stack(imgs).unsqueeze(0).to(device)
        imgs = 2*(imgs/255 - 0.5)

        poses = pose_net(imgs)

        inv_transform_matrices = pose_vec2mat(poses, rotation_mode=args.rotation_mode)

        transform_matrices = invert_mat(inv_transform_matrices)

        # rot_matrices = np.linalg.inv(inv_transform_matrices[:,:,:3])
        # tr_vectors = rot_matrices @ inv_transform_matrices[:,:,-1:]

        # transform_matrices = np.concatenate([rot_matrices, tr_vectors], axis=-1)

        # first_transform = transform_matrices[0]
        # final_poses = np.linalg.inv(first_transform[:,:3]) @ transform_matrices
        # final_poses[:,:,-1:] -= np.linalg.inv(first_transform[:,:3]) @ first_transform[:,-1:]

        final_poses = compensate_pose(transform_matrices, transform_matrices[:,0])[0].cpu().numpy()

        if args.output_dir is not None:
            predictions_array[j] = final_poses

        ATE, RE = compute_pose_error(sample['poses'][1:], final_poses[1:])
        errors[j] = ATE, RE

    mean_errors = errors.mean(0)
    std_errors = errors.std(0)
    error_names = ['ATE','RE']
    print('')
    print("Results")
    print("\t {:>10}, {:>10}".format(*error_names))
    print("mean \t {:10.4f}, {:10.4f}".format(*mean_errors))
    print("std \t {:10.4f}, {:10.4f}".format(*std_errors))

    if args.output_dir is not None:
        np.save(output_dir/'predictions.npy', predictions_array)
Example #4
0
def validate_without_gt(args,
                        val_loader,
                        depth_net,
                        pose_net,
                        epoch,
                        logger,
                        output_writers=[],
                        **env):
    global device
    batch_time = AverageMeter()
    losses = AverageMeter(i=3, precision=4)
    w1, w2, w3 = args.photo_loss_weight, args.smooth_loss_weight, args.ssim
    if args.log_output:
        poses_values = np.zeros(((len(val_loader) - 1) * args.test_batch_size *
                                 (args.sequence_length - 1), 6))
        disp_values = np.zeros(
            ((len(val_loader) - 1) * args.test_batch_size * 3))

    # switch to evaluate mode
    depth_net.eval()
    pose_net.eval()

    upsample_depth_net = models.UpSampleNet(depth_net, args.network_input_size)

    end = time.time()
    logger.valid_bar.update(0)

    for i, sample in enumerate(val_loader):
        log_output = i < len(output_writers)

        imgs = torch.stack(sample['imgs'], dim=1).to(device)
        intrinsics = sample['intrinsics'].to(device)
        intrinsics_inv = sample['intrinsics_inv'].to(device)

        if epoch == 1 and log_output:
            for j, img in enumerate(sample['imgs']):
                output_writers[i].add_image('val Input', tensor2array(img[0]),
                                            j)

        batch_size, seq = imgs.size()[:2]

        if args.network_input_size is not None:
            h, w = args.network_input_size
            downsample_imgs = F.interpolate(imgs, (3, h, w), mode='area')
            poses = pose_net(downsample_imgs)  # [B, seq, 6]
        else:
            poses = pose_net(imgs)

        pose_matrices = pose_vec2mat(poses,
                                     args.rotation_mode)  # [B, seq, 3, 4]

        mid_index = (args.sequence_length - 1) // 2

        tgt_imgs = imgs[:, mid_index]  # [B, 3, H, W]
        tgt_poses = pose_matrices[:, mid_index]  # [B, 3, 4]
        compensated_poses = compensate_pose(
            pose_matrices,
            tgt_poses)  # [B, seq, 3, 4] tgt_poses are now neutral pose

        ref_indices = list(range(args.sequence_length))
        ref_indices.remove(mid_index)

        loss_1 = 0
        loss_2 = 0

        for ref_index in ref_indices:
            prior_imgs = imgs[:, ref_index]
            prior_poses = compensated_poses[:, ref_index]  # [B, 3, 4]

            prior_imgs_compensated = inverse_rotate(prior_imgs,
                                                    prior_poses[:, :, :3],
                                                    intrinsics, intrinsics_inv)
            input_pair = torch.cat([prior_imgs_compensated, tgt_imgs],
                                   dim=1)  # [B, 6, W, H]

            predicted_magnitude = prior_poses[:, :, -1:].norm(
                p=2, dim=1, keepdim=True).unsqueeze(1)  # [B, 1, 1, 1]
            scale_factor = args.nominal_displacement / predicted_magnitude
            normalized_translation = compensated_poses[:, :, :,
                                                       -1:] * scale_factor  # [B, seq, 3, 1]
            new_pose_matrices = torch.cat(
                [compensated_poses[:, :, :, :-1], normalized_translation],
                dim=-1)

            depth = upsample_depth_net(input_pair)
            disparity = 1 / depth
            total_indices = torch.arange(seq).long().unsqueeze(0).expand(
                batch_size, seq).to(device)
            tgt_id = total_indices[:, mid_index]

            ref_indices = total_indices[
                total_indices != tgt_id.unsqueeze(1)].view(
                    batch_size, seq - 1)

            photo_loss, diff_maps, warped_imgs = photometric_reconstruction_loss(
                imgs,
                tgt_id,
                ref_indices,
                depth,
                new_pose_matrices,
                intrinsics,
                intrinsics_inv,
                args.rotation_mode,
                ssim_weight=w3)

            loss_1 += photo_loss

            if log_output:
                output_writers[i].add_image(
                    'val Dispnet Output Normalized {}'.format(ref_index),
                    tensor2array(disparity[0], max_value=None,
                                 colormap='bone'), epoch)
                output_writers[i].add_image(
                    'val Depth Output {}'.format(ref_index),
                    tensor2array(depth[0].cpu(), max_value=args.max_depth),
                    epoch)
                for j, (diff, warped) in enumerate(zip(diff_maps,
                                                       warped_imgs)):
                    output_writers[i].add_image(
                        'val Warped Outputs {} {}'.format(j, ref_index),
                        tensor2array(warped[0]), epoch)
                    output_writers[i].add_image(
                        'val Diff Outputs {} {}'.format(j, ref_index),
                        tensor2array(diff[0].abs() - 1), epoch)

            loss_2 += texture_aware_smooth_loss(
                disparity, tgt_imgs if args.texture_loss else None)

        if args.log_output and i < len(val_loader) - 1:
            step = args.test_batch_size * (args.sequence_length - 1)
            poses_values[i * step:(i + 1) * step] = poses[:, :-1].cpu().view(
                -1, 6).numpy()
            step = args.test_batch_size * 3
            disp_unraveled = disparity.cpu().view(args.test_batch_size, -1)
            disp_values[i * step:(i + 1) * step] = torch.cat([
                disp_unraveled.min(-1)[0],
                disp_unraveled.median(-1)[0],
                disp_unraveled.max(-1)[0]
            ]).numpy()

        loss = w1 * loss_1 + w2 * loss_2
        losses.update([loss.item(), loss_1.item(), loss_2.item()])

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        logger.valid_bar.update(i + 1)
        if i % args.print_freq == 0:
            logger.valid_writer.write('valid: Time {} Loss {}'.format(
                batch_time, losses))

    if args.log_output:
        rot_coeffs = ['rx', 'ry', 'rz'] if args.rotation_mode == 'euler' else [
            'qx', 'qy', 'qz'
        ]
        tr_coeffs = ['tx', 'ty', 'tz']
        for k, (coeff_name) in enumerate(tr_coeffs + rot_coeffs):
            output_writers[0].add_histogram('val poses_{}'.format(coeff_name),
                                            poses_values[:, k], epoch)
        output_writers[0].add_histogram('disp_values', disp_values, epoch)
    logger.valid_bar.update(len(val_loader))
    return OrderedDict(
        zip(['Total loss', 'Photo loss', 'Smooth loss'], losses.avg))
Example #5
0
def train_one_epoch(args, train_loader, depth_net, pose_net, optimizer, epoch,
                    n_iter, logger, training_writer, **env):
    global device
    logger.reset_train_bar()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter(precision=4)
    w1, w2, w3 = args.photo_loss_weight, args.smooth_loss_weight, args.ssim
    e1, e2 = args.training_milestones

    # switch to train mode
    depth_net.train()
    pose_net.train()

    upsample_depth_net = models.UpSampleNet(depth_net, args.network_input_size)

    end = time.time()
    logger.train_bar.update(0)

    for i, sample in enumerate(train_loader):

        log_losses = i > 0 and n_iter % args.print_freq == 0
        log_output = args.training_output_freq > 0 and n_iter % args.training_output_freq == 0

        # measure data loading time
        data_time.update(time.time() - end)
        imgs = torch.stack(sample['imgs'], dim=1).to(device)
        intrinsics = sample['intrinsics'].to(device)
        intrinsics_inv = sample['intrinsics_inv'].to(device)

        batch_size, seq = imgs.size()[:2]

        if args.network_input_size is not None:
            h, w = args.network_input_size
            downsample_imgs = F.interpolate(imgs, (3, h, w), mode='area')
            poses = pose_net(downsample_imgs)  # [B, seq, 6]
        else:
            poses = pose_net(imgs)

        pose_matrices = pose_vec2mat(poses,
                                     args.rotation_mode)  # [B, seq, 3, 4]

        total_indices = torch.arange(seq).long().to(device).unsqueeze(
            0).expand(batch_size, seq)
        batch_range = torch.arange(batch_size).long().to(device)
        ''' for each element of the batch select a random picture in the sequence to
        which we will compute the depth, all poses are then converted so that pose of this
        very picture is exactly identity. At first this image is always in the middle of the sequence'''

        if epoch > e2:
            tgt_id = torch.floor(torch.rand(batch_size) *
                                 seq).long().to(device)
        else:
            tgt_id = torch.zeros(batch_size).long().to(
                device) + args.sequence_length // 2
        '''
        Select what other picture we are going to feed DepthNet, it must not be the same
        as tgt_id. At first, it's always first picture of the sequence, it is randomly chosen when first training milestone is reached
        '''

        ref_indices = total_indices[total_indices != tgt_id.unsqueeze(1)].view(
            batch_size, seq - 1)

        if epoch > e1:
            prior_id = torch.floor(torch.rand(batch_size) *
                                   (seq - 1)).long().to(device)
        else:
            prior_id = torch.zeros(batch_size).long().to(device)
        prior_id = ref_indices[batch_range, prior_id]

        tgt_imgs = imgs[batch_range, tgt_id]  # [B, 3, H, W]
        tgt_poses = pose_matrices[batch_range, tgt_id]  # [B, 3, 4]

        prior_imgs = imgs[batch_range, prior_id]

        compensated_poses = compensate_pose(
            pose_matrices,
            tgt_poses)  # [B, seq, 3, 4] tgt_poses are now neutral pose
        prior_poses = compensated_poses[batch_range, prior_id]  # [B, 3, 4]

        if args.supervise_pose:
            from_GT = invert_mat(sample['pose']).to(device)
            compensated_GT_poses = compensate_pose(
                from_GT, from_GT[batch_range, tgt_id])
            prior_GT_poses = compensated_GT_poses[batch_range, prior_id]
            prior_imgs_compensated = inverse_rotate(prior_imgs,
                                                    prior_GT_poses[:, :, :-1],
                                                    intrinsics, intrinsics_inv)
        else:
            prior_imgs_compensated = inverse_rotate(prior_imgs,
                                                    prior_poses[:, :, :-1],
                                                    intrinsics, intrinsics_inv)

        input_pair = torch.cat([prior_imgs_compensated, tgt_imgs],
                               dim=1)  # [B, 6, W, H]
        depth = upsample_depth_net(input_pair)
        # depth = [sample['depth'].to(device).unsqueeze(1) * 3 / abs(tgt_id[0] - prior_id[0])]
        # depth.append(torch.nn.functional.interpolate(depth[0], scale_factor=2))
        disparities = [1 / d for d in depth]

        predicted_magnitude = prior_poses[:, :,
                                          -1:].norm(p=2, dim=1,
                                                    keepdim=True).unsqueeze(1)
        scale_factor = args.nominal_displacement / (predicted_magnitude + 1e-5)
        normalized_translation = compensated_poses[:, :, :,
                                                   -1:] * scale_factor  # [B, seq_length-1, 3]
        new_pose_matrices = torch.cat(
            [compensated_poses[:, :, :, :-1], normalized_translation], dim=-1)

        biggest_scale = depth[0].size(-1)

        loss_1 = 0
        for k, scaled_depth in enumerate(depth):
            size_ratio = scaled_depth.size(-1) / biggest_scale
            loss, diff_maps, warped_imgs = photometric_reconstruction_loss(
                imgs,
                tgt_id,
                ref_indices,
                scaled_depth,
                new_pose_matrices,
                intrinsics,
                intrinsics_inv,
                args.rotation_mode,
                ssim_weight=w3)

            loss_1 += loss * size_ratio

            if log_output:
                training_writer.add_image(
                    'train Dispnet Output Normalized scale {}'.format(k),
                    tensor2array(disparities[k][0],
                                 max_value=None,
                                 colormap='bone'), n_iter)
                training_writer.add_image(
                    'train Depth Output scale {}'.format(k),
                    tensor2array(scaled_depth[0], max_value=args.max_depth),
                    n_iter)
                for j, (diff, warped) in enumerate(zip(diff_maps,
                                                       warped_imgs)):
                    training_writer.add_image(
                        'train Warped Outputs {} {}'.format(k, j),
                        tensor2array(warped[0]), n_iter)
                    training_writer.add_image(
                        'train Diff Outputs {} {}'.format(k, j),
                        tensor2array(diff.abs()[0] - 1), n_iter)

        loss_2 = texture_aware_smooth_loss(
            depth, tgt_imgs if args.texture_loss else None)

        loss = w1 * loss_1 + w2 * loss_2

        if args.supervise_pose:
            loss += (from_GT[:, :, :, :3] -
                     pose_matrices[:, :, :, :3]).abs().mean()

        if log_losses:
            training_writer.add_scalar('photometric_error', loss_1.item(),
                                       n_iter)
            training_writer.add_scalar('disparity_smoothness_loss',
                                       loss_2.item(), n_iter)
            training_writer.add_scalar('total_loss', loss.item(), n_iter)

        if log_output:
            nominal_translation_magnitude = poses[:, -2, :3].norm(p=2, dim=-1)
            # last pose is always identity and penultimate translation magnitude is always 1, so you don't need to log them
            for j in range(args.sequence_length - 2):
                trans_mag = poses[:, j, :3].norm(p=2, dim=-1)
                training_writer.add_histogram(
                    'tr {}'.format(j),
                    (trans_mag /
                     nominal_translation_magnitude).detach().cpu().numpy(),
                    n_iter)
            for j in range(args.sequence_length - 1):
                # TODO log a better value : this is magnitude of vector (yaw, pitch, roll) which is not a physical value
                rot_mag = poses[:, j, 3:].norm(p=2, dim=-1)
                training_writer.add_histogram('rot {}'.format(j),
                                              rot_mag.detach().cpu().numpy(),
                                              n_iter)

            training_writer.add_image('train Input', tensor2array(tgt_imgs[0]),
                                      n_iter)

        # record loss for average meter
        losses.update(loss.item(), args.batch_size)

        # compute gradient and do Adam step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        with open(args.save_path / args.log_full, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([loss.item(), loss_1.item(), loss_2.item()])
        logger.train_bar.update(i + 1)
        if i % args.print_freq == 0:
            logger.train_writer.write('Train: Time {} Data {} Loss {}'.format(
                batch_time, data_time, losses))
        if i >= args.epoch_size - 1:
            break

        n_iter += 1

    return losses.avg[0], n_iter
Example #6
0
def adjust_shifts(args, train_set, adjust_loader, depth_net, pose_net, epoch,
                  logger, training_writer, **env):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    new_shifts = AverageMeter(args.sequence_length - 1, precision=2)
    pose_net.eval()
    depth_net.eval()
    upsample_depth_net = models.UpSampleNet(depth_net, args.network_input_size)

    end = time.time()

    mid_index = (args.sequence_length - 1) // 2

    # we contrain mean value of depth net output from pair 0 and mid_index
    target_values = np.arange(
        -mid_index, mid_index + 1) / (args.target_mean_depth * mid_index)
    target_values = 1 / np.abs(
        np.concatenate(
            [target_values[:mid_index], target_values[mid_index + 1:]]))

    logger.reset_train_bar(len(adjust_loader))

    for i, sample in enumerate(adjust_loader):
        index = sample['index']

        # measure data loading time
        data_time.update(time.time() - end)
        imgs = torch.stack(sample['imgs'], dim=1).to(device)
        intrinsics = sample['intrinsics'].to(device)
        intrinsics_inv = sample['intrinsics_inv'].to(device)

        # compute output
        batch_size, seq = imgs.size()[:2]

        if args.network_input_size is not None:
            h, w = args.network_input_size
            downsample_imgs = F.interpolate(imgs, (3, h, w), mode='area')
            poses = pose_net(downsample_imgs)  # [B, seq, 6]
        else:
            poses = pose_net(imgs)

        pose_matrices = pose_vec2mat(poses,
                                     args.rotation_mode)  # [B, seq, 3, 4]

        tgt_imgs = imgs[:, mid_index]  # [B, 3, H, W]
        tgt_poses = pose_matrices[:, mid_index]  # [B, 3, 4]
        compensated_poses = compensate_pose(
            pose_matrices,
            tgt_poses)  # [B, seq, 3, 4] tgt_poses are now neutral pose

        ref_indices = list(range(args.sequence_length))
        ref_indices.remove(mid_index)

        mean_depth_batch = []

        for ref_index in ref_indices:
            prior_imgs = imgs[:, ref_index]
            prior_poses = compensated_poses[:, ref_index]  # [B, 3, 4]

            prior_imgs_compensated = inverse_rotate(prior_imgs,
                                                    prior_poses[:, :, :3],
                                                    intrinsics, intrinsics_inv)
            input_pair = torch.cat([prior_imgs_compensated, tgt_imgs],
                                   dim=1)  # [B, 6, W, H]

            depth = upsample_depth_net(input_pair)  # [B, 1, H, W]
            mean_depth = depth.view(batch_size, -1).mean(-1).cpu().numpy()  # B
            mean_depth_batch.append(mean_depth)

        for j, mean_values in zip(index, np.stack(mean_depth_batch, axis=-1)):
            ratio = mean_values / target_values  # if mean value is too high, raise the shift, lower otherwise
            train_set.reset_shifts(j, ratio[:mid_index], ratio[mid_index:])
            new_shifts.update(train_set.get_shifts(j))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        logger.train_bar.update(i)
        if i % args.print_freq == 0:
            logger.train_writer.write('Adjustement:'
                                      'Time {} Data {} shifts {}'.format(
                                          batch_time, data_time, new_shifts))

    for i, shift in enumerate(new_shifts.avg):
        training_writer.add_scalar('shifts{}'.format(i), shift, epoch)

    return new_shifts.avg
Example #7
0
def validate_without_gt(args, val_loader, depth_net, pose_net, epoch, logger,
                        tb_writer, sample_nb_to_log, **env):
    global device
    batch_time = AverageMeter()
    losses = AverageMeter(i=3, precision=4)
    w1, w2, w3 = args.photo_loss_weight, args.smooth_loss_weight, args.ssim
    if args.log_output:
        poses_values = np.zeros(((len(val_loader) - 1) * args.test_batch_size *
                                 (args.sequence_length - 1), 6))
        disp_values = np.zeros(
            ((len(val_loader) - 1) * args.test_batch_size * 3))

    # switch to evaluate mode
    depth_net.eval()
    pose_net.eval()

    end = time.time()
    logger.valid_bar.update(0)

    for i, sample in enumerate(val_loader):
        log_output = i < sample_nb_to_log

        imgs = torch.stack(sample['imgs'], dim=1).to(device)
        intrinsics = sample['intrinsics'].to(device)

        if epoch == 1 and log_output:
            for j, img in enumerate(sample['imgs']):
                tb_writer.add_image('val Input/{}'.format(i),
                                    tensor2array(img[0]), j)

        batch_size, seq = imgs.size()[:2]
        poses = pose_net(imgs)
        pose_matrices = pose_vec2mat(poses,
                                     args.rotation_mode)  # [B, seq, 3, 4]

        mid_index = (args.sequence_length - 1) // 2

        tgt_imgs = imgs[:, mid_index]  # [B, 3, H, W]
        tgt_poses = pose_matrices[:, mid_index]  # [B, 3, 4]
        compensated_poses = compensate_pose(
            pose_matrices,
            tgt_poses)  # [B, seq, 3, 4] tgt_poses are now neutral pose

        ref_ids = list(range(args.sequence_length))
        ref_ids.remove(mid_index)

        loss_1 = 0
        loss_2 = 0

        for ref_index in ref_ids:
            prior_imgs = imgs[:, ref_index]
            prior_poses = compensated_poses[:, ref_index]  # [B, 3, 4]

            prior_imgs_compensated = inverse_rotate(prior_imgs,
                                                    prior_poses[:, :, :3],
                                                    intrinsics)
            input_pair = torch.cat([prior_imgs_compensated, tgt_imgs],
                                   dim=1)  # [B, 6, W, H]

            predicted_magnitude = prior_poses[:, :, -1:].norm(
                p=2, dim=1, keepdim=True).unsqueeze(1)  # [B, 1, 1, 1]
            scale_factor = args.nominal_displacement / predicted_magnitude
            normalized_translation = compensated_poses[:, :, :,
                                                       -1:] * scale_factor  # [B, seq, 3, 1]
            new_pose_matrices = torch.cat(
                [compensated_poses[:, :, :, :-1], normalized_translation],
                dim=-1)

            depth = depth_net(input_pair)
            disparity = 1 / depth

            tgt_id = torch.full((batch_size, ),
                                ref_index,
                                dtype=torch.int64,
                                device=device)
            ref_ids_tensor = torch.tensor(ref_ids,
                                          dtype=torch.int64,
                                          device=device).expand(
                                              batch_size, -1)
            photo_loss, *to_log = photometric_reconstruction_loss(
                imgs,
                tgt_id,
                ref_ids_tensor,
                depth,
                new_pose_matrices,
                intrinsics,
                args.rotation_mode,
                ssim_weight=w3,
                upsample=args.upscale)

            loss_1 += photo_loss

            if log_output:
                log_output_tensorboard(tb_writer, "train", i, ref_index, epoch,
                                       depth[0], disparity[0], *to_log)

            loss_2 += grad_diffusion_loss(disparity, tgt_imgs, args.kappa)

        if args.log_output and i < len(val_loader) - 1:
            step = args.test_batch_size * (args.sequence_length - 1)
            poses_values[i * step:(i + 1) * step] = poses[:, :-1].cpu().view(
                -1, 6).numpy()
            step = args.test_batch_size * 3
            disp_unraveled = disparity.cpu().view(args.test_batch_size, -1)
            disp_values[i * step:(i + 1) * step] = torch.cat([
                disp_unraveled.min(-1)[0],
                disp_unraveled.median(-1)[0],
                disp_unraveled.max(-1)[0]
            ]).numpy()

        loss = w1 * loss_1 + w2 * loss_2
        losses.update([loss.item(), loss_1.item(), loss_2.item()])

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        logger.valid_bar.update(i + 1)
        if i % args.print_freq == 0:
            logger.valid_writer.write('valid: Time {} Loss {}'.format(
                batch_time, losses))

    if args.log_output:
        rot_coeffs = ['rx', 'ry', 'rz'] if args.rotation_mode == 'euler' else [
            'qx', 'qy', 'qz'
        ]
        tr_coeffs = ['tx', 'ty', 'tz']
        for k, (coeff_name) in enumerate(tr_coeffs + rot_coeffs):
            tb_writer.add_histogram('val poses_{}'.format(coeff_name),
                                    poses_values[:, k], epoch)
        tb_writer.add_histogram('disp_values', disp_values, epoch)
    logger.valid_bar.update(len(val_loader))
    return OrderedDict(
        zip(['Total loss', 'Photo loss', 'Smooth loss'], losses.avg))
Example #8
0
def train_one_epoch(args, train_loader, depth_net, pose_net, optimizer, epoch,
                    n_iter, logger, tb_writer, **env):
    global device
    logger.reset_train_bar()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter(precision=4)
    w1, w2, w3 = args.photo_loss_weight, args.smooth_loss_weight, args.ssim
    e1, e2 = args.training_milestones

    # switch to train mode
    depth_net.train()
    pose_net.train()

    end = time.time()
    logger.train_bar.update(0)

    for i, sample in enumerate(train_loader):

        log_losses = i > 0 and n_iter % args.print_freq == 0
        log_output = args.training_output_freq > 0 and n_iter % args.training_output_freq == 0

        # measure data loading time
        data_time.update(time.time() - end)
        imgs = torch.stack(sample['imgs'], dim=1).to(device)
        intrinsics = sample['intrinsics'].to(device)

        batch_size, seq = imgs.size()[:2]

        if args.network_input_size is not None:
            h, w = args.network_input_size
            downsample_imgs = F.interpolate(imgs, (3, h, w), mode='area')
            poses = pose_net(downsample_imgs)  # [B, seq, 6]
        else:
            poses = pose_net(imgs)

        pose_matrices = pose_vec2mat(poses,
                                     args.rotation_mode)  # [B, seq, 3, 4]

        total_indices = torch.arange(seq, dtype=torch.int64,
                                     device=device).expand(batch_size, seq)
        batch_range = torch.arange(batch_size,
                                   dtype=torch.int64,
                                   device=device)
        ''' for each element of the batch select a random picture in the sequence to
        which we will compute the depth, all poses are then converted so that pose of this
        very picture is exactly identity. At first this image is always in the middle of the sequence'''

        if epoch > e2:
            tgt_id = torch.randint(0, seq, (batch_size, ), device=device)
        else:
            tgt_id = torch.full_like(batch_range, args.sequence_length // 2)

        ref_ids = total_indices[total_indices != tgt_id.unsqueeze(1)].view(
            batch_size, seq - 1)
        '''
        Select what other picture we are going to feed DepthNet, it must not be the same
        as tgt_id. At first, it's always first picture of the sequence, it is randomly chosen when first training milestone is reached
        '''

        if epoch > e1:
            probs = torch.ones_like(total_indices, dtype=torch.float32)
            probs[batch_range, tgt_id] = args.same_ratio
            prior_id = torch.multinomial(probs, 1)[:, 0]
        else:
            prior_id = torch.zeros_like(batch_range)

        # Treat the case of prior_id == tgt_id and the depth must be max_depth, regardless of apparent movement

        tgt_imgs = imgs[batch_range, tgt_id]  # [B, 3, H, W]
        tgt_poses = pose_matrices[batch_range, tgt_id]  # [B, 3, 4]

        prior_imgs = imgs[batch_range, prior_id]

        compensated_poses = compensate_pose(
            pose_matrices,
            tgt_poses)  # [B, seq, 3, 4] tgt_poses are now neutral pose
        prior_poses = compensated_poses[batch_range, prior_id]  # [B, 3, 4]

        if args.supervise_pose:
            from_GT = invert_mat(sample['pose']).to(device)
            compensated_GT_poses = compensate_pose(
                from_GT, from_GT[batch_range, tgt_id])
            prior_GT_poses = compensated_GT_poses[batch_range, prior_id]
            prior_imgs_compensated = inverse_rotate(prior_imgs,
                                                    prior_GT_poses[:, :, :-1],
                                                    intrinsics)
        else:
            prior_imgs_compensated = inverse_rotate(prior_imgs,
                                                    prior_poses[:, :, :-1],
                                                    intrinsics)

        input_pair = torch.cat([prior_imgs_compensated, tgt_imgs],
                               dim=1)  # [B, 6, W, H]
        depth = depth_net(input_pair)

        # depth = [sample['depth'].to(device).unsqueeze(1) * 3 / abs(tgt_id[0] - prior_id[0])]
        # depth.append(torch.nn.functional.interpolate(depth[0], scale_factor=2))
        disparities = [1 / d for d in depth]

        predicted_magnitude = prior_poses[:, :,
                                          -1:].norm(p=2, dim=1,
                                                    keepdim=True).unsqueeze(1)
        scale_factor = args.nominal_displacement / (predicted_magnitude + 1e-5)
        normalized_translation = compensated_poses[:, :, :,
                                                   -1:] * scale_factor  # [B, seq_length-1, 3]
        new_pose_matrices = torch.cat(
            [compensated_poses[:, :, :, :-1], normalized_translation], dim=-1)

        biggest_scale = depth[0].size(-1)

        # Construct valid sequence to compute photometric error,
        # make the rest converge to max_depth because nothing moved
        vb = batch_range[prior_id != tgt_id]
        same_range = batch_range[prior_id == tgt_id]  # batch of still pairs

        loss_1 = 0
        loss_1_same = 0
        for k, scaled_depth in enumerate(depth):
            size_ratio = scaled_depth.size(-1) / biggest_scale

            if len(same_range) > 0:
                # Frames are identical. The corresponding depth must be infinite. Here, we set it to max depth
                still_depth = scaled_depth[same_range]
                loss_same = F.smooth_l1_loss(still_depth / args.max_depth,
                                             torch.ones_like(still_depth))
            else:
                loss_same = 0

            loss_valid, *to_log = photometric_reconstruction_loss(
                imgs[vb],
                tgt_id[vb],
                ref_ids[vb],
                scaled_depth[vb],
                new_pose_matrices[vb],
                intrinsics[vb],
                args.rotation_mode,
                ssim_weight=w3,
                upsample=args.upscale)

            loss_1 += loss_valid * size_ratio
            loss_1_same += loss_same * size_ratio

            if log_output and len(vb) > 0:
                log_output_tensorboard(tb_writer, "train", 0, k, n_iter,
                                       scaled_depth[0], disparities[k][0],
                                       *to_log)
        loss_2 = grad_diffusion_loss(disparities, tgt_imgs, args.kappa)

        loss = w1 * (loss_1 + loss_1_same) + w2 * loss_2
        if args.supervise_pose:
            loss += (from_GT[:, :, :, :3] -
                     pose_matrices[:, :, :, :3]).abs().mean()

        if log_losses:
            tb_writer.add_scalar('photometric_error', loss_1.item(), n_iter)
            tb_writer.add_scalar('disparity_smoothness_loss', loss_2.item(),
                                 n_iter)
            tb_writer.add_scalar('total_loss', loss.item(), n_iter)

        if log_output and len(vb) > 0:
            valid_poses = poses[vb]
            nominal_translation_magnitude = valid_poses[:, -2, :3].norm(p=2,
                                                                        dim=-1)
            # Log the translation magnitude relative to translation magnitude between last and penultimate frames
            # for a perfectly constant displacement magnitude, you should get ratio of 2,3,4 and so forth.
            # last pose is always identity and penultimate translation magnitude is always 1, so you don't need to log them
            for j in range(args.sequence_length - 2):
                trans_mag = valid_poses[:, j, :3].norm(p=2, dim=-1)
                tb_writer.add_histogram(
                    'tr {}'.format(j),
                    (trans_mag /
                     nominal_translation_magnitude).detach().cpu().numpy(),
                    n_iter)
            for j in range(args.sequence_length - 1):
                # TODO log a better value : this is magnitude of vector (yaw, pitch, roll) which is not a physical value
                rot_mag = valid_poses[:, j, 3:].norm(p=2, dim=-1)
                tb_writer.add_histogram('rot {}'.format(j),
                                        rot_mag.detach().cpu().numpy(), n_iter)

            tb_writer.add_image('train Input', tensor2array(tgt_imgs[0]),
                                n_iter)

        # record loss for average meter
        losses.update(loss.item(), args.batch_size)

        # compute gradient and do Adam step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        with open(args.save_path / args.log_full, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([loss.item(), loss_1.item(), loss_2.item()])
        logger.train_bar.update(i + 1)
        if i % args.print_freq == 0:
            logger.train_writer.write('Train: Time {} Data {} Loss {}'.format(
                batch_time, data_time, losses))
        if i >= args.epoch_size - 1:
            break

        n_iter += 1

    return losses.avg[0], n_iter