Example #1
0
def train(args, train_loader, mvdnet, depth_cons, cons_loss_, optimizer,
          epoch_size, train_writer, epoch):
    global n_iter
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter(precision=4)
    d_losses = AverageMeter(precision=4)
    nmap_losses = AverageMeter(precision=4)
    cons_losses = AverageMeter(precision=4)

    # switch to training mode
    if args.train_cons:
        depth_cons.train()
    else:
        mvdnet.train()

    print("Training")
    end = time.time()

    for i, (tgt_img, ref_imgs, gt_nmap, ref_poses, intrinsics, intrinsics_inv,
            tgt_depth) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        tgt_img_var = Variable(tgt_img.cuda())
        ref_imgs_var = [Variable(img.cuda()) for img in ref_imgs]
        gt_nmap_var = Variable(gt_nmap.cuda())
        ref_poses_var = [Variable(pose.cuda()) for pose in ref_poses]
        intrinsics_var = Variable(intrinsics.cuda())
        intrinsics_inv_var = Variable(intrinsics_inv.cuda())
        tgt_depth_var = Variable(tgt_depth.cuda()).cuda()

        # compute output
        pose = torch.cat(ref_poses_var, 1)

        # get mask
        mask = (tgt_depth_var <= args.nlabel * args.mindepth) & (
            tgt_depth_var >= args.mindepth) & (tgt_depth_var == tgt_depth_var)
        mask.detach_()
        if mask.any() == 0:
            continue

        if args.train_cons:
            with torch.no_grad():
                outputs = mvdnet(tgt_img_var, ref_imgs_var, pose,
                                 intrinsics_var, intrinsics_inv_var)
                output_depth1 = outputs[0]
                nmap1 = outputs[1]
        else:
            outputs = mvdnet(tgt_img_var, ref_imgs_var, pose, intrinsics_var,
                             intrinsics_inv_var)
            output_depth1 = outputs[1]
            nmap1 = outputs[2]

        if args.train_cons:
            outputs = depth_cons(output_depth1, nmap1)
            nmap = outputs[:, 1:]
            depths = [outputs[:, 0]]
        else:
            nmap = nmap1.permute(0, 3, 1, 2)
            depths = [output_depth1.squeeze(1)]

        loss = 0.
        d_loss = 0.
        nmap_loss = 0.
        cons_loss = 0.

        for l, depth in enumerate(depths):
            output = torch.squeeze(depth, 1)
            d_loss = d_loss + F.smooth_l1_loss(output[mask],
                                               tgt_depth_var[mask])

        n_mask = mask.unsqueeze(1).expand(-1, 3, -1, -1)
        nmap_loss = nmap_loss + F.smooth_l1_loss(nmap[n_mask],
                                                 gt_nmap_var[n_mask])

        if args.train_cons:
            cons_loss = cons_loss + cons_loss_(
                depths[-1].unsqueeze(1), tgt_depth_var.unsqueeze(1),
                nmap.clone(), intrinsics_var, mask.unsqueeze(1))
            cons_losses.update(cons_loss.item(), args.batch_size)
        loss = loss + args.d_weight * d_loss + args.n_weight * nmap_loss + args.c_weight * cons_loss

        if i > 0 and n_iter % args.print_freq == 0:
            train_writer.add_scalar('total_loss', loss.item(), n_iter)
        # record loss and EPE
        losses.update(loss.item(), args.batch_size)
        d_losses.update(d_loss.item(), args.batch_size)
        nmap_losses.update(nmap_loss.item(), args.batch_size)

        # compute gradient and do Adam step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        with open(args.save_path / args.log_full, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([loss.item()])
        if i % args.print_freq == 0:
            print(
                'Train: Time {} Data {} Loss {} NmapLoss {} DLoss {} ConsLoss {}Iter {}/{} Epoch {}/{}'
                .format(batch_time, data_time, losses, nmap_losses, d_losses,
                        cons_losses, i, len(train_loader), epoch, args.epochs))

        if i >= epoch_size - 1:
            break

        n_iter += 1

    return losses.avg[0]
Example #2
0
def validate_without_gt(args,
                        val_loader,
                        disp_net,
                        pose_net,
                        epoch,
                        logger,
                        output_writers=[]):
    global device
    batch_time = AverageMeter()
    losses = AverageMeter(i=4, precision=4)
    log_outputs = len(output_writers) > 0

    # switch to evaluate mode
    disp_net.eval()
    pose_net.eval()

    end = time.time()
    logger.valid_bar.update(0)
    for i, (tgt_img, ref_imgs, intrinsics,
            intrinsics_inv) in enumerate(val_loader):
        tgt_img = tgt_img.to(device)
        ref_imgs = [img.to(device) for img in ref_imgs]
        intrinsics = intrinsics.to(device)
        intrinsics_inv = intrinsics_inv.to(device)

        # compute output
        tgt_depth = [1 / disp_net(tgt_img)]
        ref_depths = []
        for ref_img in ref_imgs:
            ref_depth = [1 / disp_net(ref_img)]
            ref_depths.append(ref_depth)

        if log_outputs and i < len(output_writers):
            if epoch == 0:
                output_writers[i].add_image('val Input',
                                            tensor2array(tgt_img[0]), 0)

            output_writers[i].add_image(
                'val Dispnet Output Normalized',
                tensor2array(1 / tgt_depth[0][0],
                             max_value=None,
                             colormap='magma'), epoch)
            output_writers[i].add_image(
                'val Depth Output', tensor2array(tgt_depth[0][0],
                                                 max_value=10), epoch)

        poses, poses_inv = compute_pose_with_inv(pose_net, tgt_img, ref_imgs,
                                                 intrinsics)

        loss_1, loss_3 = compute_photo_and_geometry_loss(
            tgt_img, ref_imgs, intrinsics, tgt_depth, ref_depths, poses,
            poses_inv, args.num_scales, args.with_ssim, args.with_mask, False,
            args.padding_mode)

        loss_2 = compute_smooth_loss(tgt_depth, tgt_img, ref_depths, ref_imgs)

        loss_1 = loss_1.item()
        loss_2 = loss_2.item()
        loss_3 = loss_3.item()

        loss = loss_1
        losses.update([loss, loss_1, loss_2, loss_3])

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        logger.valid_bar.update(i + 1)
        if i % args.print_freq == 0:
            logger.valid_writer.write('valid: Time {} Loss {}'.format(
                batch_time, losses))

    logger.valid_bar.update(len(val_loader))
    return losses.avg, [
        'Total loss', 'Photo loss', 'Smooth loss', 'Consistency loss'
    ]
Example #3
0
def train(args, train_loader, disp_net, pose_exp_net, optimizer, epoch_size,
          logger, tb_writer, n_iter, torch_device):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter(precision=4)
    w1, w2, w3, w4 = args.photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight, args.gt_pose_loss_weight

    # switch to train mode
    disp_net.train()
    pose_exp_net.train()

    end = time.time()
    logger.train_bar.update(0)

    for i, (tgt_img, tgt_lf, ref_imgs, ref_lfs, intrinsics, intrinsics_inv,
            pose_gt) in enumerate(train_loader):
        log_losses = i > 0 and n_iter % args.print_freq == 0
        log_output = args.training_output_freq > 0 and n_iter % args.training_output_freq == 0

        # measure data loading time
        data_time.update(time.time() - end)
        tgt_img = tgt_img.to(torch_device)
        ref_imgs = [img.to(torch_device) for img in ref_imgs]
        tgt_lf = tgt_lf.to(torch_device)
        ref_lfs = [lf.to(torch_device) for lf in ref_lfs]
        intrinsics = intrinsics.to(torch_device)
        pose_gt = pose_gt.to(torch_device)

        # compute output
        disparities = disp_net(tgt_lf)
        depth = [1 / disp for disp in disparities]

        explainability_mask, pose = pose_exp_net(tgt_lf, ref_lfs)

        loss_1, warped, diff = photometric_reconstruction_loss(
            tgt_img, ref_imgs, intrinsics, depth, explainability_mask, pose,
            args.rotation_mode, args.padding_mode)

        if w2 > 0:
            loss_2 = explainability_loss(explainability_mask)
        else:
            loss_2 = 0
        loss_3 = smooth_loss(depth)

        pred_pose_magnitude = pose[:, :, :3].norm(dim=2)
        pose_gt_magnitude = pose_gt[:, :, :3].norm(dim=2)
        pose_loss = (pred_pose_magnitude - pose_gt_magnitude).abs().mean()

        loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3 + w4 * pose_loss

        if log_losses:
            tb_writer.add_scalar('train/photometric_error', loss_1.item(),
                                 n_iter)
            tb_writer.add_scalar('train/smoothness_loss', loss_3.item(),
                                 n_iter)
            tb_writer.add_scalar('train/total_loss', loss.item(), n_iter)
            tb_writer.add_scalar('train/pose_loss', pose_loss.item(), n_iter)
            if w2 > 0:
                tb_writer.add_scalar('train/explanability_loss', loss_2.item(),
                                     n_iter)

        if log_output:
            tb_writer.add_image('train/Input', tensor2array(tgt_img[0]),
                                n_iter)
            for k, scaled_maps in enumerate(
                    zip(depth, disparities, warped, diff,
                        explainability_mask)):
                log_output_tensorboard(tb_writer, "train", 0, k, n_iter,
                                       *scaled_maps)
                break

        # record loss and EPE
        losses.update(loss.item(), args.batch_size)

        # compute gradient and do Adam step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        with open(args.save_path + "/" + args.log_full, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([
                loss.item(),
                loss_1.item(),
                loss_2.item() if w2 > 0 else 0,
                loss_3.item()
            ])
        logger.train_bar.update(i + 1)
        if i % args.print_freq == 0:
            logger.train_writer.write('Train: Time {} Data {} Loss {}'.format(
                batch_time, data_time, losses))
        if i >= epoch_size - 1:
            break

        n_iter += 1

    return losses.avg[0]
Example #4
0
def validate_without_gt(args,
                        val_loader,
                        disp_net,
                        pose_exp_net,
                        epoch,
                        logger,
                        output_writers=[]):
    batch_time = AverageMeter()
    losses = AverageMeter(i=3, precision=4)
    log_outputs = len(output_writers) > 0
    w1, w2, w3 = args.photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight
    poses = np.zeros(
        ((len(val_loader) - 1) * args.batch_size * (args.sequence_length - 1),
         6))
    disp_values = np.zeros(((len(val_loader) - 1) * args.batch_size * 3))

    # switch to evaluate mode
    disp_net.eval()
    pose_exp_net.eval()

    end = time.time()
    logger.valid_bar.update(0)
    for i, (tgt_img, ref_imgs, intrinsics,
            intrinsics_inv) in enumerate(val_loader):
        tgt_img_var = Variable(tgt_img.cuda(), volatile=True)
        ref_imgs_var = [
            Variable(img.cuda(), volatile=True) for img in ref_imgs
        ]
        intrinsics_var = Variable(intrinsics.cuda(), volatile=True)
        intrinsics_inv_var = Variable(intrinsics_inv.cuda(), volatile=True)

        # compute output
        disp = disp_net(tgt_img_var)
        depth = 1 / disp
        explainability_mask, pose = pose_exp_net(tgt_img_var, ref_imgs_var)

        loss_1 = photometric_reconstruction_loss(tgt_img_var, ref_imgs_var,
                                                 intrinsics_var,
                                                 intrinsics_inv_var, depth,
                                                 explainability_mask, pose,
                                                 args.rotation_mode,
                                                 args.padding_mode)
        loss_1 = loss_1.data[0]
        if w2 > 0:
            loss_2 = explainability_loss(explainability_mask).data[0]
        else:
            loss_2 = 0
        loss_3 = smooth_loss(disp).data[0]

        if log_outputs and i % 100 == 0 and i / 100 < len(
                output_writers):  # log first output of every 100 batch
            index = int(i // 100)
            if epoch == 0:
                for j, ref in enumerate(ref_imgs):
                    output_writers[index].add_image('val Input {}'.format(j),
                                                    tensor2array(tgt_img[0]),
                                                    0)
                    output_writers[index].add_image('val Input {}'.format(j),
                                                    tensor2array(ref[0]), 1)

            output_writers[index].add_image(
                'val Dispnet Output Normalized',
                tensor2array(disp.data[0].cpu(),
                             max_value=None,
                             colormap='bone'), epoch)
            output_writers[index].add_image(
                'val Depth Output',
                tensor2array(1. / disp.data[0].cpu(), max_value=10), epoch)
            # log warped images along with explainability mask
            for j, ref in enumerate(ref_imgs_var):
                ref_warped = inverse_warp(ref[:1],
                                          depth[:1, 0],
                                          pose[:1, j],
                                          intrinsics_var[:1],
                                          intrinsics_inv_var[:1],
                                          rotation_mode=args.rotation_mode,
                                          padding_mode=args.padding_mode)[0]

                output_writers[index].add_image(
                    'val Warped Outputs {}'.format(j),
                    tensor2array(ref_warped.data.cpu()), epoch)
                output_writers[index].add_image(
                    'val Diff Outputs {}'.format(j),
                    tensor2array(
                        0.5 * (tgt_img_var[0] - ref_warped).abs().data.cpu()),
                    epoch)
                if explainability_mask is not None:
                    output_writers[index].add_image(
                        'val Exp mask Outputs {}'.format(j),
                        tensor2array(explainability_mask[0, j].data.cpu(),
                                     max_value=1,
                                     colormap='bone'), epoch)

        if log_outputs and i < len(val_loader) - 1:
            step = args.batch_size * (args.sequence_length - 1)
            poses[i * step:(i + 1) * step] = pose.data.cpu().view(-1,
                                                                  6).numpy()
            step = args.batch_size * 3
            disp_unraveled = disp.data.cpu().view(args.batch_size, -1)
            disp_values[i * step:(i + 1) * step] = torch.cat([
                disp_unraveled.min(-1)[0],
                disp_unraveled.median(-1)[0],
                disp_unraveled.max(-1)[0]
            ]).numpy()

        loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3
        losses.update([loss, loss_1, loss_2])

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        logger.valid_bar.update(i + 1)
        if i % args.print_freq == 0:
            logger.valid_writer.write('valid: Time {} Loss {}'.format(
                batch_time, losses))
    if log_outputs:
        prefix = 'valid poses'
        coeffs_names = ['tx', 'ty', 'tz']
        if args.rotation_mode == 'euler':
            coeffs_names.extend(['rx', 'ry', 'rz'])
        elif args.rotation_mode == 'quat':
            coeffs_names.extend(['qx', 'qy', 'qz'])
        for i in range(poses.shape[1]):
            output_writers.add_histogram(
                '{} {}'.format(prefix, coeffs_names[i]), poses[:, i], epoch)
        output_writers[0].add_histogram('disp_values', disp_values, epoch)
    logger.valid_bar.update(len(val_loader))
    return losses.avg, ['Total loss', 'Photo loss', 'Exp loss']
Example #5
0
def main():
    global args
    args = parser.parse_args()
    args.pretrained_disp = Path(args.pretrained_disp)
    args.pretrained_pose = Path(args.pretrained_pose)
    args.pretrained_mask = Path(args.pretrained_mask)
    args.pretrained_flow = Path(args.pretrained_flow)

    if args.output_dir is not None:
        args.output_dir = Path(args.output_dir)
        args.output_dir.makedirs_p()

        image_dir = args.output_dir / 'images'
        gt_dir = args.output_dir / 'gt'
        mask_dir = args.output_dir / 'mask'
        viz_dir = args.output_dir / 'viz'

        image_dir.makedirs_p()
        gt_dir.makedirs_p()
        mask_dir.makedirs_p()
        viz_dir.makedirs_p()

        output_writer = SummaryWriter(args.output_dir)

    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                            std=[0.5, 0.5, 0.5])
    flow_loader_h, flow_loader_w = 256, 832
    valid_flow_transform = custom_transforms.Compose([
        custom_transforms.Scale(h=flow_loader_h, w=flow_loader_w),
        custom_transforms.ArrayToTensor(), normalize
    ])
    if args.dataset == "kitti2015":
        val_flow_set = ValidationFlow(root=args.kitti_dir,
                                      sequence_length=5,
                                      transform=valid_flow_transform)

    val_loader = torch.utils.data.DataLoader(val_flow_set,
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=2,
                                             pin_memory=True,
                                             drop_last=True)

    disp_net = getattr(models, args.dispnet)().cuda()
    pose_net = getattr(models, args.posenet)(nb_ref_imgs=4).cuda()
    mask_net = getattr(models, args.masknet)(nb_ref_imgs=4).cuda()
    flow_net = getattr(models, args.flownet)(nlevels=args.nlevels).cuda()

    dispnet_weights = torch.load(args.pretrained_disp)
    posenet_weights = torch.load(args.pretrained_pose)
    masknet_weights = torch.load(args.pretrained_mask)
    flownet_weights = torch.load(args.pretrained_flow)
    disp_net.load_state_dict(dispnet_weights['state_dict'])
    pose_net.load_state_dict(posenet_weights['state_dict'])
    flow_net.load_state_dict(flownet_weights['state_dict'])
    mask_net.load_state_dict(masknet_weights['state_dict'])

    disp_net.eval()
    pose_net.eval()
    mask_net.eval()
    flow_net.eval()

    error_names = [
        'epe_total', 'epe_sp', 'epe_mv', 'Fl', 'epe_total_gt_mask',
        'epe_sp_gt_mask', 'epe_mv_gt_mask', 'Fl_gt_mask'
    ]
    errors = AverageMeter(i=len(error_names))
    for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv, flow_gt,
            obj_map_gt) in enumerate(tqdm(val_loader)):
        tgt_img_var = Variable(tgt_img.cuda(), volatile=True)
        ref_imgs_var = [
            Variable(img.cuda(), volatile=True) for img in ref_imgs
        ]
        intrinsics_var = Variable(intrinsics.cuda(), volatile=True)
        intrinsics_inv_var = Variable(intrinsics_inv.cuda(), volatile=True)

        flow_gt_var = Variable(flow_gt.cuda(), volatile=True)
        obj_map_gt_var = Variable(obj_map_gt.cuda(), volatile=True)

        disp = disp_net(tgt_img_var)
        depth = 1 / disp
        pose = pose_net(tgt_img_var, ref_imgs_var)
        explainability_mask = mask_net(tgt_img_var, ref_imgs_var)

        if args.flownet == 'Back2Future':
            flow_fwd, flow_bwd, _ = flow_net(tgt_img_var, ref_imgs_var[1:3])
        else:
            flow_fwd = flow_net(tgt_img_var, ref_imgs_var[2])

        flow_cam = pose2flow(depth.squeeze(1), pose[:, 2], intrinsics_var,
                             intrinsics_inv_var)
        flow_cam_bwd = pose2flow(depth.squeeze(1), pose[:, 1], intrinsics_var,
                                 intrinsics_inv_var)

        rigidity_mask = 1 - (1 - explainability_mask[:, 1]) * (
            1 - explainability_mask[:, 2]).unsqueeze(1) > 0.5
        rigidity_mask_census_soft = (flow_cam - flow_fwd).abs()  #.normalize()
        rigidity_mask_census_u = rigidity_mask_census_soft[:, 0] < args.THRESH
        rigidity_mask_census_v = rigidity_mask_census_soft[:, 1] < args.THRESH
        rigidity_mask_census = (rigidity_mask_census_u).type_as(flow_fwd) * (
            rigidity_mask_census_v).type_as(flow_fwd)

        rigidity_mask_combined = 1 - (
            1 - rigidity_mask.type_as(explainability_mask)) * (
                1 - rigidity_mask_census.type_as(explainability_mask))

        obj_map_gt_var_expanded = obj_map_gt_var.unsqueeze(1).type_as(flow_fwd)

        flow_fwd_non_rigid = (rigidity_mask_combined <= args.THRESH).type_as(
            flow_fwd).expand_as(flow_fwd) * flow_fwd
        flow_fwd_rigid = (rigidity_mask_combined > args.THRESH
                          ).type_as(flow_cam).expand_as(flow_cam) * flow_cam
        total_flow = flow_fwd_rigid + flow_fwd_non_rigid

        rigidity_mask = rigidity_mask.type_as(flow_fwd)
        _epe_errors = compute_all_epes(
            flow_gt_var, flow_cam,
            flow_fwd, rigidity_mask_combined) + compute_all_epes(
                flow_gt_var, flow_cam, flow_fwd, (1 - obj_map_gt_var_expanded))
        errors.update(_epe_errors)

        tgt_img_np = tgt_img[0].numpy()
        rigidity_mask_combined_np = rigidity_mask_combined.cpu().data[0].numpy(
        )
        gt_mask_np = obj_map_gt[0].numpy()

        if args.output_dir is not None:
            np.save(image_dir / str(i).zfill(3), tgt_img_np)
            np.save(gt_dir / str(i).zfill(3), gt_mask_np)
            np.save(mask_dir / str(i).zfill(3), rigidity_mask_combined_np)

        if (args.output_dir is not None) and i % 10 == 0:
            ind = int(i // 10)
            output_writer.add_image(
                'val Dispnet Output Normalized',
                tensor2array(disp.data[0].cpu(),
                             max_value=None,
                             colormap='bone'), ind)
            output_writer.add_image('val Input',
                                    tensor2array(tgt_img[0].cpu()), i)
            output_writer.add_image(
                'val Total Flow Output',
                flow_to_image(tensor2array(total_flow.data[0].cpu())), ind)
            output_writer.add_image(
                'val Rigid Flow Output',
                flow_to_image(tensor2array(flow_fwd_rigid.data[0].cpu())), ind)
            output_writer.add_image(
                'val Non-rigid Flow Output',
                flow_to_image(tensor2array(flow_fwd_non_rigid.data[0].cpu())),
                ind)
            output_writer.add_image(
                'val Rigidity Mask',
                tensor2array(rigidity_mask.data[0].cpu(),
                             max_value=1,
                             colormap='bone'), ind)
            output_writer.add_image(
                'val Rigidity Mask Census',
                tensor2array(rigidity_mask_census.data[0].cpu(),
                             max_value=1,
                             colormap='bone'), ind)
            output_writer.add_image(
                'val Rigidity Mask Combined',
                tensor2array(rigidity_mask_combined.data[0].cpu(),
                             max_value=1,
                             colormap='bone'), ind)

            tgt_img_viz = tensor2array(tgt_img[0].cpu())
            depth_viz = tensor2array(disp.data[0].cpu(),
                                     max_value=None,
                                     colormap='bone')
            mask_viz = tensor2array(
                rigidity_mask_census_soft.data[0].prod(dim=0).cpu(),
                max_value=1,
                colormap='bone')
            rigid_flow_viz = flow_to_image(tensor2array(
                flow_cam.data[0].cpu()))
            non_rigid_flow_viz = flow_to_image(
                tensor2array(flow_fwd_non_rigid.data[0].cpu()))
            total_flow_viz = flow_to_image(
                tensor2array(total_flow.data[0].cpu()))
            row1_viz = np.hstack((tgt_img_viz, depth_viz, mask_viz))
            row2_viz = np.hstack(
                (rigid_flow_viz, non_rigid_flow_viz, total_flow_viz))

            row1_viz_im = Image.fromarray((255 * row1_viz).astype('uint8'))
            row2_viz_im = Image.fromarray((row2_viz).astype('uint8'))

            row1_viz_im.save(viz_dir / str(i).zfill(3) + '01.png')
            row2_viz_im.save(viz_dir / str(i).zfill(3) + '02.png')

    print("Results")
    print("\t {:>10}, {:>10}, {:>10}, {:>6}, {:>10}, {:>10}, {:>10}, {:>10} ".
          format(*error_names))
    print(
        "Errors \t {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}"
        .format(*errors.avg))
Example #6
0
def train(train_loader,
          alice_net,
          bob_net,
          mod_net,
          optimizer,
          epoch_size,
          logger=None,
          train_writer=None,
          mode='compete'):
    global args, n_iter
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter(precision=4)

    # switch to train mode
    alice_net.train()
    bob_net.train()
    mod_net.train()

    end = time.time()

    for i, (img, target) in enumerate(train_loader):
        # measure data loading time
        #mode = 'compete' if (i%2)==0 else 'collaborate'

        data_time.update(time.time() - end)
        img_var = Variable(img.cuda())
        target_var = Variable(target.cuda())

        pred_alice = alice_net(img_var)
        pred_bob = bob_net(img_var)
        pred_mod = mod_net(img_var)

        loss_alice = F.cross_entropy(pred_alice, target_var, reduce=False)
        loss_bob = F.cross_entropy(pred_bob, target_var, reduce=False)

        if mode == 'compete':
            if args.fix_bob:
                if args.DEBUG: print("Training Alice Only")
                loss = loss_alice.mean()
            elif args.fix_alice:
                loss = loss_bob.mean()
            else:
                if args.DEBUG: print("Training Both Alice and Bob")

                pred_mod_soft = Variable(F.sigmoid(pred_mod).data,
                                         requires_grad=False)
                loss = pred_mod_soft * loss_alice + (1 -
                                                     pred_mod_soft) * loss_bob

                loss = loss.mean()

        elif mode == 'collaborate':
            loss_alice2 = Variable(loss_alice.data, requires_grad=False)
            loss_bob2 = Variable(loss_bob.data, requires_grad=False)

            loss1 = F.sigmoid(pred_mod) * loss_alice2 + (
                1 - F.sigmoid(pred_mod)) * loss_bob2

            loss2 = collaboration_loss(pred_mod, loss_alice2, loss_bob2)

            loss = loss1.mean() + loss2.mean(
            ) + args.wr * mod_regularization_loss(pred_mod)

        if i > 0 and n_iter % args.print_freq == 0:
            train_writer.add_scalar('loss_alice',
                                    loss_alice.mean().item(), n_iter)
            train_writer.add_scalar('loss_bob', loss_bob.mean().item(), n_iter)
            train_writer.add_scalar('mod_mean',
                                    F.sigmoid(pred_mod).mean().item(), n_iter)
            train_writer.add_scalar('mod_var',
                                    F.sigmoid(pred_mod).var().item(), n_iter)
            train_writer.add_scalar('loss_regularization',
                                    mod_regularization_loss(pred_mod).item(),
                                    n_iter)

            if mode == 'compete':
                train_writer.add_scalar('competetion_loss', loss.item(),
                                        n_iter)
            elif mode == 'collaborate':
                train_writer.add_scalar('collaboration_loss', loss.item(),
                                        n_iter)

        # record loss
        losses.update(loss.item(), args.batch_size)

        # compute gradient and do Adam step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        with open(args.save_path / args.log_full, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([
                loss.item(),
                loss_alice.mean().item(),
                loss_bob.mean().item()
            ])
        if args.log_terminal:
            logger.train_bar.update(i + 1)
            if i % args.print_freq == 0:
                logger.train_writer.write(
                    'Train: Time {} Data {} Loss {}'.format(
                        batch_time, data_time, losses))
        if i >= epoch_size - 1:
            break

        n_iter += 1

    return losses.avg[0]
def validate_with_gt(args, val_loader, mvdnet, epoch, output_writers=[]):
    batch_time = AverageMeter()
    error_names = [
        'abs_rel', 'abs_diff', 'sq_rel', 'a1', 'a2', 'a3', 'mean_angle'
    ]
    test_error_names = [
        'abs_rel', 'abs_diff', 'sq_rel', 'rms', 'log_rms', 'a1', 'a2', 'a3',
        'mean_angle'
    ]
    errors = AverageMeter(i=len(error_names))
    test_errors = AverageMeter(i=len(test_error_names))
    log_outputs = len(output_writers) > 0

    output_dir = Path(args.output_dir)
    if not os.path.isdir(output_dir):
        os.mkdir(output_dir)

    # switch to evaluate mode
    mvdnet.eval()

    end = time.time()
    with torch.no_grad():
        for i, (tgt_img, ref_imgs, gt_nmap, ref_poses, intrinsics,
                intrinsics_inv, tgt_depth) in enumerate(val_loader):
            tgt_img_var = Variable(tgt_img.cuda())
            ref_imgs_var = [Variable(img.cuda()) for img in ref_imgs]
            gt_nmap_var = Variable(gt_nmap.cuda())
            ref_poses_var = [Variable(pose.cuda()) for pose in ref_poses]
            intrinsics_var = Variable(intrinsics.cuda())
            intrinsics_inv_var = Variable(intrinsics_inv.cuda())
            tgt_depth_var = Variable(tgt_depth.cuda())

            pose = torch.cat(ref_poses_var, 1)

            if (pose != pose).any():
                continue

            if args.dataset == 'sceneflow':
                factor = (1.0 / args.scale) * intrinsics_var[:, 0, 0] / 1050.0
                factor = factor.view(-1, 1, 1)
            else:
                factor = torch.ones(
                    (tgt_depth_var.size(0), 1, 1)).type_as(tgt_depth_var)

            # get mask
            mask = (tgt_depth_var <= args.nlabel * args.mindepth * factor *
                    3) & (tgt_depth_var >= args.mindepth * factor) & (
                        tgt_depth_var == tgt_depth_var)

            if not mask.any():
                continue

            output_depth, nmap = mvdnet(tgt_img_var,
                                        ref_imgs_var,
                                        pose,
                                        intrinsics_var,
                                        intrinsics_inv_var,
                                        factor=factor.unsqueeze(1))
            output_disp = args.nlabel * args.mindepth / (output_depth)
            if args.dataset == 'sceneflow':
                output_disp = (args.nlabel *
                               args.mindepth) * 3 / (output_depth)
                output_depth = (args.nlabel * 3) * (args.mindepth *
                                                    factor) / output_disp

            tgt_disp_var = ((1.0 / args.scale) *
                            intrinsics_var[:, 0, 0].view(-1, 1, 1) /
                            tgt_depth_var)

            if args.dataset == 'sceneflow':
                output = torch.squeeze(output_disp.data.cpu(), 1)
                errors_ = compute_errors_train(tgt_disp_var.cpu(), output,
                                               mask)
                test_errors_ = list(
                    compute_errors_test(tgt_disp_var.cpu()[mask],
                                        output[mask]))
            else:
                output = torch.squeeze(output_depth.data.cpu(), 1)
                errors_ = compute_errors_train(tgt_depth, output, mask)
                test_errors_ = list(
                    compute_errors_test(tgt_depth[mask], output[mask]))

            n_mask = (gt_nmap_var.permute(0, 2, 3, 1)[0, :, :] != 0)
            n_mask = n_mask[:, :, 0] | n_mask[:, :, 1] | n_mask[:, :, 2]
            total_angles_m = compute_angles(
                gt_nmap_var.permute(0, 2, 3, 1)[0], nmap[0])

            mask_angles = total_angles_m[n_mask]
            total_angles_m[~n_mask] = 0
            errors_.append(
                torch.mean(mask_angles).item()
            )  #/mask_angles.size(0)#[torch.sum(mask_angles).item(), (mask_angles.size(0)),  torch.sum(mask_angles < 7.5).item(), torch.sum(mask_angles < 15).item(), torch.sum(mask_angles < 30).item(), torch.sum(mask_angles < 45).item()]
            test_errors_.append(torch.mean(mask_angles).item())
            errors.update(errors_)
            test_errors.update(test_errors_)
            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()
            if args.output_print:
                np.save(output_dir / '{:04d}{}'.format(i, '_depth.npy'),
                        output.numpy()[0])
                plt.imsave(output_dir / '{:04d}_gt{}'.format(i, '.png'),
                           tgt_depth.numpy()[0],
                           cmap='rainbow')
                imsave(output_dir / '{:04d}_aimage{}'.format(i, '.png'),
                       np.transpose(tgt_img.numpy()[0], (1, 2, 0)))
                np.save(output_dir / '{:04d}_cam{}'.format(i, '.npy'),
                        intrinsics_var.cpu().numpy()[0])
                np.save(output_dir / '{:04d}{}'.format(i, '_normal.npy'),
                        nmap.cpu().numpy()[0])

            if i % args.print_freq == 0:
                print(
                    'valid: Time {} Abs Error {:.4f} ({:.4f}) Abs angle Error {:.4f} ({:.4f}) Iter {}/{}'
                    .format(batch_time, test_errors.val[0], test_errors.avg[0],
                            test_errors.val[-1], test_errors.avg[-1], i,
                            len(val_loader)))
    if args.output_print:
        np.savetxt(output_dir / args.ttype + 'errors.csv',
                   test_errors.avg,
                   fmt='%1.4f',
                   delimiter=',')
        np.savetxt(output_dir / args.ttype + 'angle_errors.csv',
                   test_errors.avg,
                   fmt='%1.4f',
                   delimiter=',')
    return errors.avg, error_names
Example #8
0
def train(train_loader, model, optimizer, epoch, args, log, mp=None):
    '''train given model and dataloader'''
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    mixing_avg = []

    # switch to train mode
    model.train()

    end = time.time()
    for i, (input, target) in enumerate(train_loader):
        data_time.update(time.time() - end)
        optimizer.zero_grad()

        input = input.cuda()
        target = target.long().cuda()

        unary = None
        noise = None
        adv_mask1 = 0
        adv_mask2 = 0

        # train with clean images
        if args.train == 'vanilla':
            input_var, target_var = Variable(input), Variable(target)
            output, reweighted_target = model(input_var, target_var)
            loss = bce_loss(softmax(output), reweighted_target)

        # train with mixup images
        elif args.train == 'mixup':
            # process for Puzzle Mix
            if args.graph:
                # whether to add adversarial noise or not
                if args.adv_p > 0:
                    adv_mask1 = np.random.binomial(n=1, p=args.adv_p)
                    adv_mask2 = np.random.binomial(n=1, p=args.adv_p)
                else:
                    adv_mask1 = 0
                    adv_mask2 = 0

                # random start
                if (adv_mask1 == 1 or adv_mask2 == 1):
                    noise = torch.zeros_like(input).uniform_(
                        -args.adv_eps / 255., args.adv_eps / 255.)
                    input_orig = input * args.std + args.mean
                    input_noise = input_orig + noise
                    input_noise = torch.clamp(input_noise, 0, 1)
                    noise = input_noise - input_orig
                    input_noise = (input_noise - args.mean) / args.std
                    input_var = Variable(input_noise, requires_grad=True)
                else:
                    input_var = Variable(input, requires_grad=True)
                target_var = Variable(target)

                # calculate saliency (unary)
                if args.clean_lam == 0:
                    model.eval()
                    output = model(input_var)
                    loss_batch = criterion_batch(output, target_var)
                else:
                    model.train()
                    output = model(input_var)
                    loss_batch = 2 * args.clean_lam * criterion_batch(
                        output, target_var) / args.num_classes

                loss_batch_mean = torch.mean(loss_batch, dim=0)
                loss_batch_mean.backward(retain_graph=True)

                unary = torch.sqrt(torch.mean(input_var.grad**2, dim=1))

                # calculate adversarial noise
                if (adv_mask1 == 1 or adv_mask2 == 1):
                    noise += (args.adv_eps + 2) / 255. * input_var.grad.sign()
                    noise = torch.clamp(noise, -args.adv_eps / 255.,
                                        args.adv_eps / 255.)
                    adv_mix_coef = np.random.uniform(0, 1)
                    noise = adv_mix_coef * noise

                if args.clean_lam == 0:
                    model.train()
                    optimizer.zero_grad()

            input_var, target_var = Variable(input), Variable(target)
            # perform mixup and calculate loss
            output, reweighted_target = model(input_var,
                                              target_var,
                                              mixup=True,
                                              args=args,
                                              grad=unary,
                                              noise=noise,
                                              adv_mask1=adv_mask1,
                                              adv_mask2=adv_mask2,
                                              mp=mp)
            loss = bce_loss(softmax(output), reweighted_target)

        # for manifold mixup
        elif args.train == 'mixup_hidden':
            input_var, target_var = Variable(input), Variable(target)
            output, reweighted_target = model(input_var,
                                              target_var,
                                              mixup_hidden=True,
                                              args=args)
            loss = bce_loss(softmax(output), reweighted_target)
        else:
            raise AssertionError('wrong train type!!')

        # measure accuracy and record loss
        prec1, prec5 = accuracy(output, target, topk=(1, 5))
        losses.update(loss.item(), input.size(0))
        top1.update(prec1.item(), input.size(0))
        top5.update(prec5.item(), input.size(0))

        # compute gradient and do SGD step
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

    print_log(
        '  **Train** Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f} Error@1 {error1:.3f}'
        .format(top1=top1, top5=top5, error1=100 - top1.avg), log)
    return top1.avg, top5.avg, losses.avg
Example #9
0
def validate(val_loader,
             model,
             log,
             fgsm=False,
             eps=4,
             rand_init=False,
             mean=None,
             std=None):
    '''evaluate trained model'''
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    for i, (input, target) in enumerate(val_loader):
        if args.use_cuda:
            input = input.cuda()
            target = target.cuda()

        # check FGSM for adversarial training
        if fgsm:
            input_var = Variable(input, requires_grad=True)
            target_var = Variable(target)

            optimizer_input = torch.optim.SGD([input_var], lr=0.1)
            output = model(input_var)
            loss = criterion(output, target_var)
            optimizer_input.zero_grad()
            loss.backward()

            sign_data_grad = input_var.grad.sign()
            input = input * std + mean + eps / 255. * sign_data_grad
            input = torch.clamp(input, 0, 1)
            input = (input - mean) / std

        with torch.no_grad():
            input_var = Variable(input)
            target_var = Variable(target)

            # compute output
            output = model(input_var)
            loss = criterion(output, target_var)

        # measure accuracy and record loss
        prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
        losses.update(loss.item(), input.size(0))
        top1.update(prec1.item(), input.size(0))
        top5.update(prec5.item(), input.size(0))

    if fgsm:
        print_log(
            'Attack (eps : {}) Prec@1 {top1.avg:.2f}'.format(eps, top1=top1),
            log)
    else:
        print_log(
            '  **Test** Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f} Error@1 {error1:.3f} Loss: {losses.avg:.3f} '
            .format(top1=top1, top5=top5, error1=100 - top1.avg,
                    losses=losses), log)
    return top1.avg, losses.avg
Example #10
0
def validate_with_gt(args,
                     val_loader,
                     disp_net,
                     epoch,
                     logger,
                     output_writers=[]):
    global device
    batch_time = AverageMeter()
    error_names = ['abs_diff', 'abs_rel', 'sq_rel', 'a1', 'a2', 'a3']
    errors = AverageMeter(i=len(error_names))
    log_outputs = len(output_writers) > 0

    # switch to evaluate mode
    disp_net.eval()

    end = time.time()
    logger.valid_bar.update(0)
    for i, (tgt_img, depth) in enumerate(val_loader):
        tgt_img = tgt_img.to(device)
        depth = depth.to(device)

        # compute output
        output_disp = disp_net(tgt_img)
        output_depth = 1 / output_disp[:, 0]

        if log_outputs and i < len(output_writers):
            if epoch == 0:
                output_writers[i].add_image('val Input',
                                            tensor2array(tgt_img[0]), 0)
                depth_to_show = depth[0]
                output_writers[i].add_image(
                    'val target Depth',
                    tensor2array(depth_to_show, max_value=10), epoch)
                depth_to_show[depth_to_show == 0] = 1000
                disp_to_show = (1 / depth_to_show).clamp(0, 10)
                output_writers[i].add_image(
                    'val target Disparity Normalized',
                    tensor2array(disp_to_show, max_value=None,
                                 colormap='bone'), epoch)

            output_writers[i].add_image(
                'val Dispnet Output Normalized',
                tensor2array(output_disp[0], max_value=None, colormap='bone'),
                epoch)
            output_writers[i].add_image(
                'val Depth Output', tensor2array(output_depth[0], max_value=3),
                epoch)

#debug for the errors
#**************************************
# scale_factor = torch.div(torch.median(depth), torch.median(output_depth))
# #scale_factor = np.median(depth)/np.median(output_depth)
# #sl_tensor=torch.tensor(scale_factor)
# #print()
# errors.update(compute_errors(depth, output_depth*scale_factor))
#**************************************
#original
        errors.update(compute_errors(depth, output_depth))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        logger.valid_bar.update(i + 1)
        if i % args.print_freq == 0:
            logger.valid_writer.write(
                'valid: Time {} Abs Error {:.4f} ({:.4f})'.format(
                    batch_time, errors.val[0], errors.avg[0]))
    logger.valid_bar.update(len(val_loader))
    # debug
    #    print(errors.avg)
    #    print(error_names)
    return errors.avg, error_names
def train(train_loader, distilled_model, epoch, args):

    distilled_model.train()

    batch_time = AverageMeter()
    loss_teacher_rec = AverageMeter()
    loss_student_rec = AverageMeter()
    loss_student_perceptual = AverageMeter()
    loss_dehazing_network = AverageMeter()
    loss_psnr = AverageMeter()
    loss_ssim = AverageMeter()

    # Start counting time
    time_start = time.time()

    for i, item in enumerate(tqdm(train_loader)):

        gt, hazy = item["gt"], item["hazy"]

        if torch.cuda.is_available():
            gt, hazy = gt.cuda(), hazy.cuda()

        loss = distilled_model.backward(gt, hazy, args)

        loss_teacher_rec.update(loss["teacher_rec_loss"].item(), gt.size(0))
        loss_student_rec.update(loss["student_rec_loss"].item(), gt.size(0))
        loss_student_perceptual.update(loss["perceptual_loss"].item(),
                                       gt.size(0))
        loss_dehazing_network.update(loss["dehazing_loss"].item(), gt.size(0))
        loss_psnr.update(loss["loss_psnr"].item(), gt.size(0))
        loss_ssim.update(loss["loss_ssim"].item(), gt.size(0))

        # time
        time_end = time.time()
        batch_time.update(time_end - time_start)
        time_start = time_end

        if (i + 1) % args.log_interval == 0:
            print(
                '[Train] Epoch: [{0}][{1}/{2}]\t'
                'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                'Teacher Reconstruction Loss {loss_teacher.val:.4f} ({loss_teacher.avg:.4f})\t'
                'Student Reconstruction Loss {loss_student.val:.4f} ({loss_student.avg:.4f})\t'
                'Student Perceptual Loss {loss_perc.val:.4f} ({loss_perc.avg:.4f})\t'
                'Dehazing Network Loss {loss_dehaze.val:.4f} ({loss_dehaze.avg:.4f})\t'
                'PSNR {loss_psnr.val:.4f} ({loss_psnr.avg:.4f})\t'
                'SSIM {loss_ssim.val:.4f} ({loss_ssim.avg:.4f})\t'.format(
                    epoch + 1,
                    i + 1,
                    len(train_loader),
                    batch_time=batch_time,
                    loss_teacher=loss_teacher_rec,
                    loss_student=loss_student_rec,
                    loss_perc=loss_student_perceptual,
                    loss_dehaze=loss_dehazing_network,
                    loss_psnr=loss_psnr,
                    loss_ssim=loss_ssim))

    losses = {
        "teacher_rec_loss": loss_teacher_rec,
        "student_rec_loss": loss_student_rec,
        "perceptual_loss": loss_student_perceptual,
        "dehazing_loss": loss_dehazing_network,
        "loss_psnr": loss_psnr,
        "loss_ssim": loss_ssim
    }

    return losses
Example #12
0
def adjust_shifts(args, train_set, adjust_loader, depth_net, pose_net, epoch,
                  logger, training_writer, **env):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    new_shifts = AverageMeter(args.sequence_length - 1, precision=2)
    pose_net.eval()
    depth_net.eval()
    upsample_depth_net = models.UpSampleNet(depth_net, args.network_input_size)

    end = time.time()

    mid_index = (args.sequence_length - 1) // 2

    # we contrain mean value of depth net output from pair 0 and mid_index
    target_values = np.arange(
        -mid_index, mid_index + 1) / (args.target_mean_depth * mid_index)
    target_values = 1 / np.abs(
        np.concatenate(
            [target_values[:mid_index], target_values[mid_index + 1:]]))

    logger.reset_train_bar(len(adjust_loader))

    for i, sample in enumerate(adjust_loader):
        index = sample['index']

        # measure data loading time
        data_time.update(time.time() - end)
        imgs = torch.stack(sample['imgs'], dim=1).to(device)
        intrinsics = sample['intrinsics'].to(device)
        intrinsics_inv = sample['intrinsics_inv'].to(device)

        # compute output
        batch_size, seq = imgs.size()[:2]

        if args.network_input_size is not None:
            h, w = args.network_input_size
            downsample_imgs = F.interpolate(imgs, (3, h, w), mode='area')
            poses = pose_net(downsample_imgs)  # [B, seq, 6]
        else:
            poses = pose_net(imgs)

        pose_matrices = pose_vec2mat(poses,
                                     args.rotation_mode)  # [B, seq, 3, 4]

        tgt_imgs = imgs[:, mid_index]  # [B, 3, H, W]
        tgt_poses = pose_matrices[:, mid_index]  # [B, 3, 4]
        compensated_poses = compensate_pose(
            pose_matrices,
            tgt_poses)  # [B, seq, 3, 4] tgt_poses are now neutral pose

        ref_indices = list(range(args.sequence_length))
        ref_indices.remove(mid_index)

        mean_depth_batch = []

        for ref_index in ref_indices:
            prior_imgs = imgs[:, ref_index]
            prior_poses = compensated_poses[:, ref_index]  # [B, 3, 4]

            prior_imgs_compensated = inverse_rotate(prior_imgs,
                                                    prior_poses[:, :, :3],
                                                    intrinsics, intrinsics_inv)
            input_pair = torch.cat([prior_imgs_compensated, tgt_imgs],
                                   dim=1)  # [B, 6, W, H]

            depth = upsample_depth_net(input_pair)  # [B, 1, H, W]
            mean_depth = depth.view(batch_size, -1).mean(-1).cpu().numpy()  # B
            mean_depth_batch.append(mean_depth)

        for j, mean_values in zip(index, np.stack(mean_depth_batch, axis=-1)):
            ratio = mean_values / target_values  # if mean value is too high, raise the shift, lower otherwise
            train_set.reset_shifts(j, ratio[:mid_index], ratio[mid_index:])
            new_shifts.update(train_set.get_shifts(j))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        logger.train_bar.update(i)
        if i % args.print_freq == 0:
            logger.train_writer.write('Adjustement:'
                                      'Time {} Data {} shifts {}'.format(
                                          batch_time, data_time, new_shifts))

    for i, shift in enumerate(new_shifts.avg):
        training_writer.add_scalar('shifts{}'.format(i), shift, epoch)

    return new_shifts.avg
Example #13
0
def adjust_shifts(args, train_set, adjust_loader, pose_exp_net, epoch, logger,
                  tb_writer):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    new_shifts = AverageMeter(args.sequence_length - 1)
    pose_exp_net.train()
    poses = np.zeros(((len(adjust_loader) - 1) * args.batch_size *
                      (args.sequence_length - 1), 6))

    mid_index = (args.sequence_length - 1) // 2

    target_values = np.abs(np.arange(
        -mid_index, mid_index + 1)) * (args.target_displacement)
    target_values = np.concatenate(
        [target_values[:mid_index], target_values[mid_index + 1:]])

    end = time.time()

    for i, (indices, tgt_img, ref_imgs, intrinsics,
            intrinsics_inv) in enumerate(adjust_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        tgt_img = tgt_img.to(device)
        ref_imgs = [img.to(device) for img in ref_imgs]

        # compute output
        explainability_mask, pose_batch = pose_exp_net(tgt_img, ref_imgs)

        if i < len(adjust_loader) - 1:
            step = args.batch_size * (args.sequence_length - 1)
            poses[i * step:(i + 1) * step] = pose_batch.cpu().reshape(
                -1, 6).numpy()

        for index, pose in zip(indices, pose_batch):
            displacements = pose[:, :3].norm(p=2, dim=1).cpu().numpy()
            ratio = target_values / displacements

            train_set.reset_shifts(index, ratio[:mid_index], ratio[mid_index:])
            new_shifts.update(train_set.get_shifts(index))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        logger.train_bar.update(i)
        if i % args.print_freq == 0:
            logger.train_writer.write('Adjustement:'
                                      'Time {} Data {} shifts {}'.format(
                                          batch_time, data_time, new_shifts))

    prefix = 'train poses'
    coeffs_names = ['tx', 'ty', 'tz']
    if args.rotation_mode == 'euler':
        coeffs_names.extend(['rx', 'ry', 'rz'])
    elif args.rotation_mode == 'quat':
        coeffs_names.extend(['qx', 'qy', 'qz'])
    for i in range(poses.shape[1]):
        tb_writer.add_histogram('{} {}'.format(prefix, coeffs_names[i]),
                                poses[:, i], epoch)

    return new_shifts.avg
Example #14
0
def validate_with_gt(args,
                     val_loader,
                     mvdnet,
                     depth_cons,
                     epoch,
                     output_writers=[]):
    batch_time = AverageMeter()
    error_names = [
        'abs_rel', 'abs_diff', 'sq_rel', 'a1', 'a2', 'a3', 'mean_angle'
    ]
    test_error_names = [
        'abs_rel', 'abs_diff', 'sq_rel', 'rms', 'log_rms', 'a1', 'a2', 'a3',
        'mean_angle'
    ]
    test_error_names1 = [
        'abs_rel', 'abs_diff', 'sq_rel', 'rms', 'log_rms', 'a1', 'a2', 'a3',
        'mean_angle'
    ]
    errors = AverageMeter(i=len(error_names))
    test_errors = AverageMeter(i=len(test_error_names))
    test_errors1 = AverageMeter(i=len(test_error_names1))
    log_outputs = len(output_writers) > 0

    # switch to evaluate mode
    if args.train_cons:
        depth_cons.eval()
    else:
        mvdnet.eval()

    end = time.time()
    with torch.no_grad():
        for i, (tgt_img, ref_imgs, gt_nmap, ref_poses, intrinsics,
                intrinsics_inv, tgt_depth) in enumerate(val_loader):
            tgt_img_var = Variable(tgt_img.cuda())
            ref_imgs_var = [Variable(img.cuda()) for img in ref_imgs]
            gt_nmap_var = Variable(gt_nmap.cuda())
            ref_poses_var = [Variable(pose.cuda()) for pose in ref_poses]
            intrinsics_var = Variable(intrinsics.cuda())
            intrinsics_inv_var = Variable(intrinsics_inv.cuda())
            tgt_depth_var = Variable(tgt_depth.cuda())

            pose = torch.cat(ref_poses_var, 1)
            if (pose != pose).any():
                continue

            outputs = mvdnet(tgt_img_var, ref_imgs_var, pose, intrinsics_var,
                             intrinsics_inv_var)
            output_depth = outputs[0]
            output_depth1 = output_depth.clone()
            nmap = outputs[1]
            nmap1 = nmap.clone()

            output_depth1 = output_depth.clone()
            if args.train_cons:
                outputs = depth_cons(output_depth, nmap.permute(0, 3, 1, 2))
                nmap = outputs[:, 1:].permute(0, 2, 3, 1)
                output_depth = outputs[:, 0].unsqueeze(1)

            mask = (tgt_depth <= args.nlabel * args.mindepth) & (
                tgt_depth >= args.mindepth) & (tgt_depth == tgt_depth)
            #mask = (tgt_depth <= 10) & (tgt_depth >= args.mindepth) & (tgt_depth == tgt_depth) #for DeMoN testing, to compare against DPSNet you might need to turn on this for fair comparison

            if not mask.any():
                continue

            output_depth1_ = torch.squeeze(output_depth1.data.cpu(), 1)
            output_depth_ = torch.squeeze(output_depth.data.cpu(), 1)

            errors_ = compute_errors_train(tgt_depth, output_depth_, mask)
            test_errors_ = list(
                compute_errors_test(tgt_depth[mask], output_depth_[mask]))
            test_errors1_ = list(
                compute_errors_test(tgt_depth[mask], output_depth1_[mask]))

            n_mask = (gt_nmap_var.permute(0, 2, 3, 1)[0, :, :] != 0)
            n_mask = n_mask[:, :, 0] | n_mask[:, :, 1] | n_mask[:, :, 2]
            total_angles_m = compute_angles(
                gt_nmap_var.permute(0, 2, 3, 1)[0], nmap[0])
            total_angles_m1 = compute_angles(
                gt_nmap_var.permute(0, 2, 3, 1)[0], nmap1[0])

            mask_angles = total_angles_m[n_mask]
            mask_angles1 = total_angles_m1[n_mask]
            total_angles_m[~n_mask] = 0
            total_angles_m1[~n_mask] = 0
            errors_.append(
                torch.mean(mask_angles).item()
            )  #/mask_angles.size(0)#[torch.sum(mask_angles).item(), (mask_angles.size(0)),  torch.sum(mask_angles < 7.5).item(), torch.sum(mask_angles < 15).item(), torch.sum(mask_angles < 30).item(), torch.sum(mask_angles < 45).item()]
            test_errors_.append(torch.mean(mask_angles).item())
            test_errors1_.append(torch.mean(mask_angles1).item())
            errors.update(errors_)
            test_errors.update(test_errors_)
            test_errors1.update(test_errors1_)

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()
            if i % args.print_freq == 0 or i == len(val_loader) - 1:
                if args.train_cons:
                    print(
                        'valid: Time {} Prev Error {:.4f}({:.4f}) Curr Error {:.4f} ({:.4f}) Prev angle Error {:.4f} ({:.4f}) Curr angle Error {:.4f} ({:.4f}) Iter {}/{}'
                        .format(batch_time, test_errors1.val[0],
                                test_errors1.avg[0], test_errors.val[0],
                                test_errors.avg[0], test_errors1.val[-1],
                                test_errors1.avg[-1], test_errors.val[-1],
                                test_errors.avg[-1], i, len(val_loader)))
                else:
                    print(
                        'valid: Time {} Rel Error {:.4f} ({:.4f}) Angle Error {:.4f} ({:.4f}) Iter {}/{}'
                        .format(batch_time, test_errors.val[0],
                                test_errors.avg[0], test_errors.val[-1],
                                test_errors.avg[-1], i, len(val_loader)))
            if args.output_print:
                output_dir = Path(args.output_dir)
                if not os.path.isdir(output_dir):
                    os.mkdir(output_dir)
                plt.imsave(output_dir / '{:04d}_map{}'.format(i, '_dps.png'),
                           output_depth_.numpy()[0],
                           cmap='rainbow')
                np.save(output_dir / '{:04d}{}'.format(i, '_dps.npy'),
                        output_depth_.numpy()[0])
                if args.train_cons:
                    plt.imsave(output_dir /
                               '{:04d}_map{}'.format(i, '_prev.png'),
                               output_depth1_.numpy()[0],
                               cmap='rainbow')
                    np.save(output_dir / '{:04d}{}'.format(i, '_prev.npy'),
                            output_depth1_.numpy()[0])
                # np.save(output_dir/'{:04d}{}'.format(i,'_gt.npy'),tgt_depth.numpy()[0])
                # imsave(output_dir/'{:04d}_aimage{}'.format(i,'.png'), np.transpose(tgt_img.numpy()[0],(1,2,0)))
                # np.save(output_dir/'{:04d}_cam{}'.format(i,'.npy'),intrinsics_var.cpu().numpy()[0])
    if args.output_print:
        np.savetxt(output_dir / args.ttype + 'errors.csv',
                   test_errors.avg,
                   fmt='%1.4f',
                   delimiter=',')
        np.savetxt(output_dir / args.ttype + 'prev_errors.csv',
                   test_errors1.avg,
                   fmt='%1.4f',
                   delimiter=',')
    return errors.avg, error_names
Example #15
0
def train(args, train_loader, disp_net, pose_exp_net, seg_net, optimizer,
          epoch_size, logger, tb_writer):
    global n_iter, device
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter(precision=4)
    w1, w2, w3, w4 = args.photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight, args.seg_loss

    # switch to train mode
    disp_net.train()
    pose_exp_net.train()
    seg_net.eval()

    end = time.time()
    logger.train_bar.update(0)

    for i, (tgt_img, ref_imgs, intrinsics,
            intrinsics_inv) in enumerate(train_loader):
        log_losses = i > 0 and n_iter % args.print_freq == 0
        log_output = args.training_output_freq > 0 and n_iter % args.training_output_freq == 0

        # measure data loading time
        data_time.update(time.time() - end)
        tgt_img = tgt_img.to(device)
        ref_imgs = [img.to(device) for img in ref_imgs]
        intrinsics = intrinsics.to(device)

        # compute output
        tgt_seg = seg_net(tgt_img)
        edge = tgt_seg[:, :, 0:-1, :] - tgt_seg[:, :, 1:, :]
        disparities = disp_net(tgt_img, edge)
        ref_seg = [seg_net(i) for i in ref_imgs]
        depth = [1 / disp for disp in disparities]
        explainability_mask, pose = pose_exp_net(tgt_img, ref_imgs)

        loss_1, warped, diff = photometric_reconstruction_loss(
            tgt_img, ref_imgs, intrinsics, depth, explainability_mask, pose,
            args.rotation_mode, args.padding_mode)
        loss_seg, warped_seg, diff_seg = photometric_reconstruction_loss(
            tgt_seg, ref_seg, intrinsics, depth, explainability_mask, pose,
            args.rotation_mode, args.padding_mode)
        if w2 > 0:
            loss_2 = explainability_loss(explainability_mask)
        else:
            loss_2 = 0
        loss_3 = smooth_loss(depth)

        loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3 + w4 * loss_seg

        if log_losses:
            tb_writer.add_scalar('seg_loss', loss_seg.item(), n_iter)
            tb_writer.add_scalar('photometric_error', loss_1.item(), n_iter)
            if w2 > 0:
                tb_writer.add_scalar('explanability_loss', loss_2.item(),
                                     n_iter)
            tb_writer.add_scalar('disparity_smoothness_loss', loss_3.item(),
                                 n_iter)
            tb_writer.add_scalar('total_loss', loss.item(), n_iter)

        if log_output:
            tb_writer.add_image('train Input', tensor2array(tgt_img[0]),
                                n_iter)
            for k, scaled_maps in enumerate(
                    zip(depth, disparities, warped_seg, diff_seg,
                        explainability_mask)):
                log_output_tensorboard(tb_writer, "train", 0, " {}".format(k),
                                       n_iter, *scaled_maps)

        # record loss and EPE
        losses.update(loss.item(), args.batch_size)

        # compute gradient and do Adam step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        with open(args.save_path / args.log_full, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow(
                [loss.item(),
                 loss_2.item() if w2 > 0 else 0,
                 loss_3.item()])
        logger.train_bar.update(i + 1)
        if i % args.print_freq == 0:
            logger.train_writer.write('Train: Time {} Data {} Loss {}'.format(
                batch_time, data_time, losses))
        if i >= epoch_size - 1:
            break

        n_iter += 1

    return losses.avg[0]
Example #16
0
def main():
    global args
    args = parser.parse_args()
    args.pretrained_disp = Path(args.pretrained_disp)
    args.pretrained_pose = Path(args.pretrained_pose)
    # args.pretrained_mask = Path(args.pretrained_mask)
    args.pretrained_flow = Path(args.pretrained_flow)

    if args.output_dir is not None:
        args.output_dir = Path(args.output_dir)
        args.output_dir.makedirs_p()

        image_dir = args.output_dir / 'images'
        gt_dir = args.output_dir / 'gt'
        mask_dir = args.output_dir / 'mask'
        viz_dir = args.output_dir / 'viz'

        image_dir.makedirs_p()
        gt_dir.makedirs_p()
        mask_dir.makedirs_p()
        viz_dir.makedirs_p()

        output_writer = SummaryWriter(args.output_dir)

    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                            std=[0.5, 0.5, 0.5])
    flow_loader_h, flow_loader_w = 256, 832
    valid_flow_transform = custom_transforms.Compose([
        custom_transforms.Scale(h=flow_loader_h, w=flow_loader_w),
        custom_transforms.ArrayToTensor(), normalize
    ])
    if args.dataset == "kitti2015":
        val_flow_set = ValidationFlow(root=args.kitti_dir,
                                      sequence_length=3,
                                      transform=valid_flow_transform)

    val_loader = torch.utils.data.DataLoader(val_flow_set,
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=2,
                                             pin_memory=True,
                                             drop_last=True)

    disp_net = getattr(models, args.dispnet)().cuda()
    pose_net = getattr(models, args.posenet)(nb_ref_imgs=2).cuda()
    # mask_net = getattr(models, args.masknet)(nb_ref_imgs=4).cuda()
    flow_net = getattr(models, args.flownet)(nlevels=args.nlevels).cuda()

    dispnet_weights = torch.load(args.pretrained_disp)
    posenet_weights = torch.load(args.pretrained_pose)
    # masknet_weights = torch.load(args.pretrained_mask)
    flownet_weights = torch.load(args.pretrained_flow)
    disp_net.load_state_dict(dispnet_weights['state_dict'])
    pose_net.load_state_dict(posenet_weights['state_dict'])
    flow_net.load_state_dict(flownet_weights['state_dict'])
    # mask_net.load_state_dict(masknet_weights['state_dict'])

    disp_net.eval()
    pose_net.eval()
    # mask_net.eval()
    flow_net.eval()

    error_names = [
        'epe_total', 'epe_sp', 'epe_mv', 'Fl', 'epe_total_gt_mask',
        'epe_sp_gt_mask', 'epe_mv_gt_mask', 'Fl_gt_mask'
    ]
    errors = AverageMeter(i=len(error_names))
    for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv, flow_gt,
            obj_map_gt) in enumerate(tqdm(val_loader)):
        tgt_img_var = Variable(tgt_img.cuda(), volatile=True)
        ref_imgs_var = [
            Variable(img.cuda(), volatile=True) for img in ref_imgs
        ]
        intrinsics_var = Variable(intrinsics.cuda(), volatile=True)
        intrinsics_inv_var = Variable(intrinsics_inv.cuda(), volatile=True)

        flow_gt_var = Variable(flow_gt.cuda(), volatile=True)
        obj_map_gt_var = Variable(obj_map_gt.cuda(), volatile=True)

        disp = disp_net(tgt_img_var)
        depth = 1 / disp
        pose = pose_net(tgt_img_var, ref_imgs_var)
        # explainability_mask = mask_net(tgt_img_var, ref_imgs_var)
        # print(len(explainability_mask))

        if args.flownet == 'Back2Future':
            flow_fwd, flow_bwd, _ = flow_net(tgt_img_var, ref_imgs_var)
        else:
            flow_fwd = flow_net(tgt_img_var, ref_imgs_var[2])

        flow_cam = pose2flow(depth.squeeze(1), pose[:, 1], intrinsics_var,
                             intrinsics_inv_var)

        # flow_cam_bwd = pose2flow(depth.squeeze(1), pose[:,1], intrinsics_var, intrinsics_inv_var)
        #---------------------------------------------------------------

        flows_cam_fwd = [
            pose2flow(depth_.squeeze(1), pose[:, 1], intrinsics_var,
                      intrinsics_inv_var) for depth_ in depth
        ]
        flows_cam_bwd = [
            pose2flow(depth_.squeeze(1), pose[:, 0], intrinsics_var,
                      intrinsics_inv_var) for depth_ in depth
        ]
        flow_fwd_list = []
        flow_fwd_list.append(flow_fwd)
        flow_bwd_list = []
        flow_bwd_list.append(flow_bwd)
        rigidity_mask_fwd = consensus_exp_masks(flows_cam_fwd,
                                                flows_cam_bwd,
                                                flow_fwd_list,
                                                flow_bwd_list,
                                                tgt_img_var,
                                                ref_imgs_var[1],
                                                ref_imgs_var[0],
                                                wssim=0.85,
                                                wrig=1.0,
                                                ws=0.1)[0]
        del flow_fwd_list
        del flow_bwd_list
        #--------------------------------------------------------------

        #rigidity_mask = 1 - (1-explainability_mask[:,1])*(1-explainability_mask[:,2]).unsqueeze(1) > 0.5
        rigidity_mask_census_soft = (flow_cam - flow_fwd).abs()  #.normalize()
        #rigidity_mask_census_u = rigidity_mask_census_soft[:,0] < args.THRESH
        #rigidity_mask_census_v = rigidity_mask_census_soft[:,1] < args.THRESH
        #rigidity_mask_census = (rigidity_mask_census_u).type_as(flow_fwd) * (rigidity_mask_census_v).type_as(flow_fwd)

        # rigidity_mask_census = ( torch.pow( (torch.pow(rigidity_mask_census_soft[:,0],2) + torch.pow(rigidity_mask_census_soft[:,1] , 2)), 0.5) < args.THRESH ).type_as(flow_fwd)
        THRESH_1 = 1
        THRESH_2 = 1
        rigidity_mask_census = (
            (torch.pow(rigidity_mask_census_soft[:, 0], 2) +
             torch.pow(rigidity_mask_census_soft[:, 1], 2)) < THRESH_1 *
            (flow_cam.pow(2).sum(dim=1) + flow_fwd.pow(2).sum(dim=1)) +
            THRESH_2).type_as(flow_fwd)

        # rigidity_mask_census = torch.zeros_like(rigidity_mask_census)
        rigidity_mask_fwd = torch.zeros_like(rigidity_mask_fwd)
        rigidity_mask_combined = 1 - (1 - rigidity_mask_fwd) * (
            1 - rigidity_mask_census)  #
        obj_map_gt_var_expanded = obj_map_gt_var.unsqueeze(1).type_as(flow_fwd)

        flow_fwd_non_rigid = (rigidity_mask_combined <= args.THRESH).type_as(
            flow_fwd).expand_as(flow_fwd) * flow_fwd
        flow_fwd_rigid = (rigidity_mask_combined > args.THRESH
                          ).type_as(flow_cam).expand_as(flow_cam) * flow_cam
        total_flow = flow_fwd_rigid + flow_fwd_non_rigid

        # rigidity_mask = rigidity_mask.type_as(flow_fwd)
        _epe_errors = compute_all_epes(
            flow_gt_var, flow_cam, flow_fwd,
            torch.zeros_like(rigidity_mask_combined)) + compute_all_epes(
                flow_gt_var, flow_cam, flow_fwd, (1 - obj_map_gt_var_expanded))
        errors.update(_epe_errors)

        tgt_img_np = tgt_img[0].numpy()
        rigidity_mask_combined_np = rigidity_mask_combined.cpu().data[0].numpy(
        )
        gt_mask_np = obj_map_gt[0].numpy()

        if args.output_dir is not None:
            np.save(image_dir / str(i).zfill(3), tgt_img_np)
            np.save(gt_dir / str(i).zfill(3), gt_mask_np)
            np.save(mask_dir / str(i).zfill(3), rigidity_mask_combined_np)

        if (args.output_dir is not None):
            tmp1 = flow_fwd.data[0].permute(1, 2, 0).cpu().numpy()
            tmp1 = flow_2_image(tmp1)
            scipy.misc.imsave(viz_dir / str(i).zfill(3) + 'flow.png', tmp1)

    print("Results")
    print("\t {:>10}, {:>10}, {:>10}, {:>6}, {:>10}, {:>10}, {:>10}, {:>10} ".
          format(*error_names))
    print(
        "Errors \t {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}"
        .format(*errors.avg))
Example #17
0
def validate_with_gt(args,
                     val_loader,
                     disp_net,
                     segnet,
                     epoch,
                     logger,
                     tb_writer,
                     sample_nb_to_log=3):
    global device
    batch_time = AverageMeter()
    error_names = ['abs_diff', 'abs_rel', 'sq_rel', 'a1', 'a2', 'a3']
    errors = AverageMeter(i=len(error_names))
    log_outputs = sample_nb_to_log > 0

    # switch to evaluate mode
    disp_net.eval()

    end = time.time()
    logger.valid_bar.update(0)
    for i, (tgt_img, depth) in enumerate(val_loader):
        tgt_img = tgt_img.to(device)
        depth = depth.to(device)

        # compute output
        tgt_seg = segnet(tgt_img)
        edge = tgt_seg[:, :, 0:-1, :] - tgt_seg[:, :, 1:, :]
        output_disp = disp_net(tgt_img, edge)
        output_depth = 1 / output_disp[:, 0]

        if log_outputs and i < sample_nb_to_log:
            if epoch == 0:
                tb_writer.add_image('val Input/{}'.format(i),
                                    tensor2array(tgt_img[0]), 0)
                depth_to_show = depth[0]
                tb_writer.add_image(
                    'val target Depth Normalized/{}'.format(i),
                    tensor2array(depth_to_show, max_value=None), epoch)
                depth_to_show[depth_to_show == 0] = 1000
                disp_to_show = (1 / depth_to_show).clamp(0, 10)
                tb_writer.add_image(
                    'val target Disparity Normalized/{}'.format(i),
                    tensor2array(disp_to_show,
                                 max_value=None,
                                 colormap='magma'), epoch)

            tb_writer.add_image(
                'val Dispnet Output Normalized/{}'.format(i),
                tensor2array(output_disp[0], max_value=None, colormap='magma'),
                epoch)
            tb_writer.add_image('val Depth Output Normalized/{}'.format(i),
                                tensor2array(output_depth[0], max_value=None),
                                epoch)

        errors.update(compute_errors(depth, output_depth))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        logger.valid_bar.update(i + 1)
        if i % args.print_freq == 0:
            logger.valid_writer.write(
                'valid: Time {} Abs Error {:.4f} ({:.4f})'.format(
                    batch_time, errors.val[0], errors.avg[0]))
    logger.valid_bar.update(len(val_loader))
    return errors.avg, error_names
Example #18
0
def train_one_epoch(args, train_loader, depth_net, pose_net, optimizer, epoch,
                    n_iter, logger, training_writer, **env):
    global device
    logger.reset_train_bar()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter(precision=4)
    w1, w2, w3 = args.photo_loss_weight, args.smooth_loss_weight, args.ssim
    e1, e2 = args.training_milestones

    # switch to train mode
    depth_net.train()
    pose_net.train()

    upsample_depth_net = models.UpSampleNet(depth_net, args.network_input_size)

    end = time.time()
    logger.train_bar.update(0)

    for i, sample in enumerate(train_loader):

        log_losses = i > 0 and n_iter % args.print_freq == 0
        log_output = args.training_output_freq > 0 and n_iter % args.training_output_freq == 0

        # measure data loading time
        data_time.update(time.time() - end)
        imgs = torch.stack(sample['imgs'], dim=1).to(device)
        intrinsics = sample['intrinsics'].to(device)
        intrinsics_inv = sample['intrinsics_inv'].to(device)

        batch_size, seq = imgs.size()[:2]

        if args.network_input_size is not None:
            h, w = args.network_input_size
            downsample_imgs = F.interpolate(imgs, (3, h, w), mode='area')
            poses = pose_net(downsample_imgs)  # [B, seq, 6]
        else:
            poses = pose_net(imgs)

        pose_matrices = pose_vec2mat(poses,
                                     args.rotation_mode)  # [B, seq, 3, 4]

        total_indices = torch.arange(seq).long().to(device).unsqueeze(
            0).expand(batch_size, seq)
        batch_range = torch.arange(batch_size).long().to(device)
        ''' for each element of the batch select a random picture in the sequence to
        which we will compute the depth, all poses are then converted so that pose of this
        very picture is exactly identity. At first this image is always in the middle of the sequence'''

        if epoch > e2:
            tgt_id = torch.floor(torch.rand(batch_size) *
                                 seq).long().to(device)
        else:
            tgt_id = torch.zeros(batch_size).long().to(
                device) + args.sequence_length // 2
        '''
        Select what other picture we are going to feed DepthNet, it must not be the same
        as tgt_id. At first, it's always first picture of the sequence, it is randomly chosen when first training milestone is reached
        '''

        ref_indices = total_indices[total_indices != tgt_id.unsqueeze(1)].view(
            batch_size, seq - 1)

        if epoch > e1:
            prior_id = torch.floor(torch.rand(batch_size) *
                                   (seq - 1)).long().to(device)
        else:
            prior_id = torch.zeros(batch_size).long().to(device)
        prior_id = ref_indices[batch_range, prior_id]

        tgt_imgs = imgs[batch_range, tgt_id]  # [B, 3, H, W]
        tgt_poses = pose_matrices[batch_range, tgt_id]  # [B, 3, 4]

        prior_imgs = imgs[batch_range, prior_id]

        compensated_poses = compensate_pose(
            pose_matrices,
            tgt_poses)  # [B, seq, 3, 4] tgt_poses are now neutral pose
        prior_poses = compensated_poses[batch_range, prior_id]  # [B, 3, 4]

        if args.supervise_pose:
            from_GT = invert_mat(sample['pose']).to(device)
            compensated_GT_poses = compensate_pose(
                from_GT, from_GT[batch_range, tgt_id])
            prior_GT_poses = compensated_GT_poses[batch_range, prior_id]
            prior_imgs_compensated = inverse_rotate(prior_imgs,
                                                    prior_GT_poses[:, :, :-1],
                                                    intrinsics, intrinsics_inv)
        else:
            prior_imgs_compensated = inverse_rotate(prior_imgs,
                                                    prior_poses[:, :, :-1],
                                                    intrinsics, intrinsics_inv)

        input_pair = torch.cat([prior_imgs_compensated, tgt_imgs],
                               dim=1)  # [B, 6, W, H]
        depth = upsample_depth_net(input_pair)
        # depth = [sample['depth'].to(device).unsqueeze(1) * 3 / abs(tgt_id[0] - prior_id[0])]
        # depth.append(torch.nn.functional.interpolate(depth[0], scale_factor=2))
        disparities = [1 / d for d in depth]

        predicted_magnitude = prior_poses[:, :,
                                          -1:].norm(p=2, dim=1,
                                                    keepdim=True).unsqueeze(1)
        scale_factor = args.nominal_displacement / (predicted_magnitude + 1e-5)
        normalized_translation = compensated_poses[:, :, :,
                                                   -1:] * scale_factor  # [B, seq_length-1, 3]
        new_pose_matrices = torch.cat(
            [compensated_poses[:, :, :, :-1], normalized_translation], dim=-1)

        biggest_scale = depth[0].size(-1)

        loss_1 = 0
        for k, scaled_depth in enumerate(depth):
            size_ratio = scaled_depth.size(-1) / biggest_scale
            loss, diff_maps, warped_imgs = photometric_reconstruction_loss(
                imgs,
                tgt_id,
                ref_indices,
                scaled_depth,
                new_pose_matrices,
                intrinsics,
                intrinsics_inv,
                args.rotation_mode,
                ssim_weight=w3)

            loss_1 += loss * size_ratio

            if log_output:
                training_writer.add_image(
                    'train Dispnet Output Normalized scale {}'.format(k),
                    tensor2array(disparities[k][0],
                                 max_value=None,
                                 colormap='bone'), n_iter)
                training_writer.add_image(
                    'train Depth Output scale {}'.format(k),
                    tensor2array(scaled_depth[0], max_value=args.max_depth),
                    n_iter)
                for j, (diff, warped) in enumerate(zip(diff_maps,
                                                       warped_imgs)):
                    training_writer.add_image(
                        'train Warped Outputs {} {}'.format(k, j),
                        tensor2array(warped[0]), n_iter)
                    training_writer.add_image(
                        'train Diff Outputs {} {}'.format(k, j),
                        tensor2array(diff.abs()[0] - 1), n_iter)

        loss_2 = texture_aware_smooth_loss(
            depth, tgt_imgs if args.texture_loss else None)

        loss = w1 * loss_1 + w2 * loss_2

        if args.supervise_pose:
            loss += (from_GT[:, :, :, :3] -
                     pose_matrices[:, :, :, :3]).abs().mean()

        if log_losses:
            training_writer.add_scalar('photometric_error', loss_1.item(),
                                       n_iter)
            training_writer.add_scalar('disparity_smoothness_loss',
                                       loss_2.item(), n_iter)
            training_writer.add_scalar('total_loss', loss.item(), n_iter)

        if log_output:
            nominal_translation_magnitude = poses[:, -2, :3].norm(p=2, dim=-1)
            # last pose is always identity and penultimate translation magnitude is always 1, so you don't need to log them
            for j in range(args.sequence_length - 2):
                trans_mag = poses[:, j, :3].norm(p=2, dim=-1)
                training_writer.add_histogram(
                    'tr {}'.format(j),
                    (trans_mag /
                     nominal_translation_magnitude).detach().cpu().numpy(),
                    n_iter)
            for j in range(args.sequence_length - 1):
                # TODO log a better value : this is magnitude of vector (yaw, pitch, roll) which is not a physical value
                rot_mag = poses[:, j, 3:].norm(p=2, dim=-1)
                training_writer.add_histogram('rot {}'.format(j),
                                              rot_mag.detach().cpu().numpy(),
                                              n_iter)

            training_writer.add_image('train Input', tensor2array(tgt_imgs[0]),
                                      n_iter)

        # record loss for average meter
        losses.update(loss.item(), args.batch_size)

        # compute gradient and do Adam step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        with open(args.save_path / args.log_full, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([loss.item(), loss_1.item(), loss_2.item()])
        logger.train_bar.update(i + 1)
        if i % args.print_freq == 0:
            logger.train_writer.write('Train: Time {} Data {} Loss {}'.format(
                batch_time, data_time, losses))
        if i >= args.epoch_size - 1:
            break

        n_iter += 1

    return losses.avg[0], n_iter
Example #19
0
def train(args, train_loader, mvdnet, optimizer, epoch_size, train_writer,
          epoch):
    global n_iter
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter(precision=4)
    d_losses = AverageMeter(precision=4)
    nmap_losses = AverageMeter(precision=4)

    # switch to training mode
    mvdnet.train()

    print("Training")
    end = time.time()
    for i, (tgt_img, ref_imgs, gt_nmap, ref_poses, intrinsics, intrinsics_inv,
            tgt_depth) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        tgt_img_var = Variable(tgt_img.cuda())
        ref_imgs_var = [Variable(img.cuda()) for img in ref_imgs]
        gt_nmap_var = Variable(gt_nmap.cuda())
        ref_poses_var = [Variable(pose.cuda()) for pose in ref_poses]
        intrinsics_var = Variable(intrinsics.cuda())
        intrinsics_inv_var = Variable(intrinsics_inv.cuda())
        tgt_depth_var = Variable(tgt_depth.cuda()).cuda()

        # compute output
        pose = torch.cat(ref_poses_var, 1)

        if args.dataset == 'sceneflow':
            factor = (1.0 / args.scale) * intrinsics_var[:, 0, 0] / 1050.0
            factor = factor.view(-1, 1, 1)
        else:
            factor = torch.ones(
                (tgt_depth_var.size(0), 1, 1)).type_as(tgt_depth_var)

        # get mask
        mask = (tgt_depth_var <= args.nlabel * args.mindepth * factor * 3) & (
            tgt_depth_var >= args.mindepth * factor) & (tgt_depth_var
                                                        == tgt_depth_var)
        mask.detach_()
        if mask.any() == 0:
            continue

        targetimg = inverse_warp(ref_imgs_var[0], tgt_depth_var.unsqueeze(1),
                                 pose[:, 0], intrinsics_var,
                                 intrinsics_inv_var)  #[B,CH,D,H,W,1]

        outputs = mvdnet(tgt_img_var,
                         ref_imgs_var,
                         pose,
                         intrinsics_var,
                         intrinsics_inv_var,
                         factor=factor.unsqueeze(1))

        nmap = outputs[2].permute(0, 3, 1, 2)
        depths = outputs[0:2]

        disps = [args.mindepth * args.nlabel / (depth)
                 for depth in depths]  # correct disps
        if args.dataset == 'sceneflow':
            disps = [(args.mindepth * args.nlabel) * 3 / (depth)
                     for depth in depths]  # correct disps
            depths = [(args.mindepth * factor) * (args.nlabel * 3) / disp
                      for disp in disps]
        loss = 0.
        d_loss = 0.
        nmap_loss = 0.
        if args.dataset == 'sceneflow':
            tgt_disp_var = ((1.0 / args.scale) *
                            intrinsics_var[:, 0, 0].view(-1, 1, 1) /
                            tgt_depth_var)
            for l, disp in enumerate(disps):
                output = torch.squeeze(disp, 1)
                d_loss = d_loss + F.smooth_l1_loss(output[mask],
                                                   tgt_disp_var[mask]) * pow(
                                                       0.7,
                                                       len(disps) - l - 1)
        else:
            for l, depth in enumerate(depths):
                output = torch.squeeze(depth, 1)
                d_loss = d_loss + F.smooth_l1_loss(output[mask],
                                                   tgt_depth_var[mask]) * pow(
                                                       0.7,
                                                       len(depths) - l - 1)

        n_mask = mask.unsqueeze(1).expand(-1, 3, -1, -1)
        nmap_loss = nmap_loss + F.smooth_l1_loss(nmap[n_mask],
                                                 gt_nmap_var[n_mask])

        loss = loss + args.d_weight * d_loss + args.n_weight * nmap_loss

        if i > 0 and n_iter % args.print_freq == 0:
            train_writer.add_scalar('total_loss', loss.item(), n_iter)
        # record loss and EPE
        losses.update(loss.item(), args.batch_size)
        d_losses.update(d_loss.item(), args.batch_size)
        nmap_losses.update(nmap_loss.item(), args.batch_size)

        # compute gradient and do Adam step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        with open(args.save_path / args.log_full, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([loss.item()])
        if i % args.print_freq == 0:
            print(
                'Train: Time {} Data {} Loss {} NmapLoss {} DLoss {} Iter {}/{} Epoch {}/{}'
                .format(batch_time, data_time, losses, nmap_losses, d_losses,
                        i, len(train_loader), epoch, args.epochs))
        if i >= epoch_size - 1:
            break

        n_iter += 1

    return losses.avg[0]
Example #20
0
def validate_without_gt(args,
                        val_loader,
                        depth_net,
                        pose_net,
                        epoch,
                        logger,
                        output_writers=[],
                        **env):
    global device
    batch_time = AverageMeter()
    losses = AverageMeter(i=3, precision=4)
    w1, w2, w3 = args.photo_loss_weight, args.smooth_loss_weight, args.ssim
    if args.log_output:
        poses_values = np.zeros(((len(val_loader) - 1) * args.test_batch_size *
                                 (args.sequence_length - 1), 6))
        disp_values = np.zeros(
            ((len(val_loader) - 1) * args.test_batch_size * 3))

    # switch to evaluate mode
    depth_net.eval()
    pose_net.eval()

    upsample_depth_net = models.UpSampleNet(depth_net, args.network_input_size)

    end = time.time()
    logger.valid_bar.update(0)

    for i, sample in enumerate(val_loader):
        log_output = i < len(output_writers)

        imgs = torch.stack(sample['imgs'], dim=1).to(device)
        intrinsics = sample['intrinsics'].to(device)
        intrinsics_inv = sample['intrinsics_inv'].to(device)

        if epoch == 1 and log_output:
            for j, img in enumerate(sample['imgs']):
                output_writers[i].add_image('val Input', tensor2array(img[0]),
                                            j)

        batch_size, seq = imgs.size()[:2]

        if args.network_input_size is not None:
            h, w = args.network_input_size
            downsample_imgs = F.interpolate(imgs, (3, h, w), mode='area')
            poses = pose_net(downsample_imgs)  # [B, seq, 6]
        else:
            poses = pose_net(imgs)

        pose_matrices = pose_vec2mat(poses,
                                     args.rotation_mode)  # [B, seq, 3, 4]

        mid_index = (args.sequence_length - 1) // 2

        tgt_imgs = imgs[:, mid_index]  # [B, 3, H, W]
        tgt_poses = pose_matrices[:, mid_index]  # [B, 3, 4]
        compensated_poses = compensate_pose(
            pose_matrices,
            tgt_poses)  # [B, seq, 3, 4] tgt_poses are now neutral pose

        ref_indices = list(range(args.sequence_length))
        ref_indices.remove(mid_index)

        loss_1 = 0
        loss_2 = 0

        for ref_index in ref_indices:
            prior_imgs = imgs[:, ref_index]
            prior_poses = compensated_poses[:, ref_index]  # [B, 3, 4]

            prior_imgs_compensated = inverse_rotate(prior_imgs,
                                                    prior_poses[:, :, :3],
                                                    intrinsics, intrinsics_inv)
            input_pair = torch.cat([prior_imgs_compensated, tgt_imgs],
                                   dim=1)  # [B, 6, W, H]

            predicted_magnitude = prior_poses[:, :, -1:].norm(
                p=2, dim=1, keepdim=True).unsqueeze(1)  # [B, 1, 1, 1]
            scale_factor = args.nominal_displacement / predicted_magnitude
            normalized_translation = compensated_poses[:, :, :,
                                                       -1:] * scale_factor  # [B, seq, 3, 1]
            new_pose_matrices = torch.cat(
                [compensated_poses[:, :, :, :-1], normalized_translation],
                dim=-1)

            depth = upsample_depth_net(input_pair)
            disparity = 1 / depth
            total_indices = torch.arange(seq).long().unsqueeze(0).expand(
                batch_size, seq).to(device)
            tgt_id = total_indices[:, mid_index]

            ref_indices = total_indices[
                total_indices != tgt_id.unsqueeze(1)].view(
                    batch_size, seq - 1)

            photo_loss, diff_maps, warped_imgs = photometric_reconstruction_loss(
                imgs,
                tgt_id,
                ref_indices,
                depth,
                new_pose_matrices,
                intrinsics,
                intrinsics_inv,
                args.rotation_mode,
                ssim_weight=w3)

            loss_1 += photo_loss

            if log_output:
                output_writers[i].add_image(
                    'val Dispnet Output Normalized {}'.format(ref_index),
                    tensor2array(disparity[0], max_value=None,
                                 colormap='bone'), epoch)
                output_writers[i].add_image(
                    'val Depth Output {}'.format(ref_index),
                    tensor2array(depth[0].cpu(), max_value=args.max_depth),
                    epoch)
                for j, (diff, warped) in enumerate(zip(diff_maps,
                                                       warped_imgs)):
                    output_writers[i].add_image(
                        'val Warped Outputs {} {}'.format(j, ref_index),
                        tensor2array(warped[0]), epoch)
                    output_writers[i].add_image(
                        'val Diff Outputs {} {}'.format(j, ref_index),
                        tensor2array(diff[0].abs() - 1), epoch)

            loss_2 += texture_aware_smooth_loss(
                disparity, tgt_imgs if args.texture_loss else None)

        if args.log_output and i < len(val_loader) - 1:
            step = args.test_batch_size * (args.sequence_length - 1)
            poses_values[i * step:(i + 1) * step] = poses[:, :-1].cpu().view(
                -1, 6).numpy()
            step = args.test_batch_size * 3
            disp_unraveled = disparity.cpu().view(args.test_batch_size, -1)
            disp_values[i * step:(i + 1) * step] = torch.cat([
                disp_unraveled.min(-1)[0],
                disp_unraveled.median(-1)[0],
                disp_unraveled.max(-1)[0]
            ]).numpy()

        loss = w1 * loss_1 + w2 * loss_2
        losses.update([loss.item(), loss_1.item(), loss_2.item()])

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        logger.valid_bar.update(i + 1)
        if i % args.print_freq == 0:
            logger.valid_writer.write('valid: Time {} Loss {}'.format(
                batch_time, losses))

    if args.log_output:
        rot_coeffs = ['rx', 'ry', 'rz'] if args.rotation_mode == 'euler' else [
            'qx', 'qy', 'qz'
        ]
        tr_coeffs = ['tx', 'ty', 'tz']
        for k, (coeff_name) in enumerate(tr_coeffs + rot_coeffs):
            output_writers[0].add_histogram('val poses_{}'.format(coeff_name),
                                            poses_values[:, k], epoch)
        output_writers[0].add_histogram('disp_values', disp_values, epoch)
    logger.valid_bar.update(len(val_loader))
    return OrderedDict(
        zip(['Total loss', 'Photo loss', 'Smooth loss'], losses.avg))
Example #21
0
def train(args, train_loader, disp_net, pose_exp_net, optimizer, epoch_size,
          logger, train_writer):
    global n_iter
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter(precision=4)
    w1, w2, w3 = args.photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight

    # switch to train mode
    disp_net.train()
    pose_exp_net.train()

    end = time.time()
    logger.train_bar.update(0)

    for i, (tgt_img, ref_imgs, intrinsics,
            intrinsics_inv) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        tgt_img_var = Variable(tgt_img.cuda())
        ref_imgs_var = [Variable(img.cuda()) for img in ref_imgs]
        intrinsics_var = Variable(intrinsics.cuda())
        intrinsics_inv_var = Variable(intrinsics_inv.cuda())

        # compute output
        disparities = disp_net(tgt_img_var)
        depth = [1 / disp for disp in disparities]
        explainability_mask, pose = pose_exp_net(tgt_img_var, ref_imgs_var)

        loss_1 = photometric_reconstruction_loss(tgt_img_var, ref_imgs_var,
                                                 intrinsics_var,
                                                 intrinsics_inv_var, depth,
                                                 explainability_mask, pose,
                                                 args.rotation_mode,
                                                 args.padding_mode)
        if w2 > 0:
            loss_2 = explainability_loss(explainability_mask)
        else:
            loss_2 = 0
        loss_3 = smooth_loss(disparities)

        loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3

        if i > 0 and n_iter % args.print_freq == 0:
            train_writer.add_scalar('photometric_error', loss_1.data[0],
                                    n_iter)
            if w2 > 0:
                train_writer.add_scalar('explanability_loss', loss_2.data[0],
                                        n_iter)
            train_writer.add_scalar('disparity_smoothness_loss',
                                    loss_3.data[0], n_iter)
            train_writer.add_scalar('total_loss', loss.data[0], n_iter)

        if args.training_output_freq > 0 and n_iter % args.training_output_freq == 0:

            train_writer.add_image('train Input', tensor2array(tgt_img[0]),
                                   n_iter)

            for k, scaled_depth in enumerate(depth):
                train_writer.add_image(
                    'train Dispnet Output Normalized {}'.format(k),
                    tensor2array(disparities[k].data[0].cpu(),
                                 max_value=None,
                                 colormap='bone'), n_iter)
                train_writer.add_image(
                    'train Depth Output {}'.format(k),
                    tensor2array(1 / disparities[k].data[0].cpu(),
                                 max_value=10), n_iter)
                b, _, h, w = scaled_depth.size()
                downscale = tgt_img_var.size(2) / h

                tgt_img_scaled = nn.functional.adaptive_avg_pool2d(
                    tgt_img_var, (h, w))
                ref_imgs_scaled = [
                    nn.functional.adaptive_avg_pool2d(ref_img, (h, w))
                    for ref_img in ref_imgs_var
                ]

                intrinsics_scaled = torch.cat(
                    (intrinsics_var[:, 0:2] / downscale, intrinsics_var[:,
                                                                        2:]),
                    dim=1)
                intrinsics_scaled_inv = torch.cat(
                    (intrinsics_inv_var[:, :, 0:2] * downscale,
                     intrinsics_inv_var[:, :, 2:]),
                    dim=2)

                # log warped images along with explainability mask
                for j, ref in enumerate(ref_imgs_scaled):
                    ref_warped = inverse_warp(
                        ref,
                        scaled_depth[:, 0],
                        pose[:, j],
                        intrinsics_scaled,
                        intrinsics_scaled_inv,
                        rotation_mode=args.rotation_mode,
                        padding_mode=args.padding_mode)[0]
                    train_writer.add_image(
                        'train Warped Outputs {} {}'.format(k, j),
                        tensor2array(ref_warped.data.cpu()), n_iter)
                    train_writer.add_image(
                        'train Diff Outputs {} {}'.format(k, j),
                        tensor2array(
                            0.5 *
                            (tgt_img_scaled[0] - ref_warped).abs().data.cpu()),
                        n_iter)
                    if explainability_mask[k] is not None:
                        train_writer.add_image(
                            'train Exp mask Outputs {} {}'.format(k, j),
                            tensor2array(explainability_mask[k][0,
                                                                j].data.cpu(),
                                         max_value=1,
                                         colormap='bone'), n_iter)

        # record loss and EPE
        losses.update(loss.data[0], args.batch_size)

        # compute gradient and do Adam step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        with open(args.save_path / args.log_full, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([
                loss.data[0], loss_1.data[0], loss_2.data[0] if w2 > 0 else 0,
                loss_3.data[0]
            ])
        logger.train_bar.update(i + 1)
        if i % args.print_freq == 0:
            logger.train_writer.write('Train: Time {} Data {} Loss {}'.format(
                batch_time, data_time, losses))
        if i >= epoch_size - 1:
            break

        n_iter += 1

    return losses.avg[0]
Example #22
0
def validate_with_gt(args,
                     val_loader,
                     depth_net,
                     pose_net,
                     epoch,
                     logger,
                     output_writers=[],
                     **env):
    global device
    batch_time = AverageMeter()
    depth_error_names = ['abs diff', 'abs rel', 'sq rel', 'a1', 'a2', 'a3']
    stab_depth_errors = AverageMeter(i=len(depth_error_names))
    unstab_depth_errors = AverageMeter(i=len(depth_error_names))
    pose_error_names = ['Absolute Trajectory Error', 'Rotation Error']
    pose_errors = AverageMeter(i=len(pose_error_names))

    # switch to evaluate mode
    depth_net.eval()
    pose_net.eval()

    end = time.time()
    logger.valid_bar.update(0)
    for i, sample in enumerate(val_loader):
        log_output = i < len(output_writers)

        imgs = torch.stack(sample['imgs'], dim=1).to(device)
        batch_size, seq, c, h, w = imgs.size()

        intrinsics = sample['intrinsics'].to(device)
        intrinsics_inv = sample['intrinsics_inv'].to(device)

        if args.network_input_size is not None:
            imgs = F.interpolate(imgs, (c, *args.network_input_size),
                                 mode='area')

            downscale = h / args.network_input_size[0]
            intrinsics = torch.cat(
                (intrinsics[:, 0:2] / downscale, intrinsics[:, 2:]), dim=1)
            intrinsics_inv = torch.cat(
                (intrinsics_inv[:, :, 0:2] * downscale, intrinsics_inv[:, :,
                                                                       2:]),
                dim=2)

        GT_depth = sample['depth'].to(device)
        GT_pose = sample['pose'].to(device)

        mid_index = (args.sequence_length - 1) // 2

        tgt_img = imgs[:, mid_index]

        if epoch == 1 and log_output:
            for j, img in enumerate(sample['imgs']):
                output_writers[i].add_image('val Input', tensor2array(img[0]),
                                            j)
            depth_to_show = GT_depth[0].cpu()
            # KITTI Like data routine to discard invalid data
            depth_to_show[depth_to_show == 0] = 1000
            disp_to_show = (1 / depth_to_show).clamp(0, 10)
            output_writers[i].add_image(
                'val target Disparity Normalized',
                tensor2array(disp_to_show, max_value=None, colormap='bone'),
                epoch)

        poses = pose_net(imgs)
        pose_matrices = pose_vec2mat(poses,
                                     args.rotation_mode)  # [B, seq, 3, 4]
        inverted_pose_matrices = invert_mat(pose_matrices)
        pose_errors.update(
            compute_pose_error(GT_pose[:, :-1],
                               inverted_pose_matrices.data[:, :-1]))

        tgt_poses = pose_matrices[:, mid_index]  # [B, 3, 4]
        compensated_predicted_poses = compensate_pose(pose_matrices, tgt_poses)
        compensated_GT_poses = compensate_pose(GT_pose, GT_pose[:, mid_index])

        for j in range(args.sequence_length):
            if j == mid_index:
                if log_output and epoch == 1:
                    output_writers[i].add_image(
                        'val Input Stabilized',
                        tensor2array(sample['imgs'][j][0]), j)
                continue
            '''compute displacement magnitude for each element of batch, and rescale
            depth accordingly.'''

            prior_img = imgs[:, j]
            displacement = compensated_GT_poses[:, j, :, -1]  # [B,3]
            displacement_magnitude = displacement.norm(p=2, dim=1)  # [B]
            current_GT_depth = GT_depth * args.nominal_displacement / displacement_magnitude.view(
                -1, 1, 1)

            prior_predicted_pose = compensated_predicted_poses[:,
                                                               j]  # [B, 3, 4]
            prior_GT_pose = compensated_GT_poses[:, j]

            prior_predicted_rot = prior_predicted_pose[:, :, :-1]
            prior_GT_rot = prior_GT_pose[:, :, :-1].transpose(1, 2)

            prior_compensated_from_GT = inverse_rotate(prior_img, prior_GT_rot,
                                                       intrinsics,
                                                       intrinsics_inv)
            if log_output and epoch == 1:
                depth_to_show = current_GT_depth[0]
                output_writers[i].add_image(
                    'val target Depth {}'.format(j),
                    tensor2array(depth_to_show, max_value=args.max_depth),
                    epoch)
                output_writers[i].add_image(
                    'val Input Stabilized',
                    tensor2array(prior_compensated_from_GT[0]), j)

            prior_compensated_from_prediction = inverse_rotate(
                prior_img, prior_predicted_rot, intrinsics, intrinsics_inv)
            predicted_input_pair = torch.cat(
                [prior_compensated_from_prediction, tgt_img],
                dim=1)  # [B, 6, W, H]
            GT_input_pair = torch.cat([prior_compensated_from_GT, tgt_img],
                                      dim=1)  # [B, 6, W, H]

            # This is the depth from footage stabilized with GT pose, it should be better than depth from raw footage without any GT info
            raw_depth_stab = depth_net(GT_input_pair)
            raw_depth_unstab = depth_net(predicted_input_pair)

            # Upsample depth so that it matches GT size
            scale_factor = GT_depth.size(-1) // raw_depth_stab.size(-1)
            depth_stab = F.interpolate(raw_depth_stab,
                                       scale_factor=scale_factor,
                                       mode='bilinear',
                                       align_corners=False)
            depth_unstab = F.interpolate(raw_depth_unstab,
                                         scale_factor=scale_factor,
                                         mode='bilinear',
                                         align_corners=False)

            for k, depth in enumerate([depth_stab, depth_unstab]):
                disparity = 1 / depth
                errors = stab_depth_errors if k == 0 else unstab_depth_errors
                errors.update(
                    compute_depth_errors(current_GT_depth, depth, crop=True))
                if log_output:
                    prefix = 'stabilized' if k == 0 else 'unstabilized'
                    output_writers[i].add_image(
                        'val {} Dispnet Output Normalized {}'.format(
                            prefix, j),
                        tensor2array(disparity[0],
                                     max_value=None,
                                     colormap='bone'), epoch)
                    output_writers[i].add_image(
                        'val {} Depth Output {}'.format(prefix, j),
                        tensor2array(depth[0], max_value=args.max_depth),
                        epoch)

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        logger.valid_bar.update(i + 1)
        if i % args.print_freq == 0:
            logger.valid_writer.write(
                'valid: Time {} ATE Error {:.4f} ({:.4f}), Unstab Rel Abs Error {:.4f} ({:.4f})'
                .format(batch_time, pose_errors.val[0], pose_errors.avg[0],
                        unstab_depth_errors.val[1],
                        unstab_depth_errors.avg[1]))
    logger.valid_bar.update(len(val_loader))

    errors = (*pose_errors.avg, *unstab_depth_errors.avg,
              *stab_depth_errors.avg)
    error_names = (*pose_error_names,
                   *['unstab {}'.format(e) for e in depth_error_names],
                   *['stab {}'.format(e) for e in depth_error_names])

    return OrderedDict(zip(error_names, errors))
Example #23
0
def validate_with_gt(args,
                     val_loader,
                     disp_net,
                     epoch,
                     logger,
                     output_writers=[]):
    batch_time = AverageMeter()
    error_names = ['abs_diff', 'abs_rel', 'sq_rel', 'a1', 'a2', 'a3']
    errors = AverageMeter(i=len(error_names))
    log_outputs = len(output_writers) > 0

    # switch to evaluate mode
    disp_net.eval()

    end = time.time()
    logger.valid_bar.update(0)
    for i, (tgt_img, depth) in enumerate(val_loader):
        tgt_img_var = Variable(tgt_img.cuda(), volatile=True)
        depth = depth.cuda()

        # compute output
        output_disp = disp_net(tgt_img_var)
        output_depth = 1 / output_disp

        if log_outputs and i % 100 == 0 and i / 100 < len(output_writers):
            index = int(i // 100)
            if epoch == 0:
                output_writers[index].add_image('val Input',
                                                tensor2array(tgt_img[0]), 0)
                depth_to_show = depth[0].cpu()
                output_writers[index].add_image(
                    'val target Depth',
                    tensor2array(depth_to_show, max_value=10), epoch)
                depth_to_show[depth_to_show == 0] = 1000
                disp_to_show = (1 / depth_to_show).clamp(0, 10)
                output_writers[index].add_image(
                    'val target Disparity Normalized',
                    tensor2array(disp_to_show, max_value=None,
                                 colormap='bone'), epoch)

            output_writers[index].add_image(
                'val Dispnet Output Normalized',
                tensor2array(output_disp.data[0].cpu(),
                             max_value=None,
                             colormap='bone'), epoch)
            output_writers[index].add_image(
                'val Depth Output',
                tensor2array(output_depth.data[0].cpu(), max_value=10), epoch)

        errors.update(compute_errors(depth, output_depth.data))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        logger.valid_bar.update(i + 1)
        if i % args.print_freq == 0:
            logger.valid_writer.write(
                'valid: Time {} Abs Error {:.4f} ({:.4f})'.format(
                    batch_time, errors.val[0], errors.avg[0]))
    logger.valid_bar.update(len(val_loader))
    return errors.avg, error_names
Example #24
0
def test(val_loader,disp_net,mask_net,pose_net, flow_net, tb_writer,global_vars_dict = None):
#data prepared
    device = global_vars_dict['device']
    n_iter_val = global_vars_dict['n_iter_val']
    args = global_vars_dict['args']


    data_time = AverageMeter()


# to eval model
    disp_net.eval()
    pose_net.eval()
    mask_net.eval()
    flow_net.eval()

    end = time.time()
    poses = np.zeros(((len(val_loader)-1) * 1 * (args.sequence_length-1),6))#init

    disp_list = []

    flow_list = []
    mask_list = []

#3. validation cycle
    for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv) in tqdm(enumerate(val_loader)):
        data_time.update(time.time() - end)
        tgt_img = tgt_img.to(device)
        ref_imgs = [img.to(device) for img in ref_imgs]
        intrinsics,intrinsics_inv = intrinsics.to(device),intrinsics_inv.to(device)
    #3.1 forwardpass
        #disp
        disp = disp_net(tgt_img)
        if args.spatial_normalize:
            disp = spatial_normalize(disp)
        depth = 1 / disp

        #pose
        pose = pose_net(tgt_img, ref_imgs)
        #flow----
        #制作前后一帧的
        if args.flownet == 'Back2Future':
            flow_fwd, flow_bwd, _ = flow_net(tgt_img, ref_imgs[1:3])
        elif args.flownet == 'FlowNetC6':
            flow_fwd = flow_net(tgt_img, ref_imgs[2])
            flow_bwd = flow_net(tgt_img, ref_imgs[1])
        #FLOW FWD [B,2,H,W]
        #flow cam :tensor[b,2,h,w]
        #flow_background
        flow_cam = pose2flow(depth.squeeze(1), pose[:, 2], intrinsics, intrinsics_inv)

        flows_cam_fwd = pose2flow(depth.squeeze(1), pose[:, 2], intrinsics, intrinsics_inv)
        flows_cam_bwd = pose2flow(depth.squeeze(1), pose[:, 1], intrinsics, intrinsics_inv)

        #exp_masks_target = consensus_exp_masks(flows_cam_fwd, flows_cam_bwd, flow_fwd, flow_bwd, tgt_img,
        #                                       ref_imgs[2], ref_imgs[1], wssim=args.wssim, wrig=args.wrig,
        #                                       ws=args.smooth_loss_weight)

        rigidity_mask_fwd = (flows_cam_fwd - flow_fwd).abs()#[b,2,h,w]
        rigidity_mask_bwd = (flows_cam_bwd - flow_bwd).abs()

        # mask
        # 4.explainability_mask(none)
        explainability_mask = mask_net(tgt_img, ref_imgs)  # 有效区域?4??

        # list(5):item:tensor:[4,4,128,512]...[4,4,4,16] value:[0.33~0.48~0.63]
        end = time.time()


    #3.4 check log

        #查看forward pass效果
    # 2 disp
        disp_to_show =tensor2array(disp[0].cpu(), max_value=None,colormap='bone')# tensor disp_to_show :[1,h,w],0.5~3.1~10
        tb_writer.add_image('Disp/disp0', disp_to_show,i)
        disp_list.append(disp_to_show)

        if i == 0:
            disp_arr =  np.expand_dims(disp_to_show,axis=0)
        else:
            disp_to_show = np.expand_dims(disp_to_show,axis=0)
            disp_arr = np.concatenate([disp_arr,disp_to_show],0)


    #3. flow
        tb_writer.add_image('Flow/Flow Output', flow2rgb(flow_fwd[0], max_value=6),i)
        tb_writer.add_image('Flow/cam_Flow Output', flow2rgb(flow_cam[0], max_value=6),i)
        tb_writer.add_image('Flow/rigid_Flow Output', flow2rgb(rigidity_mask_fwd[0], max_value=6),i)
        tb_writer.add_image('Flow/rigidity_mask_fwd',flow2rgb(rigidity_mask_fwd[0],max_value=6),i)
        flow_list.append(flow2rgb(flow_fwd[0], max_value=6))
    #4. mask
        tb_writer.add_image('Mask /mask0',tensor2array(explainability_mask[0][0], max_value=None, colormap='magma'), i)
        #tb_writer.add_image('Mask Output/mask1 sample{}'.format(i),tensor2array(explainability_mask[1][0], max_value=None, colormap='magma'), epoch)
        #tb_writer.add_image('Mask Output/mask2 sample{}'.format(i),tensor2array(explainability_mask[2][0], max_value=None, colormap='magma'), epoch)
        #tb_writer.add_image('Mask Output/mask3 sample{}'.format(i),tensor2array(explainability_mask[3][0], max_value=None, colormap='magma'), epoch)
        mask_list.append(tensor2array(explainability_mask[0][0], max_value=None, colormap='magma'))
    #

    return disp_list,disp_arr,flow_list,mask_list
Example #25
0
def train(args, train_loader, disp_net, pose_net, optimizer, epoch_size,
          logger, train_writer):
    global n_iter, device
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter(precision=4)
    w1, w2, w3 = args.photo_loss_weight, args.smooth_loss_weight, args.geometry_consistency_weight

    # switch to train mode
    disp_net.train()
    pose_net.train()

    end = time.time()
    logger.train_bar.update(0)

    for i, (tgt_img, ref_imgs, intrinsics,
            intrinsics_inv) in enumerate(train_loader):
        log_losses = i > 0 and n_iter % args.print_freq == 0

        # measure data loading time
        data_time.update(time.time() - end)
        tgt_img = tgt_img.to(device)
        ref_imgs = [img.to(device) for img in ref_imgs]
        intrinsics = intrinsics.to(device)

        # compute output
        tgt_depth, ref_depths = compute_depth(disp_net, tgt_img, ref_imgs)

        poses, poses_inv = compute_pose_with_inv(pose_net, tgt_img, ref_imgs,
                                                 intrinsics)
        #if poses is None:
        loss_1, loss_3 = compute_photo_and_geometry_loss(
            tgt_img, ref_imgs, intrinsics, tgt_depth, ref_depths, poses,
            poses_inv, args.num_scales, args.with_ssim, args.with_mask,
            args.with_auto_mask, args.padding_mode)

        loss_2 = compute_smooth_loss(tgt_depth, tgt_img, ref_depths, ref_imgs)

        loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3

        if log_losses:
            train_writer.add_scalar('photometric_error', loss_1.item(), n_iter)
            train_writer.add_scalar('disparity_smoothness_loss', loss_2.item(),
                                    n_iter)
            train_writer.add_scalar('geometry_consistency_loss', loss_3.item(),
                                    n_iter)
            train_writer.add_scalar('total_loss', loss.item(), n_iter)

        # record loss and EPE
        losses.update(loss.item(), args.batch_size)

        # compute gradient and do Adam step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        with open(args.save_path / args.log_full, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow(
                [loss.item(),
                 loss_1.item(),
                 loss_2.item(),
                 loss_3.item()])
        logger.train_bar.update(i + 1)
        if i % args.print_freq == 0:
            logger.train_writer.write('Train: Time {} Data {} Loss {}'.format(
                batch_time, data_time, losses))
        if i >= epoch_size - 1:
            break

        n_iter += 1

    return losses.avg[0]
Example #26
0
def train(train_loader, model, optimizer, epoch, args, log, mpp=None):
    '''train given model and dataloader'''
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    mixing_avg = []

    # switch to train mode
    model.train()

    end = time.time()
    for input, target in train_loader:
        data_time.update(time.time() - end)
        optimizer.zero_grad()

        input = input.cuda()
        target = target.long().cuda()
        sc = None

        # train with clean images
        if not args.comix:
            target_reweighted = to_one_hot(target, args.num_classes)
            output = model(input)
            loss = bce_loss(softmax(output), target_reweighted)

        # train with Co-Mixup images
        else:
            input_var = Variable(input, requires_grad=True)
            target_var = Variable(target)
            A_dist = None

            # Calculate saliency (unary)
            if args.clean_lam == 0:
                model.eval()
                output = model(input_var)
                loss_batch = criterion_batch(output, target_var)
            else:
                model.train()
                output = model(input_var)
                loss_batch = 2 * args.clean_lam * criterion_batch(
                    output, target_var) / args.num_classes
            loss_batch_mean = torch.mean(loss_batch, dim=0)
            loss_batch_mean.backward(retain_graph=True)
            sc = torch.sqrt(torch.mean(input_var.grad**2, dim=1))

            # Here, we calculate distance between most salient location (Compatibility)
            # We can try various measurements
            with torch.no_grad():
                z = F.avg_pool2d(sc, kernel_size=8, stride=1)
                z_reshape = z.reshape(args.batch_size, -1)
                z_idx_1d = torch.argmax(z_reshape, dim=1)
                z_idx_2d = torch.zeros((args.batch_size, 2), device=z.device)
                z_idx_2d[:, 0] = z_idx_1d // z.shape[-1]
                z_idx_2d[:, 1] = z_idx_1d % z.shape[-1]
                A_dist = distance(z_idx_2d, dist_type='l1')

            if args.clean_lam == 0:
                model.train()
                optimizer.zero_grad()

            # Perform mixup and calculate loss
            target_reweighted = to_one_hot(target, args.num_classes)
            if args.parallel:
                device = input.device
                out, target_reweighted = mpp(input.cpu(),
                                             target_reweighted.cpu(),
                                             args=args,
                                             sc=sc.cpu(),
                                             A_dist=A_dist.cpu())
                out = out.to(device)
                target_reweighted = target_reweighted.to(device)

            else:
                out, target_reweighted = mixup_process(input,
                                                       target_reweighted,
                                                       args=args,
                                                       sc=sc,
                                                       A_dist=A_dist)

            out = model(out)
            loss = bce_loss(softmax(out), target_reweighted)

        # measure accuracy and record loss
        prec1, prec5 = accuracy(output, target, topk=(1, 5))
        losses.update(loss.item(), input.size(0))
        top1.update(prec1.item(), input.size(0))
        top5.update(prec5.item(), input.size(0))

        # compute gradient and do SGD step
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

    print_log(
        '**Train** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f}'
        .format(top1=top1, top5=top5, error1=100 - top1.avg), log)
    return top1.avg, top5.avg, losses.avg
Example #27
0
def validate_with_gt(args,
                     val_loader,
                     disp_net,
                     epoch,
                     logger,
                     output_writers=[]):
    global device
    batch_time = AverageMeter()
    error_names = ['abs_diff', 'abs_rel', 'sq_rel', 'a1', 'a2', 'a3']
    errors = AverageMeter(i=len(error_names))
    log_outputs = len(output_writers) > 0

    # switch to evaluate mode
    disp_net.eval()

    end = time.time()
    logger.valid_bar.update(0)
    for i, (tgt_img, depth) in enumerate(val_loader):
        tgt_img = tgt_img.to(device)
        depth = depth.to(device)

        # check gt
        if depth.nelement() == 0:
            continue

        # compute output
        output_disp = disp_net(tgt_img)
        output_depth = 1 / output_disp[:, 0]

        if log_outputs and i < len(output_writers):
            if epoch == 0:
                output_writers[i].add_image('val Input',
                                            tensor2array(tgt_img[0]), 0)
                depth_to_show = depth[0]
                output_writers[i].add_image(
                    'val target Depth',
                    tensor2array(depth_to_show, max_value=10), epoch)
                depth_to_show[depth_to_show == 0] = 1000
                disp_to_show = (1 / depth_to_show).clamp(0, 10)
                output_writers[i].add_image(
                    'val target Disparity Normalized',
                    tensor2array(disp_to_show,
                                 max_value=None,
                                 colormap='magma'), epoch)

            output_writers[i].add_image(
                'val Dispnet Output Normalized',
                tensor2array(output_disp[0], max_value=None, colormap='magma'),
                epoch)
            output_writers[i].add_image(
                'val Depth Output', tensor2array(output_depth[0],
                                                 max_value=10), epoch)

        if depth.nelement() != output_depth.nelement():
            b, h, w = depth.size()
            output_depth = torch.nn.functional.interpolate(
                output_depth.unsqueeze(1), [h, w]).squeeze(1)

        errors.update(compute_errors(depth, output_depth, args.dataset))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        logger.valid_bar.update(i + 1)
        if i % args.print_freq == 0:
            logger.valid_writer.write(
                'valid: Time {} Abs Error {:.4f} ({:.4f})'.format(
                    batch_time, errors.val[0], errors.avg[0]))
    logger.valid_bar.update(len(val_loader))
    return errors.avg, error_names
Example #28
0
def main():
    # For CUDA multi-processing
    if args.parallel:
        import torch.multiprocessing as mp
        mp.set_start_method("spawn")

    # Set up the experiment directories
    if not args.log_off:
        exp_name = define_exp_name()
        exp_dir = os.path.join(args.root_dir, exp_name)
        if not os.path.exists(exp_dir):
            os.makedirs(exp_dir)
        copy_script_to_folder(os.path.abspath(__file__), exp_dir)
        result_png_path = os.path.join(exp_dir, 'results.png')
        log = open(os.path.join(exp_dir, 'log.txt'.format(args.seed)), 'w')
        print_log('save path : {}'.format(exp_dir), log)
    else:
        log = None
        exp_dir = None
        result_png_path = None

    global best_acc

    state = {k: v for k, v in args._get_kwargs()}
    print("")
    print_log(state, log)
    print("")
    print_log("Random Seed: {}".format(args.seed), log)
    print_log("python version : {}".format(sys.version.replace('\n', ' ')),
              log)
    print_log("torch  version : {}".format(torch.__version__), log)
    print_log("cudnn  version : {}".format(torch.backends.cudnn.version()),
              log)

    # Dataloader
    train_loader, _, _, test_loader, num_classes = load_data_subset(
        args.batch_size,
        0,
        args.dataset,
        args.data_dir,
        labels_per_class=args.labels_per_class,
        valid_labels_per_class=args.valid_labels_per_class)

    if args.dataset == 'tiny-imagenet-200':
        stride = 2
        args.mean = torch.tensor([0.5] * 3,
                                 dtype=torch.float32).reshape(1, 3, 1,
                                                              1).cuda()
        args.std = torch.tensor([0.5] * 3,
                                dtype=torch.float32).reshape(1, 3, 1,
                                                             1).cuda()
        args.labels_per_class = 500
    elif args.dataset == 'cifar10':
        stride = 1
        args.mean = torch.tensor([x / 255 for x in [125.3, 123.0, 113.9]],
                                 dtype=torch.float32).reshape(1, 3, 1,
                                                              1).cuda()
        args.std = torch.tensor([x / 255 for x in [63.0, 62.1, 66.7]],
                                dtype=torch.float32).reshape(1, 3, 1,
                                                             1).cuda()
        args.labels_per_class = 5000
    elif args.dataset == 'cifar100':
        stride = 1
        args.mean = torch.tensor([x / 255 for x in [129.3, 124.1, 112.4]],
                                 dtype=torch.float32).reshape(1, 3, 1,
                                                              1).cuda()
        args.std = torch.tensor([x / 255 for x in [68.2, 65.4, 70.4]],
                                dtype=torch.float32).reshape(1, 3, 1,
                                                             1).cuda()
        args.labels_per_class = 500
    else:
        raise AssertionError('Given Dataset is not supported!')

    # Create model
    print_log("=> creating model '{}'".format(args.arch), log)
    net = models.__dict__[args.arch](num_classes, args.dropout, stride).cuda()
    args.num_classes = num_classes

    net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))
    optimizer = torch.optim.SGD(list(net.parameters()),
                                state['learning_rate'],
                                momentum=state['momentum'],
                                weight_decay=state['decay'],
                                nesterov=True)

    if args.parallel:
        mpp = MixupProcessParallel(args.m_part, args.batch_size, 1)
    else:
        mpp = None

    recorder = RecorderMeter(args.epochs)

    # Optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print_log("\n=> loading checkpoint '{}'".format(args.resume), log)
            checkpoint = torch.load(args.resume)
            recorder = checkpoint['recorder']
            args.start_epoch = checkpoint['epoch']
            net.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            best_acc = recorder.max_accuracy(False)
            print_log(
                "=> loaded checkpoint '{}' (epoch {})".format(
                    args.resume, checkpoint['epoch']), log)
        else:
            print_log("=> no checkpoint found at '{}'".format(args.resume),
                      log)
    else:
        print_log(
            "=> do not use any checkpoint for {} model".format(args.arch), log)

    if args.evaluate:
        validate(test_loader, net, log)
        return

    start_time = time.time()
    epoch_time = AverageMeter()
    train_loss = []
    train_acc = []
    test_loss = []
    test_acc = []

    for epoch in range(args.start_epoch, args.epochs):
        current_learning_rate = adjust_learning_rate(optimizer, epoch,
                                                     args.gammas,
                                                     args.schedule)

        need_hour, need_mins, need_secs = convert_secs2time(
            epoch_time.avg * (args.epochs - epoch))
        need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(
            need_hour, need_mins, need_secs)
        print_log('\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:6.4f}]'.format(time_string(), epoch, args.epochs, need_time, current_learning_rate) \
                + ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(recorder.max_accuracy(False), 100-recorder.max_accuracy(False)), log)

        # Train for one epoch
        tr_acc, tr_acc5, tr_los = train(train_loader, net, optimizer, epoch,
                                        args, log, mpp)

        # Evaluate on validation set
        val_acc, val_los = validate(test_loader, net, log)

        train_loss.append(tr_los)
        train_acc.append(tr_acc)
        test_loss.append(val_los)
        test_acc.append(val_acc)

        is_best = False
        if val_acc > best_acc:
            is_best = True
            best_acc = val_acc

        # Measure elapsed time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()

        if args.log_off:
            continue

        # Save log
        dummy = recorder.update(epoch, tr_los, tr_acc, val_los, val_acc)
        if (epoch + 1) % 100 == 0:
            recorder.plot_curve(result_png_path)

        train_log = OrderedDict()
        train_log['train_loss'] = train_loss
        train_log['train_acc'] = train_acc
        train_log['test_loss'] = test_loss
        train_log['test_acc'] = test_acc

        pickle.dump(train_log, open(os.path.join(exp_dir, 'log.pkl'), 'wb'))
        plotting(exp_dir)

        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': net.state_dict(),
                'recorder': recorder,
                'optimizer': optimizer.state_dict(),
            }, is_best, exp_dir, 'checkpoint.pth.tar')

    acc_var = np.maximum(
        np.max(test_acc[-10:]) - np.median(test_acc[-10:]),
        np.median(test_acc[-10:]) - np.min(test_acc[-10:]))
    print_log(
        "\nfinal 10 epoch acc (median) : {:.2f} (+- {:.2f})".format(
            np.median(test_acc[-10:]), acc_var), log)

    if not args.log_off:
        log.close()

    if args.parallel:
        mpp.close()
Example #29
0
def validate_without_gt(args,
                        val_loader,
                        disp_net,
                        pose_exp_net,
                        epoch,
                        logger,
                        tb_writer,
                        torch_device,
                        sample_nb_to_log=2):
    batch_time = AverageMeter()
    losses = AverageMeter(i=3, precision=4)
    log_outputs = sample_nb_to_log > 0
    w1, w2, w3, w4 = args.photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight, args.gt_pose_loss_weight
    poses = np.zeros(
        ((len(val_loader) - 1) * args.batch_size * (args.sequence_length - 1),
         6))
    disp_values = np.zeros(((len(val_loader) - 1) * args.batch_size * 3))

    # switch to evaluate mode
    disp_net.eval()
    pose_exp_net.eval()

    end = time.time()
    logger.valid_bar.update(0)
    for i, (tgt_img, tgt_lf, ref_imgs, ref_lfs, intrinsics, intrinsics_inv,
            pose_gt) in enumerate(val_loader):
        tgt_img = tgt_img.to(torch_device)
        ref_imgs = [img.to(torch_device) for img in ref_imgs]
        tgt_lf = tgt_lf.to(torch_device)
        ref_lfs = [lf.to(torch_device) for lf in ref_lfs]
        intrinsics = intrinsics.to(torch_device)
        intrinsics_inv = intrinsics_inv.to(torch_device)
        pose_gt = pose_gt.to(torch_device)

        # compute output
        disp = disp_net(tgt_lf)
        depth = 1 / disp
        explainability_mask, pose = pose_exp_net(tgt_lf, ref_lfs)

        loss_1, warped, diff = photometric_reconstruction_loss(
            tgt_img, ref_imgs, intrinsics, depth, explainability_mask, pose,
            args.rotation_mode, args.padding_mode)
        loss_1 = loss_1.item()
        if w2 > 0:
            loss_2 = explainability_loss(explainability_mask).item()
        else:
            loss_2 = 0
        loss_3 = smooth_loss(depth).item()

        pred_pose_magnitude = pose[:, :, :3].norm(dim=2)
        pose_gt_magnitude = pose_gt[:, :, :3].norm(dim=2)
        pose_loss = (pred_pose_magnitude - pose_gt_magnitude).abs().mean()

        if log_outputs and i < sample_nb_to_log - 1:  # log first output of first batches
            if epoch == 0:
                for j, ref in enumerate(ref_imgs):
                    tb_writer.add_image('val/Input {}/{}'.format(j, i),
                                        tensor2array(tgt_img[0]), 0)
                    tb_writer.add_image('val/Input {}/{}'.format(j, i),
                                        tensor2array(ref[0]), 1)

            log_output_tensorboard(tb_writer, 'val', i, '', epoch, 1. / disp,
                                   disp, warped[0], diff[0],
                                   explainability_mask)

        if log_outputs and i < len(val_loader) - 1:
            step = args.batch_size * (args.sequence_length - 1)
            poses[i * step:(i + 1) * step] = pose.cpu().view(-1, 6).numpy()
            step = args.batch_size * 3
            disp_unraveled = disp.cpu().view(args.batch_size, -1)
            disp_values[i * step:(i + 1) * step] = torch.cat([
                disp_unraveled.min(-1)[0],
                disp_unraveled.median(-1)[0],
                disp_unraveled.max(-1)[0]
            ]).numpy()

        loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3 + w4 * pose_loss
        losses.update([loss, loss_1, loss_2])

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        logger.valid_bar.update(i + 1)
        if i % args.print_freq == 0:
            logger.valid_writer.write('valid: Time {} Loss {}'.format(
                batch_time, losses))
    if log_outputs:
        prefix = 'valid poses'
        coeffs_names = ['tx', 'ty', 'tz']
        if args.rotation_mode == 'euler':
            coeffs_names.extend(['rx', 'ry', 'rz'])
        elif args.rotation_mode == 'quat':
            coeffs_names.extend(['qx', 'qy', 'qz'])
        for i in range(poses.shape[1]):
            tb_writer.add_histogram('{} {}'.format(prefix, coeffs_names[i]),
                                    poses[:, i], epoch)
        tb_writer.add_histogram('disp_values', disp_values, epoch)
    logger.valid_bar.update(len(val_loader))
    return losses.avg, [
        'val/total_loss', 'val/photometric_error', 'val/explainability_loss'
    ]
Example #30
0
def main():
    global args
    args = parser.parse_args()
    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                            std=[0.5, 0.5, 0.5])
    flow_loader_h, flow_loader_w = 256, 832
    valid_flow_transform = custom_transforms.Compose([
        custom_transforms.Scale(h=flow_loader_h, w=flow_loader_w),
        custom_transforms.ArrayToTensor(), normalize
    ])
    if args.dataset == "kitti2015":
        val_flow_set = ValidationFlow(
            root='/home/anuragr/datasets/kitti/kitti2015',
            sequence_length=5,
            transform=valid_flow_transform)
    elif args.dataset == "kitti2012":
        val_flow_set = ValidationFlowKitti2012(
            root='/is/ps2/aranjan/AllFlowData/kitti/kitti2012',
            sequence_length=5,
            transform=valid_flow_transform)

    val_flow_loader = torch.utils.data.DataLoader(val_flow_set,
                                                  batch_size=1,
                                                  shuffle=False,
                                                  num_workers=2,
                                                  pin_memory=True,
                                                  drop_last=True)

    flow_net = getattr(models, args.flownet)(nlevels=args.nlevels).cuda()

    if args.pretrained_flow:
        print("=> using pre-trained weights from {}".format(
            args.pretrained_flow))
        weights = torch.load(args.pretrained_flow)
        flow_net.load_state_dict(weights['state_dict'])  #, strict=False)

    flow_net = flow_net.cuda()
    flow_net.eval()
    error_names = ['epe_total', 'epe_non_rigid', 'epe_rigid', 'outliers']
    errors = AverageMeter(i=len(error_names))

    for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv, flow_gt,
            obj_map) in enumerate(tqdm(val_flow_loader)):
        tgt_img_var = Variable(tgt_img.cuda(), volatile=True)
        if args.dataset == "kitti2015":
            ref_imgs_var = [
                Variable(img.cuda(), volatile=True) for img in ref_imgs
            ]
            ref_img_var = ref_imgs_var[1:3]
        elif args.dataset == "kitti2012":
            ref_img_var = Variable(ref_imgs.cuda(), volatile=True)

        flow_gt_var = Variable(flow_gt.cuda(), volatile=True)
        # compute output
        flow_fwd, flow_bwd, occ = flow_net(tgt_img_var, ref_img_var)
        #epe = compute_epe(gt=flow_gt_var, pred=flow_fwd)
        obj_map_gt_var = Variable(obj_map.cuda(), volatile=True)
        obj_map_gt_var_expanded = obj_map_gt_var.unsqueeze(1).type_as(flow_fwd)

        epe = compute_all_epes(flow_gt_var, flow_fwd, flow_fwd,
                               (1 - obj_map_gt_var_expanded))
        #print(i, epe)
        errors.update(epe)

    print("Averge EPE", errors.avg)