Beispiel #1
0
def validate(val_loader,
             disp_net,
             pose_exp_net,
             epoch,
             logger,
             output_writers=[]):
    global args
    batch_time = AverageMeter()
    losses = AverageMeter(i=3, precision=4)
    log_outputs = len(output_writers) > 0
    w1, w2, w3 = args.photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight
    poses = np.zeros(
        ((len(val_loader) - 1) * args.batch_size * (args.sequence_length - 1),
         6))

    # switch to evaluate mode
    disp_net.eval()
    pose_exp_net.eval()

    end = time.time()

    for i, (tgt_img, ref_imgs, intrinsics,
            intrinsics_inv) in enumerate(val_loader):
        tgt_img_var = Variable(tgt_img.cuda(), volatile=True)
        ref_imgs_var = [
            Variable(img.cuda(), volatile=True) for img in ref_imgs
        ]
        intrinsics_var = Variable(intrinsics.cuda(), volatile=True)
        intrinsics_inv_var = Variable(intrinsics_inv.cuda(), volatile=True)

        # compute output
        disp = disp_net(tgt_img_var)
        depth = 1 / disp
        explainability_mask, pose = pose_exp_net(tgt_img_var, ref_imgs_var)

        loss_1 = photometric_reconstruction_loss(tgt_img_var, ref_imgs_var,
                                                 intrinsics_var,
                                                 intrinsics_inv_var, depth,
                                                 explainability_mask, pose)
        loss_1 = loss_1.data[0]
        if w2 > 0:
            loss_2 = explainability_loss(explainability_mask).data[0]
        else:
            loss_2 = 0
        loss_3 = smooth_loss(disp).data[0]

        if log_outputs and i % 100 == 0 and i / 100 < len(
                output_writers):  # log first output of every 100 batch
            index = int(i // 100)
            if epoch == 0:
                for j, ref in enumerate(ref_imgs):
                    output_writers[index].add_image('val Input {}'.format(j),
                                                    tensor2array(tgt_img[0]),
                                                    0)
                    output_writers[index].add_image('val Input {}'.format(j),
                                                    tensor2array(ref[0]), 1)

            output_writers[index].add_image(
                'val Dispnet Output Normalized',
                tensor2array(disp.data[0].cpu(),
                             max_value=None,
                             colormap='bone'), epoch)
            output_writers[index].add_image(
                'val Depth Output',
                tensor2array(1. / disp.data[0].cpu(), max_value=10), epoch)
            # log warped images along with explainability mask
            for j, ref in enumerate(ref_imgs_var):
                ref_warped = inverse_warp(ref[:1], depth[:1, 0], pose[:1, j],
                                          intrinsics_var[:1],
                                          intrinsics_inv_var[:1])[0]
                output_writers[index].add_image(
                    'val Warped Outputs {}'.format(j),
                    tensor2array(ref_warped.data.cpu()), epoch)
                output_writers[index].add_image(
                    'val Diff Outputs {}'.format(j),
                    tensor2array(
                        0.5 * (tgt_img_var[0] - ref_warped).abs().data.cpu()),
                    epoch)
                if explainability_mask is not None:
                    output_writers[index].add_image(
                        'val Exp mask Outputs {}'.format(j),
                        tensor2array(explainability_mask[0, j].data.cpu(),
                                     max_value=1,
                                     colormap='bone'), epoch)

        if log_outputs and i < len(val_loader) - 1:
            step = args.batch_size * (args.sequence_length - 1)
            poses[i * step:(i + 1) * step] = pose.data.cpu().view(-1,
                                                                  6).numpy()

        loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3
        losses.update([loss, loss_1, loss_2])

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        logger.valid_bar.update(i)
        if i % args.print_freq == 0:
            logger.valid_writer.write('valid: Time {} Loss {}'.format(
                batch_time, losses))
    if log_outputs:
        output_writers[0].add_histogram('val poses_tx', poses[:, 0], epoch)
        output_writers[0].add_histogram('val poses_ty', poses[:, 1], epoch)
        output_writers[0].add_histogram('val poses_tz', poses[:, 2], epoch)
        output_writers[0].add_histogram('val poses_rx', poses[:, 3], epoch)
        output_writers[0].add_histogram('val poses_ry', poses[:, 4], epoch)
        output_writers[0].add_histogram('val poses_rz', poses[:, 5], epoch)

    return losses.avg
Beispiel #2
0
def train(args, train_loader, disp_net, pose_exp_net, optimizer, epoch_size,
          logger, train_writer):
    global n_iter, device
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter(precision=4)
    w1, w2, w3 = args.photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight

    # switch to train mode
    disp_net.train()
    pose_exp_net.train()

    end = time.time()
    logger.train_bar.update(0)

    for i, (tgt_img, ref_imgs, intrinsics,
            intrinsics_inv) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        tgt_img = tgt_img.to(device)
        ref_imgs = [img.to(device) for img in ref_imgs]
        intrinsics = intrinsics.to(device)

        # compute output
        disparities = disp_net(tgt_img)
        depth = [1 / disp for disp in disparities]
        explainability_mask, pose = pose_exp_net(tgt_img, ref_imgs)

        loss_1 = photometric_reconstruction_loss(tgt_img, ref_imgs, intrinsics,
                                                 depth, explainability_mask,
                                                 pose, args.rotation_mode,
                                                 args.padding_mode)
        if w2 > 0:
            loss_2 = explainability_loss(explainability_mask)
        else:
            loss_2 = 0
        loss_3 = smooth_loss(depth)

        loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3

        if i > 0 and n_iter % args.print_freq == 0:
            train_writer.add_scalar('photometric_error', loss_1.item(), n_iter)
            if w2 > 0:
                train_writer.add_scalar('explanability_loss', loss_2.item(),
                                        n_iter)
            train_writer.add_scalar('disparity_smoothness_loss', loss_3.item(),
                                    n_iter)
            train_writer.add_scalar('total_loss', loss.item(), n_iter)

        if args.training_output_freq > 0 and n_iter % args.training_output_freq == 0:
            train_writer.add_image('train Input', tensor2array(tgt_img[0]),
                                   n_iter)

            with torch.no_grad():
                for k, scaled_depth in enumerate(depth):
                    train_writer.add_image(
                        'train Dispnet Output Normalized {}'.format(k),
                        tensor2array(disparities[k][0],
                                     max_value=None,
                                     colormap='magma'), n_iter)
                    train_writer.add_image(
                        'train Depth Output Normalized {}'.format(k),
                        tensor2array(1 / disparities[k][0], max_value=None),
                        n_iter)
                    b, _, h, w = scaled_depth.size()
                    downscale = tgt_img.size(2) / h

                    tgt_img_scaled = F.interpolate(tgt_img, (h, w),
                                                   mode='area')
                    ref_imgs_scaled = [
                        F.interpolate(ref_img, (h, w), mode='area')
                        for ref_img in ref_imgs
                    ]

                    intrinsics_scaled = torch.cat(
                        (intrinsics[:, 0:2] / downscale, intrinsics[:, 2:]),
                        dim=1)

                    # log warped images along with explainability mask
                    for j, ref in enumerate(ref_imgs_scaled):
                        ref_warped = inverse_warp(
                            ref,
                            scaled_depth[:, 0],
                            pose[:, j],
                            intrinsics_scaled,
                            rotation_mode=args.rotation_mode,
                            padding_mode=args.padding_mode)[0]
                        train_writer.add_image(
                            'train Warped Outputs {} {}'.format(k, j),
                            tensor2array(ref_warped), n_iter)
                        train_writer.add_image(
                            'train Diff Outputs {} {}'.format(k, j),
                            tensor2array(
                                0.5 * (tgt_img_scaled[0] - ref_warped).abs()),
                            n_iter)
                        if explainability_mask[k] is not None:
                            train_writer.add_image(
                                'train Exp mask Outputs {} {}'.format(k, j),
                                tensor2array(explainability_mask[k][0, j],
                                             max_value=1,
                                             colormap='bone'), n_iter)

        # record loss and EPE
        losses.update(loss.item(), args.batch_size)

        # compute gradient and do Adam step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        with open(args.save_path / args.log_full, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([
                loss.item(),
                loss_1.item(),
                loss_2.item() if w2 > 0 else 0,
                loss_3.item()
            ])
        logger.train_bar.update(i + 1)
        if i % args.print_freq == 0:
            logger.train_writer.write('Train: Time {} Data {} Loss {}'.format(
                batch_time, data_time, losses))
        if i >= epoch_size - 1:
            break

        n_iter += 1

    return losses.avg[0]
Beispiel #3
0
def validate_without_gt(args,
                        val_loader,
                        depth_net,
                        pose_net,
                        epoch,
                        logger,
                        output_writers=[],
                        **env):
    global device
    batch_time = AverageMeter()
    losses = AverageMeter(i=3, precision=4)
    w1, w2, w3 = args.photo_loss_weight, args.smooth_loss_weight, args.ssim
    if args.log_output:
        poses_values = np.zeros(((len(val_loader) - 1) * args.test_batch_size *
                                 (args.sequence_length - 1), 6))
        disp_values = np.zeros(
            ((len(val_loader) - 1) * args.test_batch_size * 3))

    # switch to evaluate mode
    depth_net.eval()
    pose_net.eval()

    upsample_depth_net = models.UpSampleNet(depth_net, args.network_input_size)

    end = time.time()
    logger.valid_bar.update(0)

    for i, sample in enumerate(val_loader):
        log_output = i < len(output_writers)

        imgs = torch.stack(sample['imgs'], dim=1).to(device)
        intrinsics = sample['intrinsics'].to(device)
        intrinsics_inv = sample['intrinsics_inv'].to(device)

        if epoch == 1 and log_output:
            for j, img in enumerate(sample['imgs']):
                output_writers[i].add_image('val Input', tensor2array(img[0]),
                                            j)

        batch_size, seq = imgs.size()[:2]

        if args.network_input_size is not None:
            h, w = args.network_input_size
            downsample_imgs = F.interpolate(imgs, (3, h, w), mode='area')
            poses = pose_net(downsample_imgs)  # [B, seq, 6]
        else:
            poses = pose_net(imgs)

        pose_matrices = pose_vec2mat(poses,
                                     args.rotation_mode)  # [B, seq, 3, 4]

        mid_index = (args.sequence_length - 1) // 2

        tgt_imgs = imgs[:, mid_index]  # [B, 3, H, W]
        tgt_poses = pose_matrices[:, mid_index]  # [B, 3, 4]
        compensated_poses = compensate_pose(
            pose_matrices,
            tgt_poses)  # [B, seq, 3, 4] tgt_poses are now neutral pose

        ref_indices = list(range(args.sequence_length))
        ref_indices.remove(mid_index)

        loss_1 = 0
        loss_2 = 0

        for ref_index in ref_indices:
            prior_imgs = imgs[:, ref_index]
            prior_poses = compensated_poses[:, ref_index]  # [B, 3, 4]

            prior_imgs_compensated = inverse_rotate(prior_imgs,
                                                    prior_poses[:, :, :3],
                                                    intrinsics, intrinsics_inv)
            input_pair = torch.cat([prior_imgs_compensated, tgt_imgs],
                                   dim=1)  # [B, 6, W, H]

            predicted_magnitude = prior_poses[:, :, -1:].norm(
                p=2, dim=1, keepdim=True).unsqueeze(1)  # [B, 1, 1, 1]
            scale_factor = args.nominal_displacement / predicted_magnitude
            normalized_translation = compensated_poses[:, :, :,
                                                       -1:] * scale_factor  # [B, seq, 3, 1]
            new_pose_matrices = torch.cat(
                [compensated_poses[:, :, :, :-1], normalized_translation],
                dim=-1)

            depth = upsample_depth_net(input_pair)
            disparity = 1 / depth
            total_indices = torch.arange(seq).long().unsqueeze(0).expand(
                batch_size, seq).to(device)
            tgt_id = total_indices[:, mid_index]

            ref_indices = total_indices[
                total_indices != tgt_id.unsqueeze(1)].view(
                    batch_size, seq - 1)

            photo_loss, diff_maps, warped_imgs = photometric_reconstruction_loss(
                imgs,
                tgt_id,
                ref_indices,
                depth,
                new_pose_matrices,
                intrinsics,
                intrinsics_inv,
                args.rotation_mode,
                ssim_weight=w3)

            loss_1 += photo_loss

            if log_output:
                output_writers[i].add_image(
                    'val Dispnet Output Normalized {}'.format(ref_index),
                    tensor2array(disparity[0], max_value=None,
                                 colormap='bone'), epoch)
                output_writers[i].add_image(
                    'val Depth Output {}'.format(ref_index),
                    tensor2array(depth[0].cpu(), max_value=args.max_depth),
                    epoch)
                for j, (diff, warped) in enumerate(zip(diff_maps,
                                                       warped_imgs)):
                    output_writers[i].add_image(
                        'val Warped Outputs {} {}'.format(j, ref_index),
                        tensor2array(warped[0]), epoch)
                    output_writers[i].add_image(
                        'val Diff Outputs {} {}'.format(j, ref_index),
                        tensor2array(diff[0].abs() - 1), epoch)

            loss_2 += texture_aware_smooth_loss(
                disparity, tgt_imgs if args.texture_loss else None)

        if args.log_output and i < len(val_loader) - 1:
            step = args.test_batch_size * (args.sequence_length - 1)
            poses_values[i * step:(i + 1) * step] = poses[:, :-1].cpu().view(
                -1, 6).numpy()
            step = args.test_batch_size * 3
            disp_unraveled = disparity.cpu().view(args.test_batch_size, -1)
            disp_values[i * step:(i + 1) * step] = torch.cat([
                disp_unraveled.min(-1)[0],
                disp_unraveled.median(-1)[0],
                disp_unraveled.max(-1)[0]
            ]).numpy()

        loss = w1 * loss_1 + w2 * loss_2
        losses.update([loss.item(), loss_1.item(), loss_2.item()])

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        logger.valid_bar.update(i + 1)
        if i % args.print_freq == 0:
            logger.valid_writer.write('valid: Time {} Loss {}'.format(
                batch_time, losses))

    if args.log_output:
        rot_coeffs = ['rx', 'ry', 'rz'] if args.rotation_mode == 'euler' else [
            'qx', 'qy', 'qz'
        ]
        tr_coeffs = ['tx', 'ty', 'tz']
        for k, (coeff_name) in enumerate(tr_coeffs + rot_coeffs):
            output_writers[0].add_histogram('val poses_{}'.format(coeff_name),
                                            poses_values[:, k], epoch)
        output_writers[0].add_histogram('disp_values', disp_values, epoch)
    logger.valid_bar.update(len(val_loader))
    return OrderedDict(
        zip(['Total loss', 'Photo loss', 'Smooth loss'], losses.avg))
Beispiel #4
0
def validate_without_gt(args,
                        val_loader,
                        disp_net,
                        pose_exp_net,
                        epoch,
                        logger,
                        output_writers=[]):
    batch_time = AverageMeter()
    losses = AverageMeter(i=3, precision=4)
    log_outputs = len(output_writers) > 0
    w1, w2, w3 = args.photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight
    poses = np.zeros(
        ((len(val_loader) - 1) * args.batch_size * (args.sequence_length - 1),
         6))
    disp_values = np.zeros(((len(val_loader) - 1) * args.batch_size * 3))

    # switch to evaluate mode
    disp_net.eval()
    pose_exp_net.eval()

    end = time.time()
    logger.valid_bar.update(0)
    for i, (tgt_img, ref_imgs, intrinsics,
            intrinsics_inv) in enumerate(val_loader):
        tgt_img_var = Variable(tgt_img.cuda(), volatile=True)
        ref_imgs_var = [
            Variable(img.cuda(), volatile=True) for img in ref_imgs
        ]
        intrinsics_var = Variable(intrinsics.cuda(), volatile=True)
        intrinsics_inv_var = Variable(intrinsics_inv.cuda(), volatile=True)

        # compute output
        disp = disp_net(tgt_img_var)
        depth = 1 / disp
        explainability_mask, pose = pose_exp_net(tgt_img_var, ref_imgs_var)

        loss_1 = photometric_reconstruction_loss(tgt_img_var, ref_imgs_var,
                                                 intrinsics_var,
                                                 intrinsics_inv_var, depth,
                                                 explainability_mask, pose,
                                                 args.rotation_mode,
                                                 args.padding_mode)
        loss_1 = loss_1.data[0]
        if w2 > 0:
            loss_2 = explainability_loss(explainability_mask).data[0]
        else:
            loss_2 = 0
        loss_3 = smooth_loss(disp).data[0]

        if log_outputs and i % 100 == 0 and i / 100 < len(
                output_writers):  # log first output of every 100 batch
            index = int(i // 100)
            if epoch == 0:
                for j, ref in enumerate(ref_imgs):
                    output_writers[index].add_image('val Input {}'.format(j),
                                                    tensor2array(tgt_img[0]),
                                                    0)
                    output_writers[index].add_image('val Input {}'.format(j),
                                                    tensor2array(ref[0]), 1)

            output_writers[index].add_image(
                'val Dispnet Output Normalized',
                tensor2array(disp.data[0].cpu(),
                             max_value=None,
                             colormap='bone'), epoch)
            output_writers[index].add_image(
                'val Depth Output',
                tensor2array(1. / disp.data[0].cpu(), max_value=10), epoch)
            # log warped images along with explainability mask
            for j, ref in enumerate(ref_imgs_var):
                ref_warped = inverse_warp(ref[:1],
                                          depth[:1, 0],
                                          pose[:1, j],
                                          intrinsics_var[:1],
                                          intrinsics_inv_var[:1],
                                          rotation_mode=args.rotation_mode,
                                          padding_mode=args.padding_mode)[0]

                output_writers[index].add_image(
                    'val Warped Outputs {}'.format(j),
                    tensor2array(ref_warped.data.cpu()), epoch)
                output_writers[index].add_image(
                    'val Diff Outputs {}'.format(j),
                    tensor2array(
                        0.5 * (tgt_img_var[0] - ref_warped).abs().data.cpu()),
                    epoch)
                if explainability_mask is not None:
                    output_writers[index].add_image(
                        'val Exp mask Outputs {}'.format(j),
                        tensor2array(explainability_mask[0, j].data.cpu(),
                                     max_value=1,
                                     colormap='bone'), epoch)

        if log_outputs and i < len(val_loader) - 1:
            step = args.batch_size * (args.sequence_length - 1)
            poses[i * step:(i + 1) * step] = pose.data.cpu().view(-1,
                                                                  6).numpy()
            step = args.batch_size * 3
            disp_unraveled = disp.data.cpu().view(args.batch_size, -1)
            disp_values[i * step:(i + 1) * step] = torch.cat([
                disp_unraveled.min(-1)[0],
                disp_unraveled.median(-1)[0],
                disp_unraveled.max(-1)[0]
            ]).numpy()

        loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3
        losses.update([loss, loss_1, loss_2])

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        logger.valid_bar.update(i + 1)
        if i % args.print_freq == 0:
            logger.valid_writer.write('valid: Time {} Loss {}'.format(
                batch_time, losses))
    if log_outputs:
        prefix = 'valid poses'
        coeffs_names = ['tx', 'ty', 'tz']
        if args.rotation_mode == 'euler':
            coeffs_names.extend(['rx', 'ry', 'rz'])
        elif args.rotation_mode == 'quat':
            coeffs_names.extend(['qx', 'qy', 'qz'])
        for i in range(poses.shape[1]):
            output_writers.add_histogram(
                '{} {}'.format(prefix, coeffs_names[i]), poses[:, i], epoch)
        output_writers[0].add_histogram('disp_values', disp_values, epoch)
    logger.valid_bar.update(len(val_loader))
    return losses.avg, ['Total loss', 'Photo loss', 'Exp loss']
Beispiel #5
0
def train(train_loader, disp_net, pose_exp_net, optimizer, epoch_size, logger,
          train_writer):
    global args, n_iter
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    w1, w2, w3 = args.photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight

    # switch to train mode
    disp_net.train()
    pose_exp_net.train()

    end = time.time()

    for i, (tgt_img, ref_imgs, intrinsics,
            intrinsics_inv) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        tgt_img_var = Variable(tgt_img.cuda())
        ref_imgs_var = [Variable(img.cuda()) for img in ref_imgs]
        intrinsics_var = Variable(intrinsics.cuda())
        intrinsics_inv_var = Variable(intrinsics_inv.cuda())

        # compute output
        disparities = disp_net(tgt_img_var)
        depth = [1 / disp for disp in disparities]
        explainability_mask, pose = pose_exp_net(tgt_img_var, ref_imgs_var)

        loss_1 = photometric_reconstruction_loss(tgt_img_var, ref_imgs_var,
                                                 intrinsics_var,
                                                 intrinsics_inv_var, depth,
                                                 explainability_mask, pose)
        loss_2 = explainability_loss(explainability_mask)
        loss_3 = smooth_loss(disparities)

        loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3

        train_writer.add_scalar('photometric_error', loss_1.data[0], n_iter)
        train_writer.add_scalar('explanability_loss', loss_2.data[0], n_iter)
        train_writer.add_scalar('disparity_smoothness_loss', loss_3.data[0],
                                n_iter)
        train_writer.add_scalar('total_loss', loss.data[0], n_iter)

        if n_iter % 200 == 0 and args.log_output:

            train_writer.add_image('train Input', tensor2array(ref_imgs[0][0]),
                                   n_iter - 1)
            train_writer.add_image('train Input', tensor2array(tgt_img[0]),
                                   n_iter)
            train_writer.add_image('train Input', tensor2array(ref_imgs[1][0]),
                                   n_iter + 1)

            for k, scaled_depth in enumerate(depth):
                train_writer.add_image(
                    'train Dispnet Output {}'.format(k),
                    tensor2array(disparities[k].data[0].cpu(),
                                 max_value=10,
                                 colormap='bone'), n_iter)
                train_writer.add_image(
                    'train Depth Output Normalized {}'.format(k),
                    tensor2array(1 / disparities[k].data[0].cpu(),
                                 max_value=None), n_iter)
                b, _, h, w = scaled_depth.size()
                downscale = tgt_img_var.size(2) / h

                tgt_img_scaled = nn.functional.adaptive_avg_pool2d(
                    tgt_img_var, (h, w))
                ref_imgs_scaled = [
                    nn.functional.adaptive_avg_pool2d(ref_img, (h, w))
                    for ref_img in ref_imgs_var
                ]

                intrinsics_scaled = torch.cat(
                    (intrinsics_var[:, 0:2] / downscale, intrinsics_var[:,
                                                                        2:]),
                    dim=1)
                intrinsics_scaled_inv = torch.cat(
                    (intrinsics_inv_var[:, :, 0:2] * downscale,
                     intrinsics_inv_var[:, :, 2:]),
                    dim=2)

                # log warped images along with explainability mask
                for j, ref in enumerate(ref_imgs_scaled):
                    ref_warped = inverse_warp(ref, scaled_depth[:, 0],
                                              pose[:, j], intrinsics_scaled,
                                              intrinsics_scaled_inv)[0]
                    train_writer.add_image(
                        'train Warped Outputs {} {}'.format(k, j),
                        tensor2array(ref_warped.data.cpu(), max_value=1),
                        n_iter)
                    train_writer.add_image(
                        'train Diff Outputs {} {}'.format(k, j),
                        tensor2array(
                            0.5 *
                            (tgt_img_scaled[0] - ref_warped).abs().data.cpu()),
                        n_iter)
                    train_writer.add_image(
                        'train Exp mask Outputs {} {}'.format(k, j),
                        tensor2array(explainability_mask[k][0, j].data.cpu(),
                                     max_value=1,
                                     colormap='bone'), n_iter)

        # record loss and EPE
        losses.update(loss.data[0], args.batch_size)

        # compute gradient and do Adam step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        with open(args.save_path / args.log_full, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow(
                [loss.data[0], loss_1.data[0], loss_2.data[0], loss_3.data[0]])
        logger.train_bar.update(i)
        if i % args.print_freq == 0:
            logger.train_writer.write(
                'Train: Time {batch_time.val:.3f} ({batch_time.avg:.3f}) '
                'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                'Loss {loss.val:.4f} ({loss.avg:.4f}) '.format(
                    batch_time=batch_time, data_time=data_time, loss=losses))
        if i >= epoch_size - 1:
            break

        n_iter += 1

    return losses.avg
Beispiel #6
0
def train_one_epoch(args, train_loader, depth_net, pose_net, optimizer, epoch,
                    n_iter, logger, training_writer, **env):
    global device
    logger.reset_train_bar()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter(precision=4)
    w1, w2, w3 = args.photo_loss_weight, args.smooth_loss_weight, args.ssim
    e1, e2 = args.training_milestones

    # switch to train mode
    depth_net.train()
    pose_net.train()

    upsample_depth_net = models.UpSampleNet(depth_net, args.network_input_size)

    end = time.time()
    logger.train_bar.update(0)

    for i, sample in enumerate(train_loader):

        log_losses = i > 0 and n_iter % args.print_freq == 0
        log_output = args.training_output_freq > 0 and n_iter % args.training_output_freq == 0

        # measure data loading time
        data_time.update(time.time() - end)
        imgs = torch.stack(sample['imgs'], dim=1).to(device)
        intrinsics = sample['intrinsics'].to(device)
        intrinsics_inv = sample['intrinsics_inv'].to(device)

        batch_size, seq = imgs.size()[:2]

        if args.network_input_size is not None:
            h, w = args.network_input_size
            downsample_imgs = F.interpolate(imgs, (3, h, w), mode='area')
            poses = pose_net(downsample_imgs)  # [B, seq, 6]
        else:
            poses = pose_net(imgs)

        pose_matrices = pose_vec2mat(poses,
                                     args.rotation_mode)  # [B, seq, 3, 4]

        total_indices = torch.arange(seq).long().to(device).unsqueeze(
            0).expand(batch_size, seq)
        batch_range = torch.arange(batch_size).long().to(device)
        ''' for each element of the batch select a random picture in the sequence to
        which we will compute the depth, all poses are then converted so that pose of this
        very picture is exactly identity. At first this image is always in the middle of the sequence'''

        if epoch > e2:
            tgt_id = torch.floor(torch.rand(batch_size) *
                                 seq).long().to(device)
        else:
            tgt_id = torch.zeros(batch_size).long().to(
                device) + args.sequence_length // 2
        '''
        Select what other picture we are going to feed DepthNet, it must not be the same
        as tgt_id. At first, it's always first picture of the sequence, it is randomly chosen when first training milestone is reached
        '''

        ref_indices = total_indices[total_indices != tgt_id.unsqueeze(1)].view(
            batch_size, seq - 1)

        if epoch > e1:
            prior_id = torch.floor(torch.rand(batch_size) *
                                   (seq - 1)).long().to(device)
        else:
            prior_id = torch.zeros(batch_size).long().to(device)
        prior_id = ref_indices[batch_range, prior_id]

        tgt_imgs = imgs[batch_range, tgt_id]  # [B, 3, H, W]
        tgt_poses = pose_matrices[batch_range, tgt_id]  # [B, 3, 4]

        prior_imgs = imgs[batch_range, prior_id]

        compensated_poses = compensate_pose(
            pose_matrices,
            tgt_poses)  # [B, seq, 3, 4] tgt_poses are now neutral pose
        prior_poses = compensated_poses[batch_range, prior_id]  # [B, 3, 4]

        if args.supervise_pose:
            from_GT = invert_mat(sample['pose']).to(device)
            compensated_GT_poses = compensate_pose(
                from_GT, from_GT[batch_range, tgt_id])
            prior_GT_poses = compensated_GT_poses[batch_range, prior_id]
            prior_imgs_compensated = inverse_rotate(prior_imgs,
                                                    prior_GT_poses[:, :, :-1],
                                                    intrinsics, intrinsics_inv)
        else:
            prior_imgs_compensated = inverse_rotate(prior_imgs,
                                                    prior_poses[:, :, :-1],
                                                    intrinsics, intrinsics_inv)

        input_pair = torch.cat([prior_imgs_compensated, tgt_imgs],
                               dim=1)  # [B, 6, W, H]
        depth = upsample_depth_net(input_pair)
        # depth = [sample['depth'].to(device).unsqueeze(1) * 3 / abs(tgt_id[0] - prior_id[0])]
        # depth.append(torch.nn.functional.interpolate(depth[0], scale_factor=2))
        disparities = [1 / d for d in depth]

        predicted_magnitude = prior_poses[:, :,
                                          -1:].norm(p=2, dim=1,
                                                    keepdim=True).unsqueeze(1)
        scale_factor = args.nominal_displacement / (predicted_magnitude + 1e-5)
        normalized_translation = compensated_poses[:, :, :,
                                                   -1:] * scale_factor  # [B, seq_length-1, 3]
        new_pose_matrices = torch.cat(
            [compensated_poses[:, :, :, :-1], normalized_translation], dim=-1)

        biggest_scale = depth[0].size(-1)

        loss_1 = 0
        for k, scaled_depth in enumerate(depth):
            size_ratio = scaled_depth.size(-1) / biggest_scale
            loss, diff_maps, warped_imgs = photometric_reconstruction_loss(
                imgs,
                tgt_id,
                ref_indices,
                scaled_depth,
                new_pose_matrices,
                intrinsics,
                intrinsics_inv,
                args.rotation_mode,
                ssim_weight=w3)

            loss_1 += loss * size_ratio

            if log_output:
                training_writer.add_image(
                    'train Dispnet Output Normalized scale {}'.format(k),
                    tensor2array(disparities[k][0],
                                 max_value=None,
                                 colormap='bone'), n_iter)
                training_writer.add_image(
                    'train Depth Output scale {}'.format(k),
                    tensor2array(scaled_depth[0], max_value=args.max_depth),
                    n_iter)
                for j, (diff, warped) in enumerate(zip(diff_maps,
                                                       warped_imgs)):
                    training_writer.add_image(
                        'train Warped Outputs {} {}'.format(k, j),
                        tensor2array(warped[0]), n_iter)
                    training_writer.add_image(
                        'train Diff Outputs {} {}'.format(k, j),
                        tensor2array(diff.abs()[0] - 1), n_iter)

        loss_2 = texture_aware_smooth_loss(
            depth, tgt_imgs if args.texture_loss else None)

        loss = w1 * loss_1 + w2 * loss_2

        if args.supervise_pose:
            loss += (from_GT[:, :, :, :3] -
                     pose_matrices[:, :, :, :3]).abs().mean()

        if log_losses:
            training_writer.add_scalar('photometric_error', loss_1.item(),
                                       n_iter)
            training_writer.add_scalar('disparity_smoothness_loss',
                                       loss_2.item(), n_iter)
            training_writer.add_scalar('total_loss', loss.item(), n_iter)

        if log_output:
            nominal_translation_magnitude = poses[:, -2, :3].norm(p=2, dim=-1)
            # last pose is always identity and penultimate translation magnitude is always 1, so you don't need to log them
            for j in range(args.sequence_length - 2):
                trans_mag = poses[:, j, :3].norm(p=2, dim=-1)
                training_writer.add_histogram(
                    'tr {}'.format(j),
                    (trans_mag /
                     nominal_translation_magnitude).detach().cpu().numpy(),
                    n_iter)
            for j in range(args.sequence_length - 1):
                # TODO log a better value : this is magnitude of vector (yaw, pitch, roll) which is not a physical value
                rot_mag = poses[:, j, 3:].norm(p=2, dim=-1)
                training_writer.add_histogram('rot {}'.format(j),
                                              rot_mag.detach().cpu().numpy(),
                                              n_iter)

            training_writer.add_image('train Input', tensor2array(tgt_imgs[0]),
                                      n_iter)

        # record loss for average meter
        losses.update(loss.item(), args.batch_size)

        # compute gradient and do Adam step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        with open(args.save_path / args.log_full, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([loss.item(), loss_1.item(), loss_2.item()])
        logger.train_bar.update(i + 1)
        if i % args.print_freq == 0:
            logger.train_writer.write('Train: Time {} Data {} Loss {}'.format(
                batch_time, data_time, losses))
        if i >= args.epoch_size - 1:
            break

        n_iter += 1

    return losses.avg[0], n_iter
Beispiel #7
0
def validate_without_gt(args,
                        val_loader,
                        disp_net,
                        pose_exp_net,
                        epoch,
                        tb_writer,
                        sample_nb_to_log=3):
    global device
    batch_time = AverageMeter()
    losses = AverageMeter(i=3, precision=4)
    log_outputs = sample_nb_to_log > 0
    w1, w2, w3 = args.photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight
    poses = np.zeros(
        ((len(val_loader) - 1) * args.batch_size * (args.sequence_length - 1),
         6))
    disp_values = np.zeros(((len(val_loader) - 1) * args.batch_size * 3))

    # switch to evaluate mode
    disp_net.eval()
    pose_exp_net.eval()

    end = time.time()

    validate_pbar = tqdm(
        total=len(val_loader),
        bar_format='{desc} {percentage:3.0f}%|{bar}| {postfix}')
    validate_pbar.set_description(
        'valid: Loss *.**** *.*****.****(*.**** *.**** *.****)')
    validate_pbar.set_postfix_str('<Time *.***(*.***)>')

    for i, (tgt_img, ref_imgs, intrinsics,
            intrinsics_inv) in enumerate(val_loader):
        tgt_img = tgt_img.to(device)
        ref_imgs = [img.to(device) for img in ref_imgs]
        intrinsics = intrinsics.to(device)
        intrinsics_inv = intrinsics_inv.to(device)

        # compute output
        disp = disp_net(tgt_img)
        depth = 1 / disp

        explainability_mask, pose = pose_exp_net(tgt_img, ref_imgs)
        loss_1, warped, diff = photometric_reconstruction_loss(
            tgt_img, ref_imgs, intrinsics, depth, explainability_mask, pose,
            args.rotation_mode, args.padding_mode)
        loss_1 = loss_1.item()
        if w2 > 0:
            loss_2 = explainability_loss(explainability_mask).item()
        else:
            loss_2 = 0
        loss_3 = smooth_loss(depth).item()

        if log_outputs and i < sample_nb_to_log - 1:  # log first output of first batches
            if epoch == 0:
                for j, ref in enumerate(ref_imgs):
                    tb_writer.add_image('val Input {}/{}'.format(j, i),
                                        tensor2array(tgt_img[0]), 0)
                    tb_writer.add_image('val Input {}/{}'.format(j, i),
                                        tensor2array(ref[0]), 1)

            log_output_tensorboard(tb_writer, 'val', i, '', epoch, 1. / disp,
                                   disp, warped, diff, explainability_mask)

        if log_outputs and i < len(val_loader) - 1:
            step = args.batch_size * (args.sequence_length - 1)
            poses[i * step:(i + 1) * step] = pose.cpu().view(-1, 6).numpy()
            step = args.batch_size * 3
            disp_unraveled = disp.cpu().view(args.batch_size, -1)
            disp_values[i * step:(i + 1) * step] = torch.cat([
                disp_unraveled.min(-1)[0],
                disp_unraveled.median(-1)[0],
                disp_unraveled.max(-1)[0]
            ]).numpy()

        loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3
        losses.update([loss, loss_1, loss_2])

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        validate_pbar.clear()
        validate_pbar.update(1)
        validate_pbar.set_description('valid: Loss {}'.format(losses))
        validate_pbar.set_postfix_str('<Time {}>'.format(batch_time))
    validate_pbar.close()

    if log_outputs:
        prefix = 'valid poses'
        coeffs_names = ['tx', 'ty', 'tz']
        if args.rotation_mode == 'euler':
            coeffs_names.extend(['rx', 'ry', 'rz'])
        elif args.rotation_mode == 'quat':
            coeffs_names.extend(['qx', 'qy', 'qz'])
        for i in range(poses.shape[1]):
            tb_writer.add_histogram('{} {}'.format(prefix, coeffs_names[i]),
                                    poses[:, i], epoch)
        tb_writer.add_histogram('disp_values', disp_values, epoch)
        time.sleep(0.2)
    else:
        time.sleep(1)
    return losses.avg, ['Total loss', 'Photo loss', 'Exp loss']
def validate_without_gt(args,
                        val_loader,
                        disp_net,
                        pose_exp_net,
                        epoch,
                        logger,
                        tb_writer,
                        sample_nb_to_log=3):
    global device
    batch_time = AverageMeter()
    losses = AverageMeter(i=3, precision=4)
    log_outputs = sample_nb_to_log > 0
    w1, w2, w3 = args.photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight
    poses = np.zeros(
        ((len(val_loader) - 1) * args.batch_size * (args.sequence_length - 1),
         6))
    disp_values = np.zeros(((len(val_loader) - 1) * args.batch_size * 3))

    # switch to evaluate mode
    disp_net.eval()
    pose_exp_net.eval()

    end = time.time()
    logger.valid_bar.update(0)
    for i, (tgt_img, ref_imgs, intrinsics,
            intrinsics_inv) in enumerate(val_loader):
        tgt_img = tgt_img.to(device)
        ref_imgs = [img.to(device) for img in ref_imgs]
        intrinsics = intrinsics.to(device)
        intrinsics_inv = intrinsics_inv.to(device)

        # compute output
        disp = disp_net(tgt_img)
        depth = 1 / disp
        explainability_mask, pose = pose_exp_net(tgt_img, ref_imgs)

        loss_1, warped, diff = photometric_reconstruction_loss(
            tgt_img, ref_imgs, intrinsics, depth, explainability_mask, pose,
            args.rotation_mode, args.padding_mode)
        loss_1 = loss_1.item()
        if w2 > 0:
            loss_2 = explainability_loss(explainability_mask).item()
        else:
            loss_2 = 0
        loss_3 = smooth_loss(depth).item()

        if log_outputs and i < sample_nb_to_log - 1:  # log first output of first batches
            if epoch == 0:
                for j, ref in enumerate(ref_imgs):
                    tb_writer.add_image('val Input {}/{}'.format(j, i),
                                        tensor2array(tgt_img[0]), 0)
                    tb_writer.add_image('val Input {}/{}'.format(j, i),
                                        tensor2array(ref[0]), 1)

            log_output_tensorboard(tb_writer, 'val', i, '', epoch, 1. / disp,
                                   disp, warped[0], diff[0],
                                   explainability_mask)

        if log_outputs and i < len(val_loader) - 1:
            step = args.batch_size * (args.sequence_length - 1)
            poses[i * step:(i + 1) * step] = pose.cpu().view(-1, 6).numpy()
            step = args.batch_size * 3
            disp_unraveled = disp.cpu().view(args.batch_size, -1)
            disp_values[i * step:(i + 1) * step] = torch.cat([
                disp_unraveled.min(-1)[0],
                disp_unraveled.median(-1)[0],
                disp_unraveled.max(-1)[0]
            ]).numpy()

        loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3

        if args.with_photocon_loss:
            batch_size = pose.size()[0]
            homo_row = torch.tensor([[0, 0, 0, 1]],
                                    dtype=torch.float).to(device)
            homo_row = homo_row.unsqueeze(0).expand(batch_size, -1, -1)
            T21 = pose_vec2mat(pose[:, 0])
            T21 = torch.cat((T21, homo_row), 1)
            T12 = torch.inverse(T21)
            T23 = pose_vec2mat(pose[:, 1])
            T23 = torch.cat((T23, homo_row), 1)
            T13 = torch.matmul(T23, T12)  #[B,4,4]
            #             print("----",T13.size())
            # target = 1(ref_imgs[0]) and ref = 3(ref_imgs[1])
            ref_img_warped, valid_points = inverse_warp_posemat(
                ref_imgs[1], depth[:, 0], T13, intrinsics, args.rotation_mode,
                args.padding_mode)
            diff = (ref_imgs[0] -
                    ref_img_warped) * valid_points.unsqueeze(1).float()
            loss_4 = diff.abs().mean()

            loss += loss_4

        losses.update([loss, loss_1, loss_2])

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        logger.valid_bar.update(i + 1)
        if i % args.print_freq == 0:
            logger.valid_writer.write('valid: Time {} Loss {}'.format(
                batch_time, losses))
    if log_outputs:
        prefix = 'valid poses'
        coeffs_names = ['tx', 'ty', 'tz']
        if args.rotation_mode == 'euler':
            coeffs_names.extend(['rx', 'ry', 'rz'])
        elif args.rotation_mode == 'quat':
            coeffs_names.extend(['qx', 'qy', 'qz'])
        for i in range(poses.shape[1]):
            tb_writer.add_histogram('{} {}'.format(prefix, coeffs_names[i]),
                                    poses[:, i], epoch)
        tb_writer.add_histogram('disp_values', disp_values, epoch)
    logger.valid_bar.update(len(val_loader))
    return losses.avg, [
        'Validation Total loss', 'Validation Photo loss', 'Validation Exp loss'
    ]
Beispiel #9
0
def validate_without_gt(args, val_loader, depth_net, pose_net, epoch, logger,
                        tb_writer, sample_nb_to_log, **env):
    global device
    batch_time = AverageMeter()
    losses = AverageMeter(i=3, precision=4)
    w1, w2, w3 = args.photo_loss_weight, args.smooth_loss_weight, args.ssim
    if args.log_output:
        poses_values = np.zeros(((len(val_loader) - 1) * args.test_batch_size *
                                 (args.sequence_length - 1), 6))
        disp_values = np.zeros(
            ((len(val_loader) - 1) * args.test_batch_size * 3))

    # switch to evaluate mode
    depth_net.eval()
    pose_net.eval()

    end = time.time()
    logger.valid_bar.update(0)

    for i, sample in enumerate(val_loader):
        log_output = i < sample_nb_to_log

        imgs = torch.stack(sample['imgs'], dim=1).to(device)
        intrinsics = sample['intrinsics'].to(device)

        if epoch == 1 and log_output:
            for j, img in enumerate(sample['imgs']):
                tb_writer.add_image('val Input/{}'.format(i),
                                    tensor2array(img[0]), j)

        batch_size, seq = imgs.size()[:2]
        poses = pose_net(imgs)
        pose_matrices = pose_vec2mat(poses,
                                     args.rotation_mode)  # [B, seq, 3, 4]

        mid_index = (args.sequence_length - 1) // 2

        tgt_imgs = imgs[:, mid_index]  # [B, 3, H, W]
        tgt_poses = pose_matrices[:, mid_index]  # [B, 3, 4]
        compensated_poses = compensate_pose(
            pose_matrices,
            tgt_poses)  # [B, seq, 3, 4] tgt_poses are now neutral pose

        ref_ids = list(range(args.sequence_length))
        ref_ids.remove(mid_index)

        loss_1 = 0
        loss_2 = 0

        for ref_index in ref_ids:
            prior_imgs = imgs[:, ref_index]
            prior_poses = compensated_poses[:, ref_index]  # [B, 3, 4]

            prior_imgs_compensated = inverse_rotate(prior_imgs,
                                                    prior_poses[:, :, :3],
                                                    intrinsics)
            input_pair = torch.cat([prior_imgs_compensated, tgt_imgs],
                                   dim=1)  # [B, 6, W, H]

            predicted_magnitude = prior_poses[:, :, -1:].norm(
                p=2, dim=1, keepdim=True).unsqueeze(1)  # [B, 1, 1, 1]
            scale_factor = args.nominal_displacement / predicted_magnitude
            normalized_translation = compensated_poses[:, :, :,
                                                       -1:] * scale_factor  # [B, seq, 3, 1]
            new_pose_matrices = torch.cat(
                [compensated_poses[:, :, :, :-1], normalized_translation],
                dim=-1)

            depth = depth_net(input_pair)
            disparity = 1 / depth

            tgt_id = torch.full((batch_size, ),
                                ref_index,
                                dtype=torch.int64,
                                device=device)
            ref_ids_tensor = torch.tensor(ref_ids,
                                          dtype=torch.int64,
                                          device=device).expand(
                                              batch_size, -1)
            photo_loss, *to_log = photometric_reconstruction_loss(
                imgs,
                tgt_id,
                ref_ids_tensor,
                depth,
                new_pose_matrices,
                intrinsics,
                args.rotation_mode,
                ssim_weight=w3,
                upsample=args.upscale)

            loss_1 += photo_loss

            if log_output:
                log_output_tensorboard(tb_writer, "train", i, ref_index, epoch,
                                       depth[0], disparity[0], *to_log)

            loss_2 += grad_diffusion_loss(disparity, tgt_imgs, args.kappa)

        if args.log_output and i < len(val_loader) - 1:
            step = args.test_batch_size * (args.sequence_length - 1)
            poses_values[i * step:(i + 1) * step] = poses[:, :-1].cpu().view(
                -1, 6).numpy()
            step = args.test_batch_size * 3
            disp_unraveled = disparity.cpu().view(args.test_batch_size, -1)
            disp_values[i * step:(i + 1) * step] = torch.cat([
                disp_unraveled.min(-1)[0],
                disp_unraveled.median(-1)[0],
                disp_unraveled.max(-1)[0]
            ]).numpy()

        loss = w1 * loss_1 + w2 * loss_2
        losses.update([loss.item(), loss_1.item(), loss_2.item()])

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        logger.valid_bar.update(i + 1)
        if i % args.print_freq == 0:
            logger.valid_writer.write('valid: Time {} Loss {}'.format(
                batch_time, losses))

    if args.log_output:
        rot_coeffs = ['rx', 'ry', 'rz'] if args.rotation_mode == 'euler' else [
            'qx', 'qy', 'qz'
        ]
        tr_coeffs = ['tx', 'ty', 'tz']
        for k, (coeff_name) in enumerate(tr_coeffs + rot_coeffs):
            tb_writer.add_histogram('val poses_{}'.format(coeff_name),
                                    poses_values[:, k], epoch)
        tb_writer.add_histogram('disp_values', disp_values, epoch)
    logger.valid_bar.update(len(val_loader))
    return OrderedDict(
        zip(['Total loss', 'Photo loss', 'Smooth loss'], losses.avg))
Beispiel #10
0
def train(args, train_loader, disp_net, pose_exp_net, optimizer, epoch_size,
          tb_writer):
    global n_iter, device
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter(precision=4)
    w1, w2, w3 = args.photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight

    # switch to train mode
    disp_net.train()
    pose_exp_net.train()

    end = time.time()

    train_pbar = tqdm(total=min(len(train_loader), args.epoch_size),
                      bar_format='{desc} {percentage:3.0f}%|{bar}| {postfix}')
    train_pbar.set_description('Train: Total Loss=#.####(#.####)')
    train_pbar.set_postfix_str('<TIME: op=#.###(#.###) DataFlow=#.###(#.###)>')

    for i, (tgt_img, ref_imgs, intrinsics,
            intrinsics_inv) in enumerate(train_loader):
        log_losses = i > 0 and n_iter % args.print_freq == 0
        log_output = args.training_output_freq > 0 and n_iter % args.training_output_freq == 0

        # measure DataFlow loading time
        data_time.update(time.time() - end)
        tgt_img = tgt_img.to(device)
        ref_imgs = [img.to(device) for img in ref_imgs]
        intrinsics = intrinsics.to(device)

        # compute output
        disparities = disp_net(tgt_img)
        depth = [1 / disp for disp in disparities]
        explainability_mask, pose = pose_exp_net(tgt_img, ref_imgs)

        loss_1, warped, diff = photometric_reconstruction_loss(
            tgt_img, ref_imgs, intrinsics, depth, explainability_mask, pose,
            args.rotation_mode, args.padding_mode)
        if w2 > 0:
            loss_2 = explainability_loss(explainability_mask)
        else:
            loss_2 = 0
        loss_3 = smooth_loss(depth)

        loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3

        if loss < 0.0005:
            abc = 0
        if log_losses:
            tb_writer.add_scalar('photometric_error', loss_1.item(), n_iter)
            if w2 > 0:
                tb_writer.add_scalar('explanability_loss', loss_2.item(),
                                     n_iter)
            tb_writer.add_scalar('disparity_smoothness_loss', loss_3.item(),
                                 n_iter)
            tb_writer.add_scalar('total_loss', loss.item(), n_iter)

        if log_output:
            tb_writer.add_image('train Input', tensor2array(tgt_img[0]),
                                n_iter)
            for k, scaled_maps in enumerate(
                    zip(depth, disparities, warped, diff,
                        explainability_mask)):
                log_output_tensorboard(tb_writer, "train", 0, k, n_iter,
                                       *scaled_maps)

        # record loss and EPE
        losses.update(loss.item(), args.batch_size)

        # compute gradient and do Adam step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        with open(args.save_path / args.log_full, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([
                loss.item(),
                loss_1.item(),
                loss_2.item() if w2 > 0 else 0,
                loss_3.item()
            ])
        train_pbar.clear()
        train_pbar.update(1)
        train_pbar.set_description('Train: Total Loss={}'.format(losses))
        train_pbar.set_postfix_str('<TIME: op={} DataFlow={}>'.format(
            batch_time, data_time))
        if i >= epoch_size - 1:
            break

        n_iter += 1
    train_pbar.close()
    time.sleep(1)
    return losses.avg[0]
Beispiel #11
0
def train_one_epoch(args, train_loader, depth_net, pose_net, optimizer, epoch,
                    n_iter, logger, tb_writer, **env):
    global device
    logger.reset_train_bar()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter(precision=4)
    w1, w2, w3 = args.photo_loss_weight, args.smooth_loss_weight, args.ssim
    e1, e2 = args.training_milestones

    # switch to train mode
    depth_net.train()
    pose_net.train()

    end = time.time()
    logger.train_bar.update(0)

    for i, sample in enumerate(train_loader):

        log_losses = i > 0 and n_iter % args.print_freq == 0
        log_output = args.training_output_freq > 0 and n_iter % args.training_output_freq == 0

        # measure data loading time
        data_time.update(time.time() - end)
        imgs = torch.stack(sample['imgs'], dim=1).to(device)
        intrinsics = sample['intrinsics'].to(device)

        batch_size, seq = imgs.size()[:2]

        if args.network_input_size is not None:
            h, w = args.network_input_size
            downsample_imgs = F.interpolate(imgs, (3, h, w), mode='area')
            poses = pose_net(downsample_imgs)  # [B, seq, 6]
        else:
            poses = pose_net(imgs)

        pose_matrices = pose_vec2mat(poses,
                                     args.rotation_mode)  # [B, seq, 3, 4]

        total_indices = torch.arange(seq, dtype=torch.int64,
                                     device=device).expand(batch_size, seq)
        batch_range = torch.arange(batch_size,
                                   dtype=torch.int64,
                                   device=device)
        ''' for each element of the batch select a random picture in the sequence to
        which we will compute the depth, all poses are then converted so that pose of this
        very picture is exactly identity. At first this image is always in the middle of the sequence'''

        if epoch > e2:
            tgt_id = torch.randint(0, seq, (batch_size, ), device=device)
        else:
            tgt_id = torch.full_like(batch_range, args.sequence_length // 2)

        ref_ids = total_indices[total_indices != tgt_id.unsqueeze(1)].view(
            batch_size, seq - 1)
        '''
        Select what other picture we are going to feed DepthNet, it must not be the same
        as tgt_id. At first, it's always first picture of the sequence, it is randomly chosen when first training milestone is reached
        '''

        if epoch > e1:
            probs = torch.ones_like(total_indices, dtype=torch.float32)
            probs[batch_range, tgt_id] = args.same_ratio
            prior_id = torch.multinomial(probs, 1)[:, 0]
        else:
            prior_id = torch.zeros_like(batch_range)

        # Treat the case of prior_id == tgt_id and the depth must be max_depth, regardless of apparent movement

        tgt_imgs = imgs[batch_range, tgt_id]  # [B, 3, H, W]
        tgt_poses = pose_matrices[batch_range, tgt_id]  # [B, 3, 4]

        prior_imgs = imgs[batch_range, prior_id]

        compensated_poses = compensate_pose(
            pose_matrices,
            tgt_poses)  # [B, seq, 3, 4] tgt_poses are now neutral pose
        prior_poses = compensated_poses[batch_range, prior_id]  # [B, 3, 4]

        if args.supervise_pose:
            from_GT = invert_mat(sample['pose']).to(device)
            compensated_GT_poses = compensate_pose(
                from_GT, from_GT[batch_range, tgt_id])
            prior_GT_poses = compensated_GT_poses[batch_range, prior_id]
            prior_imgs_compensated = inverse_rotate(prior_imgs,
                                                    prior_GT_poses[:, :, :-1],
                                                    intrinsics)
        else:
            prior_imgs_compensated = inverse_rotate(prior_imgs,
                                                    prior_poses[:, :, :-1],
                                                    intrinsics)

        input_pair = torch.cat([prior_imgs_compensated, tgt_imgs],
                               dim=1)  # [B, 6, W, H]
        depth = depth_net(input_pair)

        # depth = [sample['depth'].to(device).unsqueeze(1) * 3 / abs(tgt_id[0] - prior_id[0])]
        # depth.append(torch.nn.functional.interpolate(depth[0], scale_factor=2))
        disparities = [1 / d for d in depth]

        predicted_magnitude = prior_poses[:, :,
                                          -1:].norm(p=2, dim=1,
                                                    keepdim=True).unsqueeze(1)
        scale_factor = args.nominal_displacement / (predicted_magnitude + 1e-5)
        normalized_translation = compensated_poses[:, :, :,
                                                   -1:] * scale_factor  # [B, seq_length-1, 3]
        new_pose_matrices = torch.cat(
            [compensated_poses[:, :, :, :-1], normalized_translation], dim=-1)

        biggest_scale = depth[0].size(-1)

        # Construct valid sequence to compute photometric error,
        # make the rest converge to max_depth because nothing moved
        vb = batch_range[prior_id != tgt_id]
        same_range = batch_range[prior_id == tgt_id]  # batch of still pairs

        loss_1 = 0
        loss_1_same = 0
        for k, scaled_depth in enumerate(depth):
            size_ratio = scaled_depth.size(-1) / biggest_scale

            if len(same_range) > 0:
                # Frames are identical. The corresponding depth must be infinite. Here, we set it to max depth
                still_depth = scaled_depth[same_range]
                loss_same = F.smooth_l1_loss(still_depth / args.max_depth,
                                             torch.ones_like(still_depth))
            else:
                loss_same = 0

            loss_valid, *to_log = photometric_reconstruction_loss(
                imgs[vb],
                tgt_id[vb],
                ref_ids[vb],
                scaled_depth[vb],
                new_pose_matrices[vb],
                intrinsics[vb],
                args.rotation_mode,
                ssim_weight=w3,
                upsample=args.upscale)

            loss_1 += loss_valid * size_ratio
            loss_1_same += loss_same * size_ratio

            if log_output and len(vb) > 0:
                log_output_tensorboard(tb_writer, "train", 0, k, n_iter,
                                       scaled_depth[0], disparities[k][0],
                                       *to_log)
        loss_2 = grad_diffusion_loss(disparities, tgt_imgs, args.kappa)

        loss = w1 * (loss_1 + loss_1_same) + w2 * loss_2
        if args.supervise_pose:
            loss += (from_GT[:, :, :, :3] -
                     pose_matrices[:, :, :, :3]).abs().mean()

        if log_losses:
            tb_writer.add_scalar('photometric_error', loss_1.item(), n_iter)
            tb_writer.add_scalar('disparity_smoothness_loss', loss_2.item(),
                                 n_iter)
            tb_writer.add_scalar('total_loss', loss.item(), n_iter)

        if log_output and len(vb) > 0:
            valid_poses = poses[vb]
            nominal_translation_magnitude = valid_poses[:, -2, :3].norm(p=2,
                                                                        dim=-1)
            # Log the translation magnitude relative to translation magnitude between last and penultimate frames
            # for a perfectly constant displacement magnitude, you should get ratio of 2,3,4 and so forth.
            # last pose is always identity and penultimate translation magnitude is always 1, so you don't need to log them
            for j in range(args.sequence_length - 2):
                trans_mag = valid_poses[:, j, :3].norm(p=2, dim=-1)
                tb_writer.add_histogram(
                    'tr {}'.format(j),
                    (trans_mag /
                     nominal_translation_magnitude).detach().cpu().numpy(),
                    n_iter)
            for j in range(args.sequence_length - 1):
                # TODO log a better value : this is magnitude of vector (yaw, pitch, roll) which is not a physical value
                rot_mag = valid_poses[:, j, 3:].norm(p=2, dim=-1)
                tb_writer.add_histogram('rot {}'.format(j),
                                        rot_mag.detach().cpu().numpy(), n_iter)

            tb_writer.add_image('train Input', tensor2array(tgt_imgs[0]),
                                n_iter)

        # record loss for average meter
        losses.update(loss.item(), args.batch_size)

        # compute gradient and do Adam step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        with open(args.save_path / args.log_full, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([loss.item(), loss_1.item(), loss_2.item()])
        logger.train_bar.update(i + 1)
        if i % args.print_freq == 0:
            logger.train_writer.write('Train: Time {} Data {} Loss {}'.format(
                batch_time, data_time, losses))
        if i >= args.epoch_size - 1:
            break

        n_iter += 1

    return losses.avg[0], n_iter
Beispiel #12
0
def train(odometry_net, depth_net, feat_extractor, train_loader, epoch,
          optimizer):
    global device
    global data_parallel
    if data_parallel:
        odometry_net.module.set_fix_method(nfp.FIX_AUTO)
    else:
        odometry_net.set_fix_method(nfp.FIX_AUTO)
    odometry_net.train()
    depth_net.train()
    feat_extractor.train()
    total_loss = 0
    img_reconstruction_total = 0
    f_reconstruction_total = 0
    smooth_total = 0
    for batch_idx, (img_R1, img_L2, img_R2, intrinsics, inv_intrinsics, raw_K,
                    T_R2L) in tqdm(enumerate(train_loader),
                                   desc='Train epoch %d' % epoch,
                                   leave=False,
                                   ncols=80):
        img_R1 = img_R1.type(torch.FloatTensor).to(device)
        img_R2 = img_R2.type(torch.FloatTensor).to(device)
        img_L2 = img_L2.type(torch.FloatTensor).to(device)
        intrinsics = intrinsics.type(torch.FloatTensor).to(device)
        inv_intrinsics = inv_intrinsics.type(torch.FloatTensor).to(device)
        raw_K = raw_K.type(torch.FloatTensor).to(device)
        T_R2L = T_R2L.type(torch.FloatTensor).to(device)

        img_R = torch.cat((img_R2, img_R1), dim=1)

        inv_depth_img_R2 = depth_net(img_R2)
        T_2to1, _ = odometry_net(img_R)
        T_2to1 = T_2to1.view(T_2to1.size(0), -1)
        T_R2L = T_R2L.view(T_R2L.size(0), -1)

        depth = (1 / (inv_depth_img_R2 + 1e-4)).squeeze(1)

        img_reconstruction_error = photometric_reconstruction_loss(
            0.004 * img_R2, 0.004 * img_R1, 0.004 * img_L2, depth, T_2to1,
            T_R2L, intrinsics, inv_intrinsics)
        smooth_error = smooth_loss(depth.unsqueeze(1))

        imgs = torch.cat((img_L2, img_R2, img_R1), dim=0)
        feat = feat_extractor(imgs)
        batch_size = img_R1.size(0)
        f_L2, f_R2, f_R1 = feat[:batch_size, :, :, :], feat[
            batch_size:batch_size * 2, :, :, :], feat[2 * batch_size:, :, :, :]

        feat_reconstruction_error = photometric_reconstruction_loss(
            f_R2, f_R1, f_L2, depth, T_2to1, T_R2L, intrinsics, inv_intrinsics)

        loss = img_reconstruction_error + 0.1 * feat_reconstruction_error + 10 * smooth_error

        total_loss += loss.item()
        img_reconstruction_total += img_reconstruction_error.item()
        f_reconstruction_total += feat_reconstruction_error.item()
        smooth_total += smooth_error.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(
        "Train epoch {}: loss: {:.9f} img-recon-loss: {:.9f} f-recon-loss: {:.9f} smooth-loss: {:.9f}"
        .format(epoch, total_loss / len(train_loader),
                img_reconstruction_total / len(train_loader),
                f_reconstruction_total / len(train_loader),
                smooth_total / len(train_loader)))
Beispiel #13
0
def train(args, train_loader, disp_net, pose_exp_net, optimizer, epoch_size,
          logger, train_writer):
    global n_iter, device
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter(precision=4)
    w1, w2, w3 = args.photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight

    # switch to train mode
    disp_net.train()
    pose_exp_net.train()

    end = time.time()
    logger.train_bar.update(0)
    #train main cycle
    for i, (tgt_img, ref_imgs, intrinsics,
            intrinsics_inv) in enumerate(train_loader):
        #for (i, data) in enumerate(train_loader):#data(list): [tensor(B,3,H,W),list(B),(B,H,W),(b,h,w)]
        log_losses = i > 0 and n_iter % args.print_freq == 0
        log_output = args.training_output_freq > 0 and n_iter % args.training_output_freq == 0

        #1 measure data loading time
        data_time.update(time.time() - end)
        tgt_img = tgt_img.to(device)  #(4,3,128,416)
        ref_imgs = [img.to(device) for img in ref_imgs]  #batch size张图片的前一帧和后一帧
        intrinsics = intrinsics.to(device)  #(4,3,3)
        """forward and loss"""
        #2 compute output
        disparities = disp_net(
            tgt_img
        )  # lenth batch-size list of tensor(4,1,128,416) ,(4,1,64,208),(4,1,32,104),(4,1,16,52)]

        explainability_mask, pose = pose_exp_net(
            tgt_img,
            ref_imgs)  #pose tensor(bs,sq-lenth-1,6), relative camera pose

        depth = [1 / disp
                 for disp in disparities]  #depth = fxT/(d) 成反比关系,简单取倒数

        #3 loss compute
        loss_1, warped, diff = photometric_reconstruction_loss(
            tgt_img, ref_imgs, intrinsics, depth, explainability_mask, pose,
            args.rotation_mode, args.padding_mode)
        if w2 > 0:
            loss_2 = explainability_loss(explainability_mask)
        else:
            loss_2 = 0

        loss_3 = smooth_loss(depth)

        loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3

        #4. 数据记录 tensorboard batch-record data, 而且不用初始化数据名称(自动初始化),直接往里面加
        if log_losses:
            train_writer.add_scalar('photometric_error', loss_1.item(), n_iter)
            if w2 > 0:
                train_writer.add_scalar('explanabilityyyyyy_loss',
                                        loss_2.item(), n_iter)
            train_writer.add_scalar('disparity_smoothness_loss', loss_3.item(),
                                    n_iter)
            train_writer.add_scalar('total_loss', loss.item(), n_iter)

        if log_output:  #数据弄到tensorboard可读文件里去, 名字就是events开头(defaulted)
            train_writer.add_image('train Input', tensor2array(tgt_img[0]),
                                   n_iter)
            for k, scaled_maps in enumerate(
                    zip(depth, disparities, warped, diff,
                        explainability_mask)):
                log_output_tensorboard(train_writer, "train", k, n_iter,
                                       *scaled_maps)

        # record loss and EPE
        losses.update(loss.item(), args.batch_size)

        # compute gradient and do Adam step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        #csv record
        with open(args.save_path / args.log_full, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([
                loss.item(),
                loss_1.item(),
                loss_2.item() if w2 > 0 else 0,
                loss_3.item()
            ])
        logger.train_bar.update(i + 1)

        if i % args.print_freq == 0:
            logger.train_writer.write('Train: Time {} Data {} Loss {}'.format(
                batch_time, data_time, losses))
        if i >= epoch_size - 1:
            break

        n_iter += 1

    return losses.avg[0]
Beispiel #14
0
def train(args, train_loader, disp_net, pose_exp_net, optimizer, epoch_size,
          logger, train_writer):
    global n_iter, device
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter(precision=4)
    w1, w2, w3 = args.photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight

    # switch to train mode
    disp_net.train()
    pose_exp_net.train()

    end = time.time()
    logger.train_bar.update(0)

    for i, (tgt_img, ref_imgs, intrinsics,
            intrinsics_inv) in enumerate(train_loader):
        log_losses = i > 0 and n_iter % args.print_freq == 0
        log_output = args.training_output_freq > 0 and n_iter % args.training_output_freq == 0

        # measure data loading time
        data_time.update(time.time() - end)
        tgt_img = tgt_img.to(device)
        ref_imgs = [img.to(device) for img in ref_imgs]
        intrinsics = intrinsics.to(device)

        # compute output
        disparities = disp_net(tgt_img)
        depth = [1 / disp for disp in disparities]
        explainability_mask, pose = pose_exp_net(tgt_img, ref_imgs)

        loss_1, warped, diff = photometric_reconstruction_loss(
            tgt_img, ref_imgs, intrinsics, depth, explainability_mask, pose,
            args.rotation_mode, args.padding_mode)
        if w2 > 0:
            loss_2 = explainability_loss(explainability_mask)
        else:
            loss_2 = 0
        loss_3 = inverse_depth_smooth_loss(depth, tgt_img)

        loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3

        if log_losses:
            train_writer.add_scalar('photometric_error', loss_1.item(), n_iter)
            if w2 > 0:
                train_writer.add_scalar('explanability_loss', loss_2.item(),
                                        n_iter)
            train_writer.add_scalar('disparity_smoothness_loss', loss_3.item(),
                                    n_iter)
            train_writer.add_scalar('total_loss', loss.item(), n_iter)

        if log_output:
            train_writer.add_image('train Input', tensor2array(tgt_img[0]),
                                   n_iter)
            for k, scaled_maps in enumerate(
                    zip(depth, disparities, warped, diff,
                        explainability_mask)):
                log_output_tensorboard(train_writer, "train", k, n_iter,
                                       *scaled_maps)

        # record loss and EPE
        losses.update(loss.item(), args.batch_size)

        # compute gradient and do Adam step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        with open(args.save_path / args.log_full, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([
                loss.item(),
                loss_1.item(),
                loss_2.item() if w2 > 0 else 0,
                loss_3.item()
            ])
        logger.train_bar.update(i + 1)
        if i % args.print_freq == 0:
            logger.train_writer.write('Train: Time {} Data {} Loss {}'.format(
                batch_time, data_time, losses))
        if i >= epoch_size - 1:
            break

        n_iter += 1

    return losses.avg[0]
Beispiel #15
0
def validate_without_gt(args,
                        val_loader,
                        disp_net,
                        pose_exp_net,
                        epoch,
                        logger,
                        tb_writer,
                        sample_nb_to_log=3):
    global device
    mse_l = torch.nn.MSELoss(reduction='mean')
    batch_time = AverageMeter()
    losses = AverageMeter(i=3, precision=4)
    log_outputs = sample_nb_to_log > 0
    # Output the logs throughout the whole dataset
    batches_to_log = list(
        np.linspace(0, len(val_loader), sample_nb_to_log).astype(int))
    w1, w2, w3, wf, wp = args.photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight, args.flow_loss_weight, args.prior_loss_weight
    poses = np.zeros(
        ((len(val_loader) - 1) * args.batch_size * (args.sequence_length - 1),
         6))
    disp_values = np.zeros(((len(val_loader) - 1) * args.batch_size * 3))

    # switch to evaluate mode
    disp_net.eval()
    pose_exp_net.eval()

    end = time.time()
    logger.valid_bar.update(0)
    for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv,
            flow_maps) in enumerate(val_loader):
        tgt_img = tgt_img.to(device)
        ref_imgs = [img.to(device) for img in ref_imgs]
        intrinsics = intrinsics.to(device)
        intrinsics_inv = intrinsics_inv.to(device)
        flow_maps = [flow_map.to(device) for flow_map in flow_maps]

        # compute output
        disp = disp_net(tgt_img)
        depth = 1 / disp
        explainability_mask, pose = pose_exp_net(tgt_img, ref_imgs)

        loss_1, warped, diff, grid = photometric_reconstruction_loss(
            tgt_img, ref_imgs, intrinsics, depth, explainability_mask, pose,
            args.rotation_mode, args.padding_mode)
        loss_1 = loss_1.item()

        if wf > 0:
            loss_f = flow_consistency_loss(grid, flow_maps, mse_l)
        else:
            loss_f = 0

        if wp > 0:
            loss_p = ground_prior_loss(disp)
        else:
            loss_p = 0

        if w2 > 0:
            loss_2 = explainability_loss(explainability_mask).item()
        else:
            loss_2 = 0
        loss_3 = smooth_loss(depth).item()

        if log_outputs and i in batches_to_log:  # log first output of wanted batches
            index = batches_to_log.index(i)
            if epoch == 0:
                for j, ref in enumerate(ref_imgs):
                    tb_writer.add_image('val Input {}/{}'.format(j, index),
                                        tensor2array(tgt_img[0]), 0)
                    tb_writer.add_image('val Input {}/{}'.format(j, index),
                                        tensor2array(ref[0]), 1)

            log_output_tensorboard(tb_writer, 'val', index, '', epoch,
                                   1. / disp, disp, warped[0], diff[0],
                                   explainability_mask)

        if log_outputs and i < len(val_loader) - 1:
            step = args.batch_size * (args.sequence_length - 1)
            poses[i * step:(i + 1) * step] = pose.cpu().view(-1, 6).numpy()
            step = args.batch_size * 3
            disp_unraveled = disp.cpu().view(args.batch_size, -1)
            disp_values[i * step:(i + 1) * step] = torch.cat([
                disp_unraveled.min(-1)[0],
                disp_unraveled.median(-1)[0],
                disp_unraveled.max(-1)[0]
            ]).numpy()

        loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3 + wf * loss_f + wp * loss_p
        losses.update([loss, loss_1, loss_2])

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        logger.valid_bar.update(i + 1)
        if i % args.print_freq == 0:
            logger.valid_writer.write('valid: Time {} Loss {}'.format(
                batch_time, losses))
    if log_outputs:
        prefix = 'valid poses'
        coeffs_names = ['tx', 'ty', 'tz']
        if args.rotation_mode == 'euler':
            coeffs_names.extend(['rx', 'ry', 'rz'])
        elif args.rotation_mode == 'quat':
            coeffs_names.extend(['qx', 'qy', 'qz'])
        for i in range(poses.shape[1]):
            tb_writer.add_histogram('{} {}'.format(prefix, coeffs_names[i]),
                                    poses[:, i], epoch)
        tb_writer.add_histogram('disp_values', disp_values, epoch)
    logger.valid_bar.update(len(val_loader))
    return losses.avg, [
        'Validation Total loss', 'Validation Photo loss', 'Validation Exp loss'
    ]
Beispiel #16
0
def validate_without_gt(args,
                        val_loader,
                        disp_net,
                        pose_exp_net,
                        epoch,
                        logger,
                        output_writers=[]):
    global device
    batch_time = AverageMeter()
    losses = AverageMeter(i=3, precision=4)
    log_outputs = len(output_writers) > 0
    w1, w2, w3 = args.photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight
    poses = np.zeros(
        ((len(val_loader) - 1) * args.batch_size * (args.sequence_length - 1),
         6))
    disp_values = np.zeros(((len(val_loader) - 1) * args.batch_size * 3))

    # switch to evaluate mode
    disp_net.eval()
    pose_exp_net.eval()

    end = time.time()
    logger.valid_bar.update(0)
    for i, (tgt_img, ref_imgs, intrinsics,
            intrinsics_inv) in enumerate(val_loader):
        tgt_img = tgt_img.to(device)
        ref_imgs = [img.to(device) for img in ref_imgs]
        intrinsics = intrinsics.to(device)
        intrinsics_inv = intrinsics_inv.to(device)

        # compute output
        disp = disp_net(tgt_img)
        depth = 1 / disp
        explainability_mask, pose = pose_exp_net(tgt_img, ref_imgs)

        loss_1, warped, diff = photometric_reconstruction_loss(
            tgt_img, ref_imgs, intrinsics, depth, explainability_mask, pose,
            args.rotation_mode, args.padding_mode)
        loss_1 = loss_1.item()
        if w2 > 0:
            loss_2 = explainability_loss(explainability_mask).item()
        else:
            loss_2 = 0
        loss_3 = smooth_loss(depth).item()

        if log_outputs and i < len(
                output_writers):  # log first output of first batches
            if epoch == 0:
                for j, ref in enumerate(ref_imgs):
                    output_writers[i].add_image('val Input {}'.format(j),
                                                tensor2array(tgt_img[0]), 0)
                    output_writers[i].add_image('val Input {}'.format(j),
                                                tensor2array(ref[0]), 1)

            log_output_tensorboard(output_writers[i], 'val', '', epoch,
                                   1. / disp, disp, warped, diff,
                                   explainability_mask)

        if log_outputs and i < len(val_loader) - 1:
            step = args.batch_size * (args.sequence_length - 1)
            poses[i * step:(i + 1) * step] = pose.cpu().view(-1, 6).numpy()
            step = args.batch_size * 3
            disp_unraveled = disp.cpu().view(args.batch_size, -1)
            disp_values[i * step:(i + 1) * step] = torch.cat([
                disp_unraveled.min(-1)[0],
                disp_unraveled.median(-1)[0],
                disp_unraveled.max(-1)[0]
            ]).numpy()

        loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3
        losses.update([loss, loss_1, loss_2])

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        logger.valid_bar.update(i + 1)
        if i % args.print_freq == 0:
            logger.valid_writer.write('valid: Time {} Loss {}'.format(
                batch_time, losses))
    if log_outputs:
        prefix = 'valid poses'
        coeffs_names = ['tx', 'ty', 'tz']
        if args.rotation_mode == 'euler':
            coeffs_names.extend(['rx', 'ry', 'rz'])
        elif args.rotation_mode == 'quat':
            coeffs_names.extend(['qx', 'qy', 'qz'])
        for i in range(poses.shape[1]):
            output_writers[0].add_histogram(
                '{} {}'.format(prefix, coeffs_names[i]), poses[:, i], epoch)
        output_writers[0].add_histogram('disp_values', disp_values, epoch)
    logger.valid_bar.update(len(val_loader))
    return losses.avg, ['Total loss', 'Photo loss', 'Exp loss']
def train(args, train_loader, disp_net, pose_exp_net, optimizer, epoch_size,
          logger, tb_writer):
    global n_iter, device
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter(precision=4)
    w1, w2, w3 = args.photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight

    # switch to train mode
    disp_net.train()
    pose_exp_net.train()

    end = time.time()
    logger.train_bar.update(0)

    for i, (tgt_img, ref_imgs, intrinsics,
            intrinsics_inv) in enumerate(train_loader):
        log_losses = i > 0 and n_iter % args.print_freq == 0
        log_output = args.training_output_freq > 0 and n_iter % args.training_output_freq == 0

        # measure data loading time
        data_time.update(time.time() - end)
        tgt_img = tgt_img.to(device)
        ref_imgs = [img.to(device) for img in ref_imgs]
        intrinsics = intrinsics.to(device)

        # compute output
        disparities = disp_net(tgt_img)
        depth = [1 / disp for disp in disparities]
        #         print("***",len(depth),depth[0].size())
        explainability_mask, pose = pose_exp_net(tgt_img, ref_imgs)

        loss_1, warped, diff = photometric_reconstruction_loss(
            tgt_img, ref_imgs, intrinsics, depth, explainability_mask, pose,
            args.rotation_mode, args.padding_mode)
        if w2 > 0:
            loss_2 = explainability_loss(explainability_mask)
        else:
            loss_2 = 0
        loss_3 = smooth_loss(depth)

        loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3

        if args.with_photocon_loss:
            batch_size = pose.size()[0]
            homo_row = torch.tensor([[0, 0, 0, 1]],
                                    dtype=torch.float).to(device)
            homo_row = homo_row.unsqueeze(0).expand(batch_size, -1, -1)
            T21 = pose_vec2mat(pose[:, 0])
            T21 = torch.cat((T21, homo_row), 1)
            T12 = torch.inverse(T21)
            T23 = pose_vec2mat(pose[:, 1])
            T23 = torch.cat((T23, homo_row), 1)
            T13 = torch.matmul(T23, T12)  #[B, 4, 4]
            #             print("----",T13.size())
            # target = 1 and ref = 3
            ref_img_warped, valid_points = inverse_warp_posemat(
                ref_imgs[1], depth[0][:, 0], T13, intrinsics,
                args.rotation_mode, args.padding_mode)
            diff = (ref_imgs[0] -
                    ref_img_warped) * valid_points.unsqueeze(1).float()
            loss_4 = diff.abs().mean()

            loss += loss_4

        if log_losses:
            tb_writer.add_scalar('photometric_error', loss_1.item(), n_iter)
            if w2 > 0:
                tb_writer.add_scalar('explanability_loss', loss_2.item(),
                                     n_iter)
            tb_writer.add_scalar('disparity_smoothness_loss', loss_3.item(),
                                 n_iter)
            tb_writer.add_scalar('total_loss', loss.item(), n_iter)

        if log_output:
            tb_writer.add_image('train Input', tensor2array(tgt_img[0]),
                                n_iter)
            for k, scaled_maps in enumerate(
                    zip(depth, disparities, warped, diff,
                        explainability_mask)):
                log_output_tensorboard(tb_writer, "train", 0, " {}".format(k),
                                       n_iter, *scaled_maps)

        # record loss and EPE
        losses.update(loss.item(), args.batch_size)

        # compute gradient and do Adam step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        with open(args.save_path / args.log_full, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([
                loss.item(),
                loss_1.item(),
                loss_2.item() if w2 > 0 else 0,
                loss_3.item()
            ])
        logger.train_bar.update(i + 1)
        if i % args.print_freq == 0:
            logger.train_writer.write('Train: Time {} Data {} Loss {}'.format(
                batch_time, data_time, losses))
        if i >= epoch_size - 1:
            break

        n_iter += 1

    return losses.avg[0]