Exemplo n.º 1
0
def main():
    args = parser.parse_args()
    if not (args.output_disp or args.output_depth):
        print('You must at least output one value !')
        return

    # disp_net = DispNetS().to(device)
    disp_net = DispResNet(3, alpha=1).to(device)
    weights = torch.load(args.pretrained)
    disp_net.load_state_dict(weights['state_dict'])
    disp_net.eval()

    dataset_dir = Path(args.dataset_dir)
    output_dir = Path(args.output_dir)
    output_disp = output_dir / 'disp'
    output_depth = output_dir / 'depth'
    output_disp.makedirs_p()
    output_depth.makedirs_p()

    if args.dataset_list is not None:
        with open(args.dataset_list, 'r') as f:
            test_files = [dataset_dir / file for file in f.read().splitlines()]
    else:
        test_files = sum(
            [dataset_dir.files('*.{}'.format(ext)) for ext in args.img_exts],
            [])

    print('{} files to test'.format(len(test_files)))

    count = 0
    for file in tqdm(test_files, ncols=100):

        img = imread(file).astype(np.float32)

        h, w, _ = img.shape
        if (not args.no_resize) and (h != args.img_height
                                     or w != args.img_width):
            img = imresize(img, (args.img_height, args.img_width)).astype(
                np.float32)
        img = np.transpose(img, (2, 0, 1))

        tensor_img = torch.from_numpy(img).unsqueeze(0)
        tensor_img = ((tensor_img / 255 - 0.5) / 0.2).to(device)

        output = disp_net(tensor_img)[0][0]

        if args.output_disp:
            disp = (255 * tensor2array(
                output, max_value=None, colormap='bone',
                channel_first=False)).astype(np.uint8)
            img = np.transpose(img, (1, 2, 0))
            im_save = np.concatenate((disp, img), axis=1).astype(np.uint8)
            imsave(output_disp / '{}_disp{}'.format(count, file.ext), im_save)
        if args.output_depth:
            depth = 1 / output
            depth = (255 * tensor2array(
                depth, max_value=1, colormap='rainbow',
                channel_first=False)).astype(np.uint8)
            imsave(output_depth / '{}_depth{}'.format(count, file.ext), depth)
        count += 1
Exemplo n.º 2
0
def validate_with_gt(args,
                     val_loader,
                     disp_net,
                     epoch,
                     logger,
                     tb_writer,
                     sample_nb_to_log=3):
    global device
    batch_time = AverageMeter()
    error_names = ['abs_diff', 'abs_rel', 'sq_rel', 'a1', 'a2', 'a3']
    errors = AverageMeter(i=len(error_names))
    log_outputs = sample_nb_to_log > 0

    # switch to evaluate mode
    disp_net.eval()

    end = time.time()
    #logger.valid_bar.update(0)
    for i, (tgt_img, depth) in enumerate(val_loader):
        tgt_img = tgt_img.to(device)
        depth = depth.to(device)

        # compute output
        output_disp = disp_net(tgt_img)
        output_depth = 1 / output_disp[:, 0]

        if log_outputs and i < sample_nb_to_log:
            if epoch == 0:
                tb_writer.add_image('val Input/{}'.format(i),
                                    tensor2array(tgt_img[0]), 0)
                depth_to_show = depth[0]
                tb_writer.add_image(
                    'val target Depth Normalized/{}'.format(i),
                    tensor2array(depth_to_show, max_value=None), epoch)
                depth_to_show[depth_to_show == 0] = 1000
                disp_to_show = (1 / depth_to_show).clamp(0, 10)
                tb_writer.add_image(
                    'val target Disparity Normalized/{}'.format(i),
                    tensor2array(disp_to_show,
                                 max_value=None,
                                 colormap='magma'), epoch)

            tb_writer.add_image(
                'val Dispnet Output Normalized/{}'.format(i),
                tensor2array(output_disp[0], max_value=None, colormap='magma'),
                epoch)
            tb_writer.add_image('val Depth Output Normalized/{}'.format(i),
                                tensor2array(output_depth[0], max_value=None),
                                epoch)

        errors.update(compute_errors(depth, output_depth))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        #logger.valid_bar.update(i+1)
        #if i % args.print_freq == 0:
        #logger.valid_writer.write('valid: Time {} Abs Error {:.4f} ({:.4f})'.format(batch_time, errors.val[0], errors.avg[0]))
    #logger.valid_bar.update(len(val_loader))
    return errors.avg, error_names
Exemplo n.º 3
0
def validate_without_gt(args, val_loader, disp_net, pose_net, epoch, logger, output_writers=[]):
    global device
    batch_time = AverageMeter()
    losses = AverageMeter(i=4, precision=4)
    log_outputs = len(output_writers) > 0

    # switch to evaluate mode
    disp_net.eval()
    pose_net.eval()

    end = time.time()
    logger.valid_bar.update(0)
    for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv) in enumerate(val_loader):
        tgt_img = tgt_img.to(device)
        ref_imgs = [img.to(device) for img in ref_imgs]
        intrinsics = intrinsics.to(device)
        intrinsics_inv = intrinsics_inv.to(device)

        # compute output
        tgt_depth = [1 / disp_net(tgt_img)]
        ref_depths = []
        for ref_img in ref_imgs:
            ref_depth = [1 / disp_net(ref_img)]
            ref_depths.append(ref_depth)

        if log_outputs and i < len(output_writers):
            if epoch == 0:
                output_writers[i].add_image('val Input', tensor2array(tgt_img[0]), 0)

            output_writers[i].add_image('val Dispnet Output Normalized',
                                        tensor2array(1/tgt_depth[0][0], max_value=None, colormap='magma'),
                                        epoch)
            output_writers[i].add_image('val Depth Output',
                                        tensor2array(tgt_depth[0][0], max_value=10),
                                        epoch)

        poses, poses_inv = compute_pose_with_inv(pose_net, tgt_img, ref_imgs)

        loss_1, loss_3 = compute_photo_and_geometry_loss(tgt_img, ref_imgs, intrinsics, tgt_depth, ref_depths,
                                                         poses, poses_inv, args.num_scales, args.with_ssim,
                                                         args.with_mask, False, args.padding_mode)

        loss_2 = compute_smooth_loss(tgt_depth, tgt_img, ref_depths, ref_imgs)

        loss_1 = loss_1.item()
        loss_2 = loss_2.item()
        loss_3 = loss_3.item()

        loss = loss_1
        losses.update([loss, loss_1, loss_2, loss_3])

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        logger.valid_bar.update(i+1)
        if i % args.print_freq == 0:
            logger.valid_writer.write('valid: Time {} Loss {}'.format(batch_time, losses))

    logger.valid_bar.update(len(val_loader))
    return losses.avg, ['Total loss', 'Photo loss', 'Smooth loss', 'Consistency loss']
Exemplo n.º 4
0
def validate_flow_with_gt(val_loader, flow_net, epoch, logger, output_writers=[]):
    global args
    batch_time = AverageMeter()
    error_names = ['epe_total', 'epe_rigid', 'epe_non_rigid', 'outliers']
    errors = AverageMeter(i=len(error_names))
    log_outputs = len(output_writers) > 0

    # switch to evaluate mode
    flow_net.eval()

    end = time.time()

    for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv, flow_gt, obj_map_gt) in enumerate(val_loader):
        tgt_img_var = Variable(tgt_img.cuda(), volatile=True)
        ref_imgs_var = [Variable(img.cuda(), volatile=True) for img in ref_imgs]

        flow_gt_var = Variable(flow_gt.cuda(), volatile=True)
        obj_map_gt_var = Variable(obj_map_gt.cuda(), volatile=True)

        if args.flownet == 'Back2Future':
            flow_fwd, flow_bwd= flow_net(tgt_img_var, ref_imgs_var)
        else:
            flow_fwd = flow_net(tgt_img_var, ref_imgs_var[1])
            flow_bwd = flow_net(tgt_img_var, ref_imgs_var[0])


        if args.DEBUG:
            flow_fwd_x = flow_fwd[:,0].view(-1).abs().data
            flow_gt_var_x = flow_gt_var[:,0].view(-1).abs().data
       
        flow_fwd_non_rigid =  flow_fwd
        total_flow = flow_fwd

        obj_map_gt_var_expanded = obj_map_gt_var.unsqueeze(1).type_as(flow_fwd)

        if log_outputs and i % 10 == 0 and i/10 < len(output_writers):
            index = int(i//10)
            if epoch == 0:
                output_writers[index].add_image('val flow Input', tensor2array(tgt_img[0]), 0)
                flow_to_show = flow_gt[0][:2,:,:].cpu()
                output_writers[index].add_image('val target Flow', flow_to_image(tensor2array(flow_to_show)), epoch)

            output_writers[index].add_image('val Non-rigid Flow Output', flow_to_image(tensor2array(flow_fwd_non_rigid.data[0].cpu())), epoch)
            
        if np.isnan(flow_gt.sum().item()) or np.isnan(total_flow.data.sum().item()):
            print('NaN encountered')
        _epe_errors = compute_all_epes(flow_gt_var, flow_fwd, flow_fwd, (1-obj_map_gt_var_expanded) )
        errors.update(_epe_errors)

        batch_time.update(time.time() - end)
        end = time.time()
        if args.log_terminal:
            logger.valid_bar.update(i)
            if i % args.print_freq == 0:
                logger.valid_writer.write('valid: Time {} Abs Error {:.4f} ({:.4f})'.format(batch_time, errors.val[0], errors.avg[0]))

    if args.log_terminal:
        logger.valid_bar.update(len(val_loader))

    return errors.avg, error_names
def main():
    args = parser.parse_args()
    if not (args.output_disp or args.output_depth):
        print('You must at least output one value !')
        return

    disp_net = DispResNet(args.resnet_layers, False).to(device)
    weights = torch.load(args.pretrained)
    disp_net.load_state_dict(weights['state_dict'])
    disp_net.eval()

    dataset_dir = Path(args.dataset_dir)
    output_dir = Path(args.output_dir)
    output_dir.makedirs_p()

    if args.dataset_list is not None:
        with open(args.dataset_list, 'r') as f:
            test_files = [dataset_dir / file for file in f.read().splitlines()]
    else:
        test_files = sum(
            [dataset_dir.files('*.{}'.format(ext)) for ext in args.img_exts],
            [])

    print('{} files to test'.format(len(test_files)))

    for file in tqdm(test_files):

        img = imread(file).astype(np.float32)

        h, w, _ = img.shape
        if (not args.no_resize) and (h != args.img_height
                                     or w != args.img_width):
            # img = imresize(img, (args.img_height, args.img_width)).astype(np.float32)
            img = np.array(
                Image.fromarray(np.uint8(img)).resize(
                    (args.img_width, args.img_height))).astype(np.float32)
        img = np.transpose(img, (2, 0, 1))

        tensor_img = torch.from_numpy(img).unsqueeze(0)

        tensor_img = ((tensor_img / 255 - 0.45) / 0.225).to(device)

        output = disp_net(tensor_img)[0]

        file_path, file_ext = file.relpath(args.dataset_dir).splitext()
        file_name = '-'.join(file_path.splitall())

        if args.output_disp:
            disp = (255 * tensor2array(output, max_value=None,
                                       colormap='bone')).astype(np.uint8)
            imsave(output_dir / '{}_disp{}'.format(file_name, file_ext),
                   np.transpose(disp, (1, 2, 0)))
        if args.output_depth:
            depth = 1 / output
            depth = (255 *
                     tensor2array(depth, max_value=5, colormap='gray')).astype(
                         np.uint8)
            imsave(output_dir / '{}_depth{}'.format(file_name, file_ext),
                   np.transpose(depth, (1, 2, 0)))
Exemplo n.º 6
0
def validate_with_gt(args, val_loader, disp_net, epoch, logger, output_writers=[]):
    global device
    batch_time = AverageMeter()
    error_names = ['abs_diff', 'abs_rel', 'sq_rel', 'a1', 'a2', 'a3']
    errors = AverageMeter(i=len(error_names))
    log_outputs = len(output_writers) > 0

    # switch to evaluate mode
    disp_net.eval()

    end = time.time()
    logger.valid_bar.update(0)
    for i, (tgt_img, depth) in enumerate(val_loader):
        tgt_img = tgt_img.to(device)
        depth = depth.to(device)

        # check gt
        if depth.nelement() == 0:
            continue

        # compute output
        output_disp = disp_net(tgt_img)
        output_depth = 1/output_disp[:, 0]

        if log_outputs and i < len(output_writers):
            if epoch == 0:
                output_writers[i].add_image('val Input', tensor2array(tgt_img[0]), 0)
                depth_to_show = depth[0]
                output_writers[i].add_image('val target Depth',
                                            tensor2array(depth_to_show, max_value=10),
                                            epoch)
                depth_to_show[depth_to_show == 0] = 1000
                disp_to_show = (1/depth_to_show).clamp(0, 10)
                output_writers[i].add_image('val target Disparity Normalized',
                                            tensor2array(disp_to_show, max_value=None, colormap='magma'),
                                            epoch)

            output_writers[i].add_image('val Dispnet Output Normalized',
                                        tensor2array(output_disp[0], max_value=None, colormap='magma'),
                                        epoch)
            output_writers[i].add_image('val Depth Output',
                                        tensor2array(output_depth[0], max_value=10),
                                        epoch)

        if depth.nelement() != output_depth.nelement():
            b, h, w = depth.size()
            output_depth = torch.nn.functional.interpolate(output_depth.unsqueeze(1), [h, w]).squeeze(1)

        errors.update(compute_errors(depth, output_depth, args.dataset))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        logger.valid_bar.update(i+1)
        if i % args.print_freq == 0:
            logger.valid_writer.write('valid: Time {} Abs Error {:.4f} ({:.4f})'.format(batch_time, errors.val[0], errors.avg[0]))
    logger.valid_bar.update(len(val_loader))
    return errors.avg, error_names
def validate_with_gt(val_loader, disp_net, epoch, logger, output_writers=[]):
    global args
    batch_time = AverageMeter()
    error_names = ['abs_diff', 'abs_rel', 'sq_rel', 'a1', 'a2', 'a3']
    errors = AverageMeter(i=len(error_names))
    log_outputs = len(output_writers) > 0

    # switch to evaluate mode
    disp_net.eval()

    end = time.time()
    logger.valid_bar.update(0)
    for i, (tgt_img, depth) in enumerate(val_loader):
        tgt_img_var = Variable(tgt_img.cuda(), volatile=True)
        depth = depth.cuda()

        # compute output
        output_disp = disp_net(tgt_img_var)
        output_depth = 1 / output_disp

        if log_outputs and i % 100 == 0 and i / 100 < len(output_writers):
            index = int(i // 100)
            if epoch == 0:
                output_writers[index].add_image('val Input',
                                                tensor2array(tgt_img[0]), 0)
                depth_to_show = depth[0].cpu()
                output_writers[index].add_image(
                    'val target Depth',
                    tensor2array(depth_to_show, max_value=10), epoch)
                depth_to_show[depth_to_show == 0] = 1000
                disp_to_show = (1 / depth_to_show).clamp(0, 10)
                output_writers[index].add_image(
                    'val target Disparity Normalized',
                    tensor2array(disp_to_show, max_value=None,
                                 colormap='bone'), epoch)

            output_writers[index].add_image(
                'val Dispnet Output Normalized',
                tensor2array(output_disp.data[0].cpu(),
                             max_value=None,
                             colormap='bone'), epoch)
            output_writers[index].add_image(
                'val Depth Output',
                tensor2array(output_depth.data[0].cpu(), max_value=10), epoch)

        errors.update(compute_errors(depth, output_depth.data))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        logger.valid_bar.update(i + 1)
        if i % args.print_freq == 0:
            logger.valid_writer.write(
                'valid: Time {} Abs Error {:.4f} ({:.4f})'.format(
                    batch_time, errors.val[0], errors.avg[0]))
    logger.valid_bar.update(len(val_loader))
    return errors.avg, error_names
Exemplo n.º 8
0
def main():
    args = parser.parse_args()
    if not (args.output_disp or args.output_depth):
        print('You must at least output one value !')
        return

    disp_net = DispNetS().cuda()
    weights = torch.load(args.pretrained)
    disp_net.load_state_dict(weights['state_dict'])
    disp_net.eval()

    dataset_dir = Path(args.dataset_dir)
    output_dir = Path(args.output_dir)
    output_dir.makedirs_p()

    if args.dataset_list is not None:
        with open(args.dataset_list, 'r') as f:
            test_files = [dataset_dir / file for file in f.read().splitlines()]
    else:
        test_files = sum(
            [dataset_dir.files('*.{}'.format(ext)) for ext in args.img_exts],
            [])

    print('{} files to test'.format(len(test_files)))

    for file in tqdm(test_files):

        img = imread(file).astype(np.float32)

        h, w, _ = img.shape
        if (not args.no_resize) and (h != args.img_height
                                     or w != args.img_width):
            img = imresize(img, (args.img_height, args.img_width)).astype(
                np.float32)
        img = np.transpose(img, (2, 0, 1))

        tensor_img = torch.from_numpy(img).unsqueeze(0)
        tensor_img = ((tensor_img / 255 - 0.5) / 0.2).cuda()
        var_img = torch.autograd.Variable(tensor_img, volatile=True)

        output = disp_net(var_img).data.cpu()[0]

        if args.output_disp:
            disp = (255 * tensor2array(output, max_value=10,
                                       colormap='bone')).astype(np.uint8)
            imsave(output_dir / '{}_disp{}'.format(file.namebase, file.ext),
                   disp)
        if args.output_depth:
            depth = 1 / output
            depth = (255 * tensor2array(depth, max_value=1,
                                        colormap='rainbow')).astype(np.uint8)
            imsave(output_dir / '{}_depth{}'.format(file.namebase, file.ext),
                   depth)
def log_result(pred_depth, GT, input_batch, selected_index, folder, prefix):
    def save(path, to_save):
        to_save = (255 * to_save.transpose(1, 2, 0)).astype(np.uint8)
        imageio.imsave(path, to_save)

    pred_to_save = tensor2array(pred_depth, max_value=100)
    gt_to_save = tensor2array(torch.from_numpy(GT), max_value=100)

    prefix = folder / prefix
    save('{}_depth_pred.jpg'.format(prefix), pred_to_save)
    save('{}_depth_gt.jpg'.format(prefix), gt_to_save)
    disp_to_save = tensor2array(1 / pred_depth,
                                max_value=None,
                                colormap='magma')
    gt_disp = np.zeros_like(GT)
    valid_depth = GT > 0
    gt_disp[valid_depth] = 1 / GT[valid_depth]

    gt_disp_to_save = tensor2array(torch.from_numpy(gt_disp),
                                   max_value=None,
                                   colormap='magma')
    save('{}_disp_pred.jpg'.format(prefix), disp_to_save)
    save('{}_disp_gt.jpg'.format(prefix), gt_disp_to_save)
    to_save = tensor2array(input_batch.cpu().data[selected_index, :3])
    save('{}_input0.jpg'.format(prefix), to_save)
    to_save = tensor2array(input_batch.cpu()[selected_index, 3:])
    save('{}_input1.jpg'.format(prefix), to_save)
    for i, batch_elem in enumerate(input_batch.cpu().data):
        to_save = tensor2array(batch_elem[:3])
        save('{}_batch_{}_0.jpg'.format(prefix, i), to_save)
        to_save = tensor2array(batch_elem[3:])
        save('{}_batch_{}_1.jpg'.format(prefix, i), to_save)
Exemplo n.º 10
0
def validate_with_gt(args, val_loader, disp_net, epoch, output_writers=[]):
    global device
    batch_time = AverageMeter()
    error_names = ['abs_diff', 'abs_rel', 'sq_rel', 'a1', 'a2', 'a3']
    errors = AverageMeter(i=len(error_names))
    log_outputs = len(output_writers) > 0

    # switch to evaluate mode
    disp_net.eval()

    end = time.time()
    for i, (tgt_img, depth) in enumerate(val_loader):
        tgt_img = tgt_img.to(device)
        depth = depth.to(device)

        # compute output
        output_disp = disp_net(tgt_img)
        output_depth = 1/output_disp[:,0]

        if log_outputs and i < len(output_writers):
            if epoch == 0:
                output_writers[i].add_image('val Input', tensor2array(tgt_img[0]), 0)
                depth_to_show = depth[0]
                output_writers[i].add_image('val target Depth',
                                            tensor2array(depth_to_show, max_value=10),
                                            epoch)
                depth_to_show[depth_to_show == 0] = 1000
                disp_to_show = (1/depth_to_show).clamp(0,10)
                output_writers[i].add_image('val target Disparity Normalized',
                                            tensor2array(disp_to_show, max_value=None, colormap='bone'),
                                            epoch)

            output_writers[i].add_image('val Dispnet Output Normalized',
                                        tensor2array(output_disp[0], max_value=None, colormap='bone'),
                                        epoch)
            output_writers[i].add_image('val Depth Output',
                                        tensor2array(output_depth[0], max_value=3),
                                        epoch)

        errors.update(compute_errors(depth, output_depth))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        if i % args.print_freq == 0:
            print('valid: Time {} Abs Error {:.4f} ({:.4f})'.format(batch_time, errors.val[0], errors.avg[0]))
    return errors.avg, error_names
Exemplo n.º 11
0
def train(args, train_loader, disvo, optimizer, epoch_size, logger, train_writer):
    global n_iter, device
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter(precision=4)

    # switch to train mode
    disvo.train()

    end = time.time()
    logger.train_bar.update(0)

    for i, (img_ref, img_tar, poses_gt) in enumerate(train_loader):
        log_losses = i > 0 and n_iter % args.print_freq == 0
        log_output = args.training_output_freq > 0 and n_iter % args.training_output_freq == 0

        # measure data loading time
        data_time.update(time.time() - end)
        img_ref = img_ref.to(device)
        img_tar = img_tar.to(device)

        # compute output
        _, poses_pred = disvo(img_ref, img_tar)

        loss = sum((poses_pred[:6] - poses_gt).^2 * torch.exp(-poses_pred[6:]) + poses[6:])

        if log_losses:
            train_writer.add_scalar('total_loss', loss.item(), n_iter)

        if log_output:
            train_writer.add_image('train Input', tensor2array(tgt_img[0]), n_iter)
            for k, scaled_maps in enumerate(zip(depth, disparities, warped, diff, explainability_mask)):
                log_output_tensorboard(train_writer, "train", k, n_iter, *scaled_maps)

        # record loss and EPE
        losses.update(loss.item(), args.batch_size)

        # compute gradient and do Adam step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        with open(args.save_path/args.log_full, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([loss.item()])
        logger.train_bar.update(i+1)
        if i % args.print_freq == 0:
            logger.train_writer.write('Train: Time {} Data {} Loss {}'.format(batch_time, data_time, losses))
        if i >= epoch_size - 1:
            break

        n_iter += 1

    return losses.avg[0]
Exemplo n.º 12
0
def PhotoMask_Output(explainability_mask, disp_net, intrinsics, j, poses,
                     ref_imgs, save_dir):
    global tgt_img
    intrinsics = torch.tensor(intrinsics).unsqueeze(0)
    intrinsics = intrinsics.to(device)
    disp = disp_net(tgt_img)
    depth = 1 / disp
    ref_depths = []
    for ref_img in ref_imgs:
        ref_disparities = disp_net(ref_img)
        ref_depth = 1 / ref_disparities
        ref_depths.append(ref_depth)
    from loss_functions2 import photometric_reconstruction_and_depth_diff_loss
    reconstruction_loss, depth_diff_loss, warped_imgs, diff_maps, weighted_masks = photometric_reconstruction_and_depth_diff_loss(
        tgt_img,
        ref_imgs,
        intrinsics,
        depth,
        ref_depths,
        explainability_mask,
        poses,
        'euler',
        'zeros',
        isTrain=False)
    im_path = save_dir + '/seq_{}/'.format(j)
    if not os.path.exists(im_path):
        os.makedirs(im_path)
    # save tgt_img
    tgt_img = tensor2array(tgt_img[0]) * 255
    tgt_img = tgt_img.transpose(1, 2, 0)
    img = Image.fromarray(np.uint8(tgt_img)).convert('RGB')
    img.save(im_path + 'tgt.jpg')

    for i in range(len(warped_imgs[0])):
        warped_img = tensor2array(warped_imgs[0][i]) * 255
        warped_img = warped_img.transpose(1, 2, 0)
        img = Image.fromarray(np.uint8(warped_img)).convert('RGB')
        img.save(im_path + 'src_{}.jpg'.format(i))
    for i in range(len(weighted_masks[0])):
        weighted_mask = weighted_masks[0][i].cpu().clone().numpy() * 255
        img = Image.fromarray(weighted_mask)
        img = img.convert('L')
        img.save(im_path + 'photomask_{}.jpg'.format(i))
Exemplo n.º 13
0
def create_adversarial(file_in, file_bim_out, file_bim_entropy_out,
                       file_cw_out, file_cw_entropy_out, label_adv, label_true,
                       model):
    # Load original image
    img_original = utils.open_image_as_tensor(file_in)
    _img_original = utils.tensor2array(img_original)
    #_img_original = utils.open_image_properly(file_in, arch='inception')

    labels_adv = []
    while len(labels_adv
              ) < 4:  # If many different labels failed, we skip this sample
        # Try a different label
        if len(labels_adv) > 0:
            label_adv = random.randint(0, NUM_CLASSES)
            while label_adv == label_true or label_adv in labels_adv:
                label_adv = random.randint(0, NUM_CLASSES)

            labels_adv.append(label_adv)
        else:
            labels_adv = [label_adv]

        # Perform adversarial attack
        img_bim = attack(_img_original,
                         model,
                         "BIM",
                         label_adv,
                         label_true,
                         entropy_masking=False)
        if img_bim is None:
            continue
        img_bim_entropy = attack(_img_original,
                                 model,
                                 "BIM",
                                 label_adv,
                                 label_true,
                                 entropy_masking=True)
        if img_bim_entropy is None:
            continue

        # Save adversarial images
        skimage.io.imsave(file_bim_out, img_bim)
        skimage.io.imsave(file_bim_entropy_out, img_bim_entropy)

        break
Exemplo n.º 14
0
def train(train_loader, FCCM_net, optimizer, epoch_size,  train_writer=None):
    global args, n_iter
    average_loss = 0
    FCCM_net.train()

    for i, (CT_img, ground_truth) in enumerate(tqdm(train_loader)):
        
        CT_img = Variable(CT_img[0].cuda())
        
        ground_truth = torch.tensor((ground_truth.float())).cuda().squeeze(1)
        if args.normalization == 'max':
            CT_img = max_normalize(CT_img)
        if args.normalization =='mean':
            CT_img = mean_normalize(CT_img)


        predict_result = FCCM_net(CT_img)
        # print(predict_result.size())
        #print(ground_truth.size())
        # print(ground_truth[:, -1],predict_result[:,-1])
        classification_loss = nn.functional.binary_cross_entropy(torch.sigmoid(predict_result), ground_truth[:, :-1])
        # regression_loss = ((predict_result[:,-1] - ground_truth[:, -1]) **2).mean()

        loss = classification_loss # + regression_loss
        

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        average_loss += loss.item()

        if i > 0 and n_iter % args.print_freq == 0:
            train_writer.add_scalar('classification_loss', classification_loss.item(), n_iter)
            # train_writer.add_scalar('regression_loss', regression_loss.item(), n_iter)
        if  n_iter % (args.print_freq*40) == 0:
            train_writer.add_image('Input image',
                                    tensor2array(CT_img.data[0].cpu(), max_value=None, colormap='bone'),
                                    n_iter)
        n_iter+=1
    return average_loss/i
Exemplo n.º 15
0
def log_result(depthmap, GT, input_batch, selected_index, folder, prefix):
    pred_depth_t = torch.from_numpy(depthmap)
    to_save = tensor2array(pred_depth_t, max_value=100)
    gt_to_save = tensor2array(torch.from_numpy(GT), max_value=100)
    prefix = folder/prefix
    scipy.misc.imsave('{}_depth_pred.jpg'.format(prefix), to_save)
    scipy.misc.imsave('{}_depth_gt.jpg'.format(prefix), gt_to_save)
    to_save = tensor2array(input_batch.cpu().data[selected_index,:3])
    scipy.misc.imsave('{}_input0.jpg'.format(prefix), to_save)
    to_save = tensor2array(input_batch.cpu().data[selected_index,3:])
    scipy.misc.imsave('{}_input1.jpg'.format(prefix), to_save)
    for i, batch_elem in enumerate(input_batch.cpu().data):
        to_save = tensor2array(batch_elem[:3])
        scipy.misc.imsave('{}_batch_{}_0.jpg'.format(prefix, i), to_save)
        to_save = tensor2array(batch_elem[3:])
        scipy.misc.imsave('{}_batch_{}_1.jpg'.format(prefix, i), to_save)
def validate_without_gt(args, val_loader, disp_net, pose_net,
                        epoch, logger, tb_writer, w1, w3, sample_nb_to_log=2):
    global device
    batch_time = AverageMeter()
    losses = AverageMeter(i=4, precision=4)
    log_outputs = sample_nb_to_log > 0

    # switch to evaluate mode
    disp_net.eval()
    pose_net.eval()

    end = time.time()
    logger.valid_bar.start()
    logger.valid_bar.update(0)
    for i, validdata in enumerate(val_loader):
        tgt_lf = validdata['tgt_lf'].to(device)
        ref_lfs = [ref.to(device) for ref in validdata['ref_lfs']]

        if args.lfformat == "epi" and args.cameras_epi == "full":
            tgt_lf_formatted_h = validdata['tgt_lf_formatted_h'].to(device)
            tgt_lf_formatted_v = validdata['tgt_lf_formatted_v'].to(device)
            ref_lfs_formatted_h = [lf.to(device) for lf in validdata['ref_lfs_formatted_h']]
            ref_lfs_formatted_v = [lf.to(device) for lf in validdata['ref_lfs_formatted_v']]

            tgt_stack = validdata['tgt_stack'].to(device)
            ref_stacks = [lf.to(device) for lf in validdata['ref_stacks']]

            # Encode the epi images further
            if args.without_disp_stack:
                # Stacked images should not be concatenated with the encoded EPI images
                tgt_lf_encoded_d = disp_net.encode(tgt_lf_formatted_v, None, tgt_lf_formatted_h)
            else:
                # Stacked images should be concatenated with the encoded EPI images
                tgt_lf_encoded_d = disp_net.encode(tgt_lf_formatted_v, tgt_stack, tgt_lf_formatted_h)

            tgt_lf_encoded_p, ref_lfs_encoded_p = pose_net.encode(tgt_lf_formatted_v, tgt_stack,
                                                                  ref_lfs_formatted_v, ref_stacks,
                                                                  tgt_lf_formatted_h, ref_lfs_formatted_h)
        else:
            tgt_lf_formatted = validdata['tgt_lf_formatted'].to(device)
            ref_lfs_formatted = [lf.to(device) for lf in validdata['ref_lfs_formatted']]

            # Encode the images if necessary
            if disp_net.has_encoder():
                # This will only be called for epi with horizontal or vertical only encoding
                if args.without_disp_stack:
                    # Stacked images should not be concatenated with the encoded EPI images
                    tgt_lf_encoded_d = disp_net.encode(tgt_lf_formatted, None)
                else:
                    # NOTE: Here we stack all 17 images, not 5. Here the images missing from the encoding,
                    # are covered in the stack. We are not using this case in the paper at all.
                    # Stacked images should be concatenated with the encoded EPI images
                    tgt_lf_encoded_d = disp_net.encode(tgt_lf_formatted, tgt_lf)
            else:
                # This will be called for focal stack and stack, where there is no encoding
                tgt_lf_encoded_d = tgt_lf_formatted

            if pose_net.has_encoder():
                tgt_lf_encoded_p, ref_lfs_encoded_p = pose_net.encode(tgt_lf_formatted, tgt_lf, ref_lfs_formatted, ref_lfs)
            else:
                tgt_lf_encoded_p = tgt_lf_formatted
                ref_lfs_encoded_p = ref_lfs_formatted
        
        # compute output
        disp = disp_net(tgt_lf_encoded_d)
        depth = 1/disp
        pose = pose_net(tgt_lf_encoded_p, ref_lfs_encoded_p)

        # compute photometric error
        intrinsics = validdata['intrinsics'].to(device)
        pose_gt_tgt_refs = validdata['pose_gt_tgt_refs'].to(device)
        metadata = validdata['metadata']
        photometric_error, warped, diff = multiwarp_photometric_loss(
            tgt_lf, ref_lfs, intrinsics, depth, pose, metadata, args.rotation_mode, args.padding_mode
        )

        photometric_error = photometric_error.item()                      # Photometric loss
        # smoothness_error = smooth_loss(depth).item()                      # Smoothness loss
        smoothness_error = total_variation_loss(depth, sum_or_mean="mean").item()   # Total variation loss
        # smoothness_error = total_variation_squared_loss(depth).item()             # Total variation loss squared version
        mean_distance_error, mean_angle_error = pose_loss(pose, pose_gt_tgt_refs).item()                      # Pose loss

        if log_outputs and i < sample_nb_to_log - 1:  # log first output of first batches
            if args.lfformat == "epi" and args.cameras_epi == "full":
                b, n, h, w = tgt_lf_formatted_v.shape
                vis_img = tgt_lf_formatted_v[0, 0, :, :].detach().cpu().numpy().reshape(1, h, w) * 0.5 + 0.5
            else:
                b, n, h, w = tgt_lf_formatted.shape
                vis_img = tgt_lf_formatted[0, 0, :, :].detach().cpu().numpy().reshape(1, h, w) * 0.5 + 0.5

            vis_depth = tensor2array(depth[0, 0, :, :], colormap='magma')
            vis_disp = tensor2array(disp[0, 0, :, :], colormap='magma')

            tb_writer.add_image('val/target_image', vis_img, epoch)
            tb_writer.add_image('val/disp', vis_disp, epoch)
            tb_writer.add_image('val/depth', vis_depth, epoch)

        loss = w1 + torch.exp(-1.0 * w1) * photometric_error + w3 + torch.exp(-1.0 * w3) * smoothness_error
        losses.update([loss, photometric_error, mean_distance_error, mean_angle_error])

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        logger.valid_bar.update(i+1)
        if i % args.print_freq == 0:
            logger.valid_writer.write('valid: Time {} Loss {}'.format(batch_time, losses))

    logger.valid_bar.update(len(val_loader))
    return losses.avg, ['val/total_loss', 'val/photometric_error', 'val/pose_error']
Exemplo n.º 17
0
def sample_from_dataset(num_per_label=10,
                        num_different_labels=1000,
                        data_in_path=None,
                        data_out_path=None,
                        bad_images=None,
                        entropy_in=None):
    # Read labels
    file_labels = os.path.join(data_in_path, 'val.txt')
    labels_true = []
    with open(file_labels, mode='r') as file_in:
        labels_true = file_in.readlines()
        labels_true = [int(y.strip().split()[1]) for y in labels_true]

    # Enumerate all image files
    files = list_files(data_in_path)
    files.sort()

    # Select samples
    classes = random.sample(range(0, len(np.unique(labels_true))),
                            num_different_labels)

    df = pandas.DataFrame(list(zip(files, labels_true)),
                          columns=['file', 'label'])

    df_entropy = pandas.read_csv(entropy_in)  # Add mean entropy of each image
    df["entropy"] = df_entropy["entropy"]

    df_final = pandas.DataFrame(columns=['file', 'label', 'entropy'])

    if bad_images is not None:  # Remove "bad images" from the dataframe
        with open(bad_images, 'r') as f_in:
            bad_files = f_in.readlines()
            bad_files = [
                data_in_path + "/" + os.path.basename(i.strip())
                for i in bad_files
            ]
            df = df[~df['file'].isin(bad_files)]

    df = df.sample(frac=1,
                   random_state=42).reset_index(drop=True)  # Shuffle rows

    for y in classes:
        df_final = df_final.append(df[df['label'] == y].head(num_per_label),
                                   ignore_index=True)

    # Create data.csv file
    with open(os.path.join(data_out_path, 'data.csv'),
              mode='w') as file_out_csv:
        header = [
            'original', 'original2', 'bim', 'bim_entropy', 'cw', 'cw_entropy',
            'label_true', 'label_adv', 'entropy'
        ]
        writer = csv.DictWriter(file_out_csv, fieldnames=header)
        writer.writeheader()

        for _, row in df_final.iterrows():
            # Load original image and apply some preprocessing (e.g. resizing)
            file_img = row.file
            label_true = row.label
            img_original = utils.open_image_as_tensor(file_img)
            _img_original = utils.tensor2array(img_original)
            #_img_original = utils.open_image_properly(file_img, arch='inception')

            # Generate random file names
            file_name = os.path.splitext(os.path.basename(file_img))[0]
            x = random.randint(42, 4242)
            file_original_out = os.path.join(data_out_path,
                                             file_name + str(x) + '.png')
            file_original2_out = os.path.join(data_out_path,
                                              file_name + str(x + 2) + '.png')
            file_bim_out = os.path.join(data_out_path,
                                        file_name + str(x - 1) + '.png')
            file_bim_entropy_out = os.path.join(
                data_out_path, file_name + str(x + 1) + '.png')
            file_cw_out = os.path.join(data_out_path,
                                       file_name + str(x + 3) + '.png')
            file_cw_entropy_out = os.path.join(data_out_path,
                                               file_name + str(x + 4) + '.png')

            # Save image
            skimage.io.imsave(file_original_out, _img_original)
            skimage.io.imsave(file_original2_out, _img_original)

            # Generate random target label
            label_adv = random.randint(0, NUM_CLASSES)
            while label_adv == label_true:
                label_adv = random.randint(0, NUM_CLASSES)

            # Create new entry in data.csv
            writer.writerow({
                'original': file_original_out,
                'original2': file_original2_out,
                'bim': file_bim_out,
                'bim_entropy': file_bim_entropy_out,
                'cw': file_cw_out,
                'cw_entropy': file_cw_entropy_out,
                'label_true': label_true,
                'label_adv': label_adv,
                'entropy': row.entropy
            })
Exemplo n.º 18
0
def validate_without_gt(args,
                        val_loader,
                        disp_net,
                        pose_exp_net,
                        epoch,
                        logger,
                        output_writers=[]):
    global device
    batch_time = AverageMeter()
    losses = AverageMeter(i=3, precision=4)
    log_outputs = len(output_writers) > 0
    w1, w2, w3 = args.photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight
    poses = np.zeros(
        ((len(val_loader) - 1) * args.batch_size * (args.sequence_length - 1),
         6))
    disp_values = np.zeros(((len(val_loader) - 1) * args.batch_size * 3))

    # switch to evaluate mode
    disp_net.eval()
    pose_exp_net.eval()

    end = time.time()
    logger.valid_bar.update(0)
    for i, (tgt_img, ref_imgs, intrinsics,
            intrinsics_inv) in enumerate(val_loader):
        tgt_img = tgt_img.to(device)
        ref_imgs = [img.to(device) for img in ref_imgs]
        intrinsics = intrinsics.to(device)
        intrinsics_inv = intrinsics_inv.to(device)

        # compute output
        disp = disp_net(tgt_img)
        depth = 1 / disp
        explainability_mask, pose = pose_exp_net(tgt_img, ref_imgs)

        loss_1, warped, diff = photometric_reconstruction_loss(
            tgt_img, ref_imgs, intrinsics, depth, explainability_mask, pose,
            args.rotation_mode, args.padding_mode)
        loss_1 = loss_1.item()
        if w2 > 0:
            loss_2 = explainability_loss(explainability_mask).item()
        else:
            loss_2 = 0
        loss_3 = smooth_loss(depth).item()

        if log_outputs and i < len(
                output_writers):  # log first output of first batches
            if epoch == 0:
                for j, ref in enumerate(ref_imgs):
                    output_writers[i].add_image('val Input {}'.format(j),
                                                tensor2array(tgt_img[0]), 0)
                    output_writers[i].add_image('val Input {}'.format(j),
                                                tensor2array(ref[0]), 1)

            log_output_tensorboard(output_writers[i], 'val', '', epoch,
                                   1. / disp, disp, warped, diff,
                                   explainability_mask)

        if log_outputs and i < len(val_loader) - 1:
            step = args.batch_size * (args.sequence_length - 1)
            poses[i * step:(i + 1) * step] = pose.cpu().view(-1, 6).numpy()
            step = args.batch_size * 3
            disp_unraveled = disp.cpu().view(args.batch_size, -1)
            disp_values[i * step:(i + 1) * step] = torch.cat([
                disp_unraveled.min(-1)[0],
                disp_unraveled.median(-1)[0],
                disp_unraveled.max(-1)[0]
            ]).numpy()

        loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3
        losses.update([loss, loss_1, loss_2])

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        logger.valid_bar.update(i + 1)
        if i % args.print_freq == 0:
            logger.valid_writer.write('valid: Time {} Loss {}'.format(
                batch_time, losses))
    if log_outputs:
        prefix = 'valid poses'
        coeffs_names = ['tx', 'ty', 'tz']
        if args.rotation_mode == 'euler':
            coeffs_names.extend(['rx', 'ry', 'rz'])
        elif args.rotation_mode == 'quat':
            coeffs_names.extend(['qx', 'qy', 'qz'])
        for i in range(poses.shape[1]):
            output_writers[0].add_histogram(
                '{} {}'.format(prefix, coeffs_names[i]), poses[:, i], epoch)
        output_writers[0].add_histogram('disp_values', disp_values, epoch)
    logger.valid_bar.update(len(val_loader))
    return losses.avg, ['Total loss', 'Photo loss', 'Exp loss']
def train(args, train_loader, disp_net, pose_net, optimizer, epoch_size, logger, tb_writer, w1, w3):
    global n_iter, device
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter(precision=4)

    # Set the networks to training mode, batch norm and dropout are handled accordingly
    disp_net.train()
    pose_net.train()

    end = time.time()
    logger.train_bar.start()
    logger.train_bar.update(0)

    for i, trainingdata in enumerate(train_loader):
        log_losses = i > 0 and n_iter % args.print_freq == 0
        log_output = args.training_output_freq > 0 and n_iter % args.training_output_freq == 0

        # measure data loading time
        data_time.update(time.time() - end)
        tgt_lf = trainingdata['tgt_lf'].to(device)
        ref_lfs = [img.to(device) for img in trainingdata['ref_lfs']]

        if args.lfformat == "epi" and args.cameras_epi == "full":
            # in this case we have separate horizontal and vertical epis
            tgt_lf_formatted_h = trainingdata['tgt_lf_formatted_h'].to(device)
            tgt_lf_formatted_v = trainingdata['tgt_lf_formatted_v'].to(device)
            ref_lfs_formatted_h = [lf.to(device) for lf in trainingdata['ref_lfs_formatted_h']]
            ref_lfs_formatted_v = [lf.to(device) for lf in trainingdata['ref_lfs_formatted_v']]

            # stacked images
            tgt_stack = trainingdata['tgt_stack'].to(device)
            ref_stacks = [lf.to(device) for lf in trainingdata['ref_stacks']]

            # Encode the epi images further
            if args.without_disp_stack:
                # Stacked images should not be concatenated with the encoded EPI images
                tgt_lf_encoded_d = disp_net.encode(tgt_lf_formatted_v, None, tgt_lf_formatted_h)
            else:
                # Stacked images should be concatenated with the encoded EPI images
                tgt_lf_encoded_d = disp_net.encode(tgt_lf_formatted_v, tgt_stack, tgt_lf_formatted_h)

            tgt_lf_encoded_p, ref_lfs_encoded_p = pose_net.encode(tgt_lf_formatted_v, tgt_stack,
                                                                  ref_lfs_formatted_v, ref_stacks,
                                                                  tgt_lf_formatted_h, ref_lfs_formatted_h)
        else:
            tgt_lf_formatted = trainingdata['tgt_lf_formatted'].to(device)
            ref_lfs_formatted = [lf.to(device) for lf in trainingdata['ref_lfs_formatted']]

            # Encode the images if necessary
            if disp_net.has_encoder():
                # This will only be called for epi with horizontal or vertical only encoding
                if args.without_disp_stack:
                    # Stacked images should not be concatenated with the encoded EPI images
                    tgt_lf_encoded_d = disp_net.encode(tgt_lf_formatted, None)
                else:
                    # Stacked images should be concatenated with the encoded EPI images
                    # NOTE: Here we stack all 17 images, not 5. Here the images missing from the encoding,
                    # are covered in the stack. We are not using this case in the paper at all.
                    tgt_lf_encoded_d = disp_net.encode(tgt_lf_formatted, tgt_lf)
            else:
                # This will be called for focal stack and stack, where there is no encoding
                tgt_lf_encoded_d = tgt_lf_formatted

            if pose_net.has_encoder():
                tgt_lf_encoded_p, ref_lfs_encoded_p = pose_net.encode(tgt_lf_formatted, tgt_lf,
                                                                      ref_lfs_formatted, ref_lfs)
            else:
                tgt_lf_encoded_p = tgt_lf_formatted
                ref_lfs_encoded_p = ref_lfs_formatted

        # compute output of networks
        disparities = disp_net(tgt_lf_encoded_d)
        depth = [1/disp for disp in disparities]
        pose = pose_net(tgt_lf_encoded_p, ref_lfs_encoded_p)

        # if i==0:
        #     tb_writer.add_graph(disp_net, tgt_lf_encoded_d)
        #     tb_writer.add_graph(pose_net, (tgt_lf_encoded_p, ref_lfs_encoded_p))

        # compute photometric error
        intrinsics = trainingdata['intrinsics'].to(device)
        pose_gt_tgt_refs = trainingdata['pose_gt_tgt_refs'].to(device)
        metadata = trainingdata['metadata']
        photometric_error, warped, diff = multiwarp_photometric_loss(
            tgt_lf, ref_lfs, intrinsics, depth, pose, metadata, args.rotation_mode, args.padding_mode
        )

        # smoothness_error = smooth_loss(depth)                             # smoothness error
        smoothness_error = total_variation_loss(depth, sum_or_mean="mean")  # total variation error
        # smoothness_error = total_variation_squared_loss(depth)            # total variation error squared version
        mean_distance_error, mean_angle_error = pose_loss(pose, pose_gt_tgt_refs)

        loss = w1 + torch.exp(-1.0 * w1) * photometric_error + w3 + torch.exp(-1.0 * w3) * smoothness_error

        if log_losses:
            tb_writer.add_scalar(tag='train/photometric_error', scalar_value=photometric_error.item(), global_step=n_iter)
            tb_writer.add_scalar(tag='train/smoothness_loss', scalar_value=smoothness_error.item(), global_step=n_iter)
            tb_writer.add_scalar(tag='train/total_loss', scalar_value=loss.item(), global_step=n_iter)
            tb_writer.add_scalar(tag='train/mean_distance_error', scalar_value=mean_distance_error.item(), global_step=n_iter)
            tb_writer.add_scalar(tag='train/mean_angle_error', scalar_value=mean_angle_error.item(), global_step=n_iter)
        if log_output:
            if args.lfformat == "epi" and args.cameras_epi == "full":
                b, n, h, w = tgt_lf_formatted_v.shape
                vis_img = tgt_lf_formatted_v[0, 0, :, :].detach().cpu().numpy().reshape(1, h, w) * 0.5 + 0.5
            else:
                b, n, h, w = tgt_lf_formatted.shape
                vis_img = tgt_lf_formatted[0, 0, :, :].detach().cpu().numpy().reshape(1, h, w) * 0.5 + 0.5

            b, n, h, w = depth[0].shape
            vis_depth = tensor2array(depth[0][0, 0, :, :], colormap='magma')
            vis_disp = tensor2array(disparities[0][0, 0, :, :], colormap='magma')
            vis_enc_f = tgt_lf_encoded_d[0, 0, :, :].detach().cpu().numpy().reshape(1, h, w) * 0.5 + 0.5
            vis_enc_b = tgt_lf_encoded_d[0, -1, :, :].detach().cpu().numpy().reshape(1, h, w) * 0.5 + 0.5

            tb_writer.add_image('train/input', vis_img, n_iter)
            tb_writer.add_image('train/encoded_front', vis_enc_f, n_iter)
            tb_writer.add_image('train/encoded_back', vis_enc_b, n_iter)
            tb_writer.add_image('train/depth', vis_depth, n_iter)
            tb_writer.add_image('train/disp', vis_disp, n_iter)

        # record loss and EPE
        losses.update(loss.item(), args.batch_size)

        # compute gradient and do Adam step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        with open(args.save_path + "/" + args.log_full, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([loss.item(), photometric_error.item(), smoothness_error.item(),
                             mean_distance_error.item(), mean_angle_error.item()])
        logger.train_bar.update(i+1)
        if i % args.print_freq == 0:
            logger.train_writer.write('Train: Time {} Data {} Loss {}'.format(batch_time, data_time, losses))
        if i >= epoch_size - 1:
            break

        n_iter += 1

    logger.train_bar.finish()
    return losses.avg[0]
Exemplo n.º 20
0
def train(args, train_loader, disp_net, pose_exp_net, optimizer, epoch_size,
          logger, train_writer):
    global n_iter, device
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter(precision=4)
    w1, w2, w3 = args.photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight

    # switch to train mode
    disp_net.train()
    pose_exp_net.train()

    end = time.time()
    logger.train_bar.update(0)

    for i, (tgt_img, ref_imgs, intrinsics,
            intrinsics_inv) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        tgt_img = tgt_img.to(device)
        ref_imgs = [img.to(device) for img in ref_imgs]
        intrinsics = intrinsics.to(device)

        # compute output
        disparities = disp_net(tgt_img)
        depth = [1 / disp for disp in disparities]
        explainability_mask, pose = pose_exp_net(tgt_img, ref_imgs)

        loss_1 = photometric_reconstruction_loss(tgt_img, ref_imgs, intrinsics,
                                                 depth, explainability_mask,
                                                 pose, args.rotation_mode,
                                                 args.padding_mode)
        if w2 > 0:
            loss_2 = explainability_loss(explainability_mask)
        else:
            loss_2 = 0
        loss_3 = smooth_loss(depth)

        loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3

        if i > 0 and n_iter % args.print_freq == 0:
            train_writer.add_scalar('photometric_error', loss_1.item(), n_iter)
            if w2 > 0:
                train_writer.add_scalar('explanability_loss', loss_2.item(),
                                        n_iter)
            train_writer.add_scalar('disparity_smoothness_loss', loss_3.item(),
                                    n_iter)
            train_writer.add_scalar('total_loss', loss.item(), n_iter)

        if args.training_output_freq > 0 and n_iter % args.training_output_freq == 0:
            train_writer.add_image('train Input', tensor2array(tgt_img[0]),
                                   n_iter)

            with torch.no_grad():
                for k, scaled_depth in enumerate(depth):
                    train_writer.add_image(
                        'train Dispnet Output Normalized {}'.format(k),
                        tensor2array(disparities[k][0],
                                     max_value=None,
                                     colormap='magma'), n_iter)
                    train_writer.add_image(
                        'train Depth Output Normalized {}'.format(k),
                        tensor2array(1 / disparities[k][0], max_value=None),
                        n_iter)
                    b, _, h, w = scaled_depth.size()
                    downscale = tgt_img.size(2) / h

                    tgt_img_scaled = F.interpolate(tgt_img, (h, w),
                                                   mode='area')
                    ref_imgs_scaled = [
                        F.interpolate(ref_img, (h, w), mode='area')
                        for ref_img in ref_imgs
                    ]

                    intrinsics_scaled = torch.cat(
                        (intrinsics[:, 0:2] / downscale, intrinsics[:, 2:]),
                        dim=1)

                    # log warped images along with explainability mask
                    for j, ref in enumerate(ref_imgs_scaled):
                        ref_warped = inverse_warp(
                            ref,
                            scaled_depth[:, 0],
                            pose[:, j],
                            intrinsics_scaled,
                            rotation_mode=args.rotation_mode,
                            padding_mode=args.padding_mode)[0]
                        train_writer.add_image(
                            'train Warped Outputs {} {}'.format(k, j),
                            tensor2array(ref_warped), n_iter)
                        train_writer.add_image(
                            'train Diff Outputs {} {}'.format(k, j),
                            tensor2array(
                                0.5 * (tgt_img_scaled[0] - ref_warped).abs()),
                            n_iter)
                        if explainability_mask[k] is not None:
                            train_writer.add_image(
                                'train Exp mask Outputs {} {}'.format(k, j),
                                tensor2array(explainability_mask[k][0, j],
                                             max_value=1,
                                             colormap='bone'), n_iter)

        # record loss and EPE
        losses.update(loss.item(), args.batch_size)

        # compute gradient and do Adam step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        with open(args.save_path / args.log_full, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([
                loss.item(),
                loss_1.item(),
                loss_2.item() if w2 > 0 else 0,
                loss_3.item()
            ])
        logger.train_bar.update(i + 1)
        if i % args.print_freq == 0:
            logger.train_writer.write('Train: Time {} Data {} Loss {}'.format(
                batch_time, data_time, losses))
        if i >= epoch_size - 1:
            break

        n_iter += 1

    return losses.avg[0]
Exemplo n.º 21
0
def train(args, train_loader, disp_net, pose_exp_net, optimizer, epoch_size,
          logger, train_writer):
    global n_iter, device
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter(precision=4)
    w1, w2, w3 = args.photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight

    # switch to train mode
    disp_net.train()
    pose_exp_net.train()

    end = time.time()
    logger.train_bar.update(0)

    for i, (tgt_img, ref_imgs, intrinsics,
            intrinsics_inv) in enumerate(train_loader):
        log_losses = i > 0 and n_iter % args.print_freq == 0
        log_output = args.training_output_freq > 0 and n_iter % args.training_output_freq == 0

        # measure data loading time
        data_time.update(time.time() - end)
        tgt_img = tgt_img.to(device)
        ref_imgs = [img.to(device) for img in ref_imgs]
        intrinsics = intrinsics.to(device)

        # compute output
        disparities = disp_net(tgt_img)
        depth = [1 / disp for disp in disparities]
        explainability_mask, pose = pose_exp_net(tgt_img, ref_imgs)

        loss_1, warped, diff = photometric_reconstruction_loss(
            tgt_img, ref_imgs, intrinsics, depth, explainability_mask, pose,
            args.rotation_mode, args.padding_mode)
        if w2 > 0:
            loss_2 = explainability_loss(explainability_mask)
        else:
            loss_2 = 0
        loss_3 = inverse_depth_smooth_loss(depth, tgt_img)

        loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3

        if log_losses:
            train_writer.add_scalar('photometric_error', loss_1.item(), n_iter)
            if w2 > 0:
                train_writer.add_scalar('explanability_loss', loss_2.item(),
                                        n_iter)
            train_writer.add_scalar('disparity_smoothness_loss', loss_3.item(),
                                    n_iter)
            train_writer.add_scalar('total_loss', loss.item(), n_iter)

        if log_output:
            train_writer.add_image('train Input', tensor2array(tgt_img[0]),
                                   n_iter)
            for k, scaled_maps in enumerate(
                    zip(depth, disparities, warped, diff,
                        explainability_mask)):
                log_output_tensorboard(train_writer, "train", k, n_iter,
                                       *scaled_maps)

        # record loss and EPE
        losses.update(loss.item(), args.batch_size)

        # compute gradient and do Adam step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        with open(args.save_path / args.log_full, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([
                loss.item(),
                loss_1.item(),
                loss_2.item() if w2 > 0 else 0,
                loss_3.item()
            ])
        logger.train_bar.update(i + 1)
        if i % args.print_freq == 0:
            logger.train_writer.write('Train: Time {} Data {} Loss {}'.format(
                batch_time, data_time, losses))
        if i >= epoch_size - 1:
            break

        n_iter += 1

    return losses.avg[0]
Exemplo n.º 22
0
def main():
    args = parser.parse_args()
    output_dir = Path(args.output_dir)

    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                            std=[0.5, 0.5, 0.5])
    valid_transform = custom_transforms.Compose(
        [custom_transforms.ArrayToTensor(), normalize])
    val_set = SequenceFolder(args.data,
                             transform=valid_transform,
                             seed=args.seed,
                             sequence_length=args.sequence_length)

    print('{} samples found in {} valid scenes'.format(len(val_set),
                                                       len(val_set.scenes)))
    val_loader = torch.utils.data.DataLoader(val_set,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    dpsnet = PSNet(args.nlabel, args.mindepth).cuda()
    weights = torch.load(args.pretrained_dps)
    dpsnet.load_state_dict(weights['state_dict'])
    dpsnet.eval()

    output_dir = Path(args.output_dir)
    if not os.path.isdir(output_dir):
        os.mkdir(output_dir)

    errors = np.zeros((2, 8, int(len(val_loader) / args.print_freq) + 1),
                      np.float32)
    with torch.no_grad():
        for ii, (tgt_img, ref_imgs, ref_poses, intrinsics, intrinsics_inv,
                 tgt_depth, scale_) in enumerate(val_loader):
            if ii % args.print_freq == 0:
                i = int(ii / args.print_freq)
                tgt_img_var = Variable(tgt_img.cuda())
                ref_imgs_var = [Variable(img.cuda()) for img in ref_imgs]
                ref_poses_var = [Variable(pose.cuda()) for pose in ref_poses]
                intrinsics_var = Variable(intrinsics.cuda())
                intrinsics_inv_var = Variable(intrinsics_inv.cuda())
                tgt_depth_var = Variable(tgt_depth.cuda())
                scale = scale_.numpy()[0]

                # compute output
                pose = torch.cat(ref_poses_var, 1)
                start = time.time()
                output_depth = dpsnet(tgt_img_var, ref_imgs_var, pose,
                                      intrinsics_var, intrinsics_inv_var)
                elps = time.time() - start
                mask = (tgt_depth <= args.maxdepth) & (
                    tgt_depth >= args.mindepth) & (tgt_depth == tgt_depth)

                tgt_disp = args.mindepth * args.nlabel / tgt_depth
                output_disp = args.mindepth * args.nlabel / output_depth

                output_disp_ = torch.squeeze(output_disp.data.cpu(), 1)
                output_depth_ = torch.squeeze(output_depth.data.cpu(), 1)

                errors[0, :,
                       i] = compute_errors_test(tgt_depth[mask] / scale,
                                                output_depth_[mask] / scale)
                errors[1, :,
                       i] = compute_errors_test(tgt_disp[mask] / scale,
                                                output_disp_[mask] / scale)

                print('Elapsed Time {} Abs Error {:.4f}'.format(
                    elps, errors[0, 0, i]))

                if args.output_print:
                    output_disp_n = (output_disp_).numpy()[0]
                    np.save(output_dir / '{:04d}{}'.format(i, '.npy'),
                            output_disp_n)
                    disp = (255 * tensor2array(torch.from_numpy(output_disp_n),
                                               max_value=args.nlabel,
                                               colormap='bone')).astype(
                                                   np.uint8)
                    imsave(output_dir / '{:04d}_disp{}'.format(i, '.png'),
                           disp)

    mean_errors = errors.mean(2)
    error_names = [
        'abs_rel', 'abs_diff', 'sq_rel', 'rms', 'log_rms', 'a1', 'a2', 'a3'
    ]
    print("{}".format(args.output_dir))
    print("Depth Results : ")
    print("{:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}".
          format(*error_names))
    print(
        "{:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}"
        .format(*mean_errors[0]))

    print("Disparity Results : ")
    print("{:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}".
          format(*error_names))
    print(
        "{:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}"
        .format(*mean_errors[1]))

    np.savetxt(output_dir / 'errors.csv',
               mean_errors,
               fmt='%1.4f',
               delimiter=',')
def main():
    global args
    args = parser.parse_args()
    save_path = Path(args.name)
    args.save_path = 'checkpoints'/save_path 
    print('=> will save everything to {}'.format(args.save_path))
    args.save_path.makedirs_p()


    
    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    train_transform = custom_transforms.Compose([
          # custom_transforms.RandomRotate(),
          # custom_transforms.RandomHorizontalFlip(),
          # custom_transforms.RandomScaleCrop(),
          custom_transforms.ArrayToTensor(),
          normalize  ])

    training_writer = SummaryWriter(args.save_path)

    intrinsics = np.array([542.822841, 0, 315.593520, 0, 542.576870, 237.756098, 0, 0, 1]).astype(np.float32).reshape((3, 3))
    
    inference_set = SequenceFolder(
        root = args.dataset_dir,
        intrinsics = intrinsics,
        transform=train_transform,
        train=False,
        sequence_length=args.sequence_length
    )

    print('{} samples found in {} train scenes'.format(len(inference_set), len(inference_set.scenes)))
    inference_loader = torch.utils.data.DataLoader(
        inference_set, batch_size=1, shuffle=False,
        num_workers=args.workers, pin_memory=True, drop_last=True)

    print("=> creating model")
    mask_net = MaskResNet6.MaskResNet6().cuda()
    pose_net = PoseNetB6.PoseNetB6().cuda()
    mask_net = torch.nn.DataParallel(mask_net)

    masknet_weights = torch.load(args.pretrained_mask)# 
    posenet_weights = torch.load(args.pretrained_pose)
    mask_net.load_state_dict(masknet_weights['state_dict'])
    # pose_net.load_state_dict(posenet_weights['state_dict'])
    pose_net.eval()
    mask_net.eval()

    # training 

    for i, (rgb_tgt_img, rgb_ref_imgs, intrinsics, intrinsics_inv) in enumerate(tqdm(inference_loader)):
        #print(rgb_tgt_img)
        tgt_img_var = Variable(rgb_tgt_img.cuda(), volatile=True)
        ref_imgs_var = [Variable(img.cuda(), volatile=True) for img in rgb_ref_imgs]
        intrinsics_var = Variable(intrinsics.cuda(), volatile=True)
        intrinsics_inv_var = Variable(intrinsics_inv.cuda(), volatile=True)

        explainability_mask = mask_net(tgt_img_var, ref_imgs_var)
        
        after_mask = tensor2array(ref_imgs_var[0][0]*explainability_mask[0,0]).transpose(1,2,0)
        x = Image.fromarray(np.uint8(after_mask*255))
        x.save(args.save_path/str(i).zfill(3)+'multi.png')
        
        explainability_mask = (explainability_mask[0,0].detach().cpu()).numpy()
        # print(explainability_mask.shape)
        y = Image.fromarray(np.uint8(explainability_mask*255))
        y.save(args.save_path/str(i).zfill(3)+'mask.png')
def main():
    args = parser.parse_args()
    pdb.set_trace()
    if not (args.output_disp or args.output_depth):
        print('You must at least output one value !')
        return

    argoverse_loader = ArgoverseTrackingLoader(args.argoverse_data_path)
    camera = argoverse_loader.CAMERA_LIST[0]
    argoverse_data = argoverse_loader.get(args.argoverse_log)
    num_lidar = len(argoverse_data.get_image_list_sync(camera))

    disp_net = DispResNet(args.resnet_layers, False).to(device)
    weights = torch.load(args.pretrained, map_location=device)
    disp_net.load_state_dict(weights['state_dict'])
    disp_net.eval()

    #dataset_dir = Path(args.dataset_dir)
    output_dir = Path(args.output_dir)
    output_dir.makedirs_p()

    # if args.dataset_list is not None:
    #     with open(args.dataset_list, 'r') as f:
    #         test_files = [dataset_dir/file for file in f.read().splitlines()]
    # else:
    #     test_files = sum([dataset_dir.files('*.{}'.format(ext)) for ext in args.img_exts], [])

    # print('{} files to test'.format(len(test_files)))

    for frame in tqdm(range(0, num_lidar - 1)):

        file = argoverse_data.get_image_sync(frame, camera, load=False)

        img = imread(file).astype(np.float32)

        h, w, _ = img.shape
        if (not args.no_resize) and (h != args.img_height
                                     or w != args.img_width):
            img = imresize(img, (args.img_height, args.img_width)).astype(
                np.float32)
        img = np.transpose(img, (2, 0, 1))

        tensor_img = torch.from_numpy(img).unsqueeze(0)
        tensor_img = ((tensor_img / 255 - 0.45) / 0.225).to(device)

        output = disp_net(tensor_img)[0]

        file_path, file_ext = file.relpath(args.dataset_dir).splitext()
        file_name = '-'.join(file_path.splitall())

        if args.output_disp:
            disp = (255 * tensor2array(output, max_value=None,
                                       colormap='bone')).astype(np.uint8)
            imsave(output_dir / '{}_disp{}'.format(file_name, file_ext),
                   np.transpose(disp, (1, 2, 0)))
        if args.output_depth:
            depth = 1 / output
            depth = (255 * tensor2array(
                depth, max_value=10, colormap='rainbow')).astype(np.uint8)
            imsave(output_dir / '{}_depth{}'.format(file_name, file_ext),
                   np.transpose(depth, (1, 2, 0)))
Exemplo n.º 25
0
def train(train_loader, mask_net, pose_net, optimizer, epoch_size,
          train_writer):
    global args, n_iter
    w1 = args.smooth_loss_weight
    w2 = args.mask_loss_weight
    w3 = args.consensus_loss_weight
    w4 = args.pose_loss_weight

    mask_net.train()
    pose_net.train()
    average_loss = 0
    for i, (rgb_tgt_img, rgb_ref_imgs, depth_tgt_img, depth_ref_imgs,
            mask_tgt_img, mask_ref_imgs, intrinsics, intrinsics_inv,
            pose_list) in enumerate(tqdm(train_loader)):
        rgb_tgt_img_var = Variable(rgb_tgt_img.cuda())
        rgb_ref_imgs_var = [Variable(img.cuda()) for img in rgb_ref_imgs]
        depth_tgt_img_var = Variable(depth_tgt_img.unsqueeze(1).cuda())
        depth_ref_imgs_var = [
            Variable(img.unsqueeze(1).cuda()) for img in depth_ref_imgs
        ]
        mask_tgt_img_var = Variable(mask_tgt_img.cuda())
        mask_ref_imgs_var = [Variable(img.cuda()) for img in mask_ref_imgs]

        mask_tgt_img_var = torch.where(mask_tgt_img_var > 0,
                                       torch.ones_like(mask_tgt_img_var),
                                       torch.zeros_like(mask_tgt_img_var))
        mask_ref_imgs_var = [
            torch.where(img > 0, torch.ones_like(img), torch.zeros_like(img))
            for img in mask_ref_imgs_var
        ]

        intrinsics_var = Variable(intrinsics.cuda())
        intrinsics_inv_var = Variable(intrinsics_inv.cuda())
        # pose_list_var = [Variable(one_pose.float().cuda()) for one_pose in pose_list]

        explainability_mask = mask_net(rgb_tgt_img_var, rgb_ref_imgs_var)

        # print(explainability_mask[0].size()) #torch.Size([4, 2, 384, 512])
        # print()
        pose = pose_net(rgb_tgt_img_var, rgb_ref_imgs_var)
        # loss 1: smoothness loss
        loss1 = smooth_loss(explainability_mask)

        # loss 2: explainability loss
        loss2 = explainability_loss(explainability_mask)

        # loss 3 consensus loss (the mask from networks and the mask from residual)
        loss3 = consensus_loss(explainability_mask[0], mask_ref_imgs_var)

        # loss 4 pose loss
        valid_pixle_mask = [
            torch.where(depth_ref_imgs_var[0] == 0,
                        torch.zeros_like(depth_tgt_img_var),
                        torch.ones_like(depth_tgt_img_var)),
            torch.where(depth_ref_imgs_var[1] == 0,
                        torch.zeros_like(depth_tgt_img_var),
                        torch.ones_like(depth_tgt_img_var))
        ]  # zero is invalid

        loss4, ref_img_warped, diff = pose_loss(
            valid_pixle_mask, mask_ref_imgs_var, rgb_tgt_img_var,
            rgb_ref_imgs_var, intrinsics_var, intrinsics_inv_var,
            depth_tgt_img_var, pose)

        # compute gradient and do Adam step
        loss = w1 * loss1 + w2 * loss2 + w3 * loss3 + w4 * loss4
        average_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # visualization in tensorboard
        if i > 0 and n_iter % args.print_freq == 0:
            train_writer.add_scalar('smoothness loss', loss1.item(), n_iter)
            train_writer.add_scalar('explainability loss', loss2.item(),
                                    n_iter)
            train_writer.add_scalar('consensus loss', loss3.item(), n_iter)
            train_writer.add_scalar('pose loss', loss4.item(), n_iter)
            train_writer.add_scalar('total loss', loss.item(), n_iter)
        if n_iter % (args.training_output_freq) == 0:
            train_writer.add_image('train Input',
                                   tensor2array(rgb_tgt_img_var[0]), n_iter)
            train_writer.add_image(
                'train Exp mask Outputs ',
                tensor2array(explainability_mask[0][0, 0].data.cpu(),
                             max_value=1,
                             colormap='bone'), n_iter)
            train_writer.add_image(
                'train gt mask ',
                tensor2array(mask_tgt_img[0].data.cpu(),
                             max_value=1,
                             colormap='bone'), n_iter)
            train_writer.add_image(
                'train depth ',
                tensor2array(depth_tgt_img[0].data.cpu(),
                             max_value=1,
                             colormap='bone'), n_iter)
            train_writer.add_image(
                'train after mask',
                tensor2array(rgb_tgt_img_var[0] *
                             explainability_mask[0][0, 0]), n_iter)
            train_writer.add_image('train diff', tensor2array(diff[0]), n_iter)
            train_writer.add_image('train warped img',
                                   tensor2array(ref_img_warped[0]), n_iter)

        n_iter += 1

    return average_loss / i
Exemplo n.º 26
0
def main():
    args = parser.parse_args()
    if not(args.output_disp or args.output_depth):
        # print("args.output_disp:\n", args.output_disp)
        # print("args.output_depth:\n", args.output_depth)
        print('You must at least output one value !')
        return

    disp_net = DispNetS().to(device)
    weights = torch.load(args.pretrained)
    disp_net.load_state_dict(weights['state_dict'])
    disp_net.eval()

    dataset_dir = Path(args.dataset_dir)
    output_dir = Path(args.output_dir)
    output_dir.makedirs_p()
    print("dataset_list:\n", args.dataset_list)
    if args.dataset_list is not None:
        with open(args.dataset_list, 'r') as f:
            test_files = [dataset_dir/file for file in f.read().splitlines()]
    else:
        print("Else!")
        test_files = sum([list(dataset_dir.walkfiles('*.{}'.format(ext))) for ext in args.img_exts], [])
    print(dataset_dir)
    print("dataset_list:\n", args.dataset_list)
    print("test_files:\n", test_files)
    print('{} files to test'.format(len(test_files)))

    for file in tqdm(test_files):
        # print("file:\n", file)
        img = imread(file)

        h,w,_ = img.shape
        if (not args.no_resize) and (h != args.img_height or w != args.img_width):
            img = np.array(Image.fromarray(img).imresize((args.img_height, args.img_width)))
        img = np.transpose(img, (2, 0, 1))

        tensor_img = torch.from_numpy(img.astype(np.float32)).unsqueeze(0)
        tensor_img = ((tensor_img/255 - 0.5)/0.5).to(device)

        output = disp_net(tensor_img)[0]
        file_path, file_ext = file.relpath(args.dataset_dir).splitext()
        print(file_path)
        print(file_path.splitall())
        file_name = '-'.join(file_path.splitall()[1:])
        print(file_name)

        if args.output_disp:
            disp = (255*tensor2array(output, max_value=None, colormap='bone')).astype(np.uint8)
            # imsave(output_dir/'{}_disp{}'.format(file_name, file_ext), np.transpose(disp, (1,2,0)))
        if args.output_depth:
            depth = 1/output
            # depth = (255*tensor2array(depth, max_value=10, colormap='rainbow')).astype(np.uint8)
            # depth = (2550*tensor2array(depth, max_value=10, colormap='bone')).astype(np.uint8)
            # print(depth.shape)
            # imsave(output_dir/'{}_depth{}'.format(file_name, file_ext), np.transpose(depth, (1,2,0)))
            depth = depth.to(device)
            errors = np.zeros((2, 9, len(test_files)), np.float32)
            mean_errors = errors.mean(2)

            gt = tifffile.imread('/home/zyd/respository/sfmlearner_results/endo_testset/left_depth_map_d4k1_000000.tiff')
            gt = gt[:, :, 2]

            abs_diff, abs_rel, sq_rel, a1, a2, a3 = 0,0,0,0,0,0
            if 1:
                crop_mask = gt[0] != gt[0]
                y1,y2 = int(0.40810811 * 1024), int(0.99189189 * 1024)
                x1,x2 = int(0.03594771 * 1280), int(0.96405229 * 1280)
                crop_mask[y1:y2,x1:x2] = 1

            for current_gt, current_pred in zip(gt, pred):
                valid = (current_gt > 0) & (current_gt < 80)
            if 1:
                valid = valid & crop_mask

            valid_gt = current_gt[valid]
            valid_pred = current_pred[valid].clamp(1e-3, 80)

            valid_pred = valid_pred * torch.median(valid_gt)/torch.median(valid_pred)

            thresh = torch.max((valid_gt / valid_pred), (valid_pred / valid_gt))
            a1 += (thresh < 1.25).float().mean()
            a2 += (thresh < 1.25 ** 2).float().mean()
            a3 += (thresh < 1.25 ** 3).float().mean()

            abs_diff += torch.mean(torch.abs(valid_gt - valid_pred))
            abs_rel += torch.mean(torch.abs(valid_gt - valid_pred) / valid_gt)

            sq_rel += torch.mean(((valid_gt - valid_pred)**2) / valid_gt)
            
            error_names = ['abs_diff', 'abs_rel','sq_rel','rms','log_rms', 'abs_log', 'a1','a2','a3']
            

            
            print("Results with scale factor determined by GT/prediction ratio (like the original paper) : ")
            print("{:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}".format(*error_names))
            print("{:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}".format(*mean_errors[1]))
Exemplo n.º 27
0
def main():
    global args
    args = parser.parse_args()

    args.pretrained_disp = Path(args.pretrained_disp)
    args.pretrained_pose = Path(args.pretrained_pose)
    args.pretrained_mask = Path(args.pretrained_mask)
    args.pretrained_flow = Path(args.pretrained_flow)

    if args.output_dir is not None:
        args.output_dir = Path(args.output_dir)
        args.output_dir.makedirs_p()

        image_dir = args.output_dir / 'images'
        gt_dir = args.output_dir / 'gt'
        mask_dir = args.output_dir / 'mask'
        viz_dir = args.output_dir / 'viz'
        rigidity_mask_dir = args.output_dir / 'rigidity'
        rigidity_census_mask_dir = args.output_dir / 'rigidity_census'
        explainability_mask_dir = args.output_dir / 'explainability'

        image_dir.makedirs_p()
        gt_dir.makedirs_p()
        mask_dir.makedirs_p()
        viz_dir.makedirs_p()
        rigidity_mask_dir.makedirs_p()
        rigidity_census_mask_dir.makedirs_p()
        explainability_mask_dir.makedirs_p()

        output_writer = SummaryWriter(args.output_dir)

    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                            std=[0.5, 0.5, 0.5])
    flow_loader_h, flow_loader_w = 256, 832
    valid_flow_transform = custom_transforms.Compose([
        custom_transforms.Scale(h=flow_loader_h, w=flow_loader_w),
        custom_transforms.ArrayToTensor(), normalize
    ])
    val_flow_set = ValidationMask(root=args.kitti_dir,
                                  sequence_length=5,
                                  transform=valid_flow_transform)

    val_loader = torch.utils.data.DataLoader(val_flow_set,
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=2,
                                             pin_memory=True,
                                             drop_last=True)

    disp_net = getattr(models, args.dispnet)().cuda()
    pose_net = getattr(models, args.posenet)(nb_ref_imgs=4).cuda()
    mask_net = getattr(models, args.masknet)(nb_ref_imgs=4).cuda()
    flow_net = getattr(models, args.flownet)(nlevels=args.nlevels).cuda()

    dispnet_weights = torch.load(args.pretrained_disp)
    posenet_weights = torch.load(args.pretrained_pose)
    masknet_weights = torch.load(args.pretrained_mask)
    flownet_weights = torch.load(args.pretrained_flow)
    disp_net.load_state_dict(dispnet_weights['state_dict'])
    pose_net.load_state_dict(posenet_weights['state_dict'])
    flow_net.load_state_dict(flownet_weights['state_dict'])
    mask_net.load_state_dict(masknet_weights['state_dict'])

    disp_net.eval()
    pose_net.eval()
    mask_net.eval()
    flow_net.eval()

    error_names = ['tp_0', 'fp_0', 'fn_0', 'tp_1', 'fp_1', 'fn_1']
    errors = AverageMeter(i=len(error_names))
    errors_census = AverageMeter(i=len(error_names))
    errors_bare = AverageMeter(i=len(error_names))

    for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv, flow_gt, obj_map_gt,
            semantic_map_gt) in enumerate(tqdm(val_loader)):
        tgt_img_var = Variable(tgt_img.cuda(), volatile=True)
        ref_imgs_var = [
            Variable(img.cuda(), volatile=True) for img in ref_imgs
        ]
        intrinsics_var = Variable(intrinsics.cuda(), volatile=True)
        intrinsics_inv_var = Variable(intrinsics_inv.cuda(), volatile=True)

        flow_gt_var = Variable(flow_gt.cuda(), volatile=True)
        obj_map_gt_var = Variable(obj_map_gt.cuda(), volatile=True)

        disp = disp_net(tgt_img_var)
        depth = 1 / disp
        pose = pose_net(tgt_img_var, ref_imgs_var)
        explainability_mask = mask_net(tgt_img_var, ref_imgs_var)
        if args.flownet in ['Back2Future']:
            flow_fwd, flow_bwd, _ = flow_net(tgt_img_var, ref_imgs_var[1:3])
        else:
            flow_fwd = flow_net(tgt_img_var, ref_imgs_var[2])
        flow_cam = pose2flow(depth.squeeze(1), pose[:, 2], intrinsics_var,
                             intrinsics_inv_var)

        rigidity_mask = 1 - (1 - explainability_mask[:, 1]) * (
            1 - explainability_mask[:, 2]).unsqueeze(1) > 0.5
        rigidity_mask_census_soft = (flow_cam - flow_fwd).pow(2).sum(
            dim=1).unsqueeze(1).sqrt()  #.normalize()
        rigidity_mask_census_soft = 1 - rigidity_mask_census_soft / rigidity_mask_census_soft.max(
        )
        rigidity_mask_census = rigidity_mask_census_soft > args.THRESH

        rigidity_mask_combined = 1 - (
            1 - rigidity_mask.type_as(explainability_mask)) * (
                1 - rigidity_mask_census.type_as(explainability_mask))

        flow_fwd_non_rigid = (1 - rigidity_mask_combined).type_as(
            flow_fwd).expand_as(flow_fwd) * flow_fwd
        flow_fwd_rigid = rigidity_mask_combined.type_as(flow_fwd).expand_as(
            flow_fwd) * flow_cam
        total_flow = flow_fwd_rigid + flow_fwd_non_rigid

        obj_map_gt_var_expanded = obj_map_gt_var.unsqueeze(1).type_as(flow_fwd)

        tgt_img_np = tgt_img[0].numpy()
        rigidity_mask_combined_np = rigidity_mask_combined.cpu().data[0].numpy(
        )
        rigidity_mask_census_np = rigidity_mask_census.cpu().data[0].numpy()
        rigidity_mask_bare_np = rigidity_mask.cpu().data[0].numpy()

        gt_mask_np = obj_map_gt[0].numpy()
        semantic_map_np = semantic_map_gt[0].numpy()

        _errors = mask_error(gt_mask_np, semantic_map_np,
                             rigidity_mask_combined_np[0])
        _errors_census = mask_error(gt_mask_np, semantic_map_np,
                                    rigidity_mask_census_np[0])
        _errors_bare = mask_error(gt_mask_np, semantic_map_np,
                                  rigidity_mask_bare_np[0])

        errors.update(_errors)
        errors_census.update(_errors_census)
        errors_bare.update(_errors_bare)

        if args.output_dir is not None:
            np.save(image_dir / str(i).zfill(3), tgt_img_np)
            np.save(gt_dir / str(i).zfill(3), gt_mask_np)
            np.save(mask_dir / str(i).zfill(3), rigidity_mask_combined_np)
            np.save(rigidity_mask_dir / str(i).zfill(3),
                    rigidity_mask.cpu().data[0].numpy())
            np.save(rigidity_census_mask_dir / str(i).zfill(3),
                    rigidity_mask_census.cpu().data[0].numpy())
            np.save(explainability_mask_dir / str(i).zfill(3),
                    explainability_mask[:, 1].cpu().data[0].numpy())
            # rigidity_mask_dir rigidity_mask.numpy()
            # rigidity_census_mask_dir rigidity_mask_census.numpy()

        if (args.output_dir is not None) and i % 10 == 0:
            ind = int(i // 10)
            output_writer.add_image(
                'val Dispnet Output Normalized',
                tensor2array(disp.data[0].cpu(),
                             max_value=None,
                             colormap='bone'), ind)
            output_writer.add_image('val Input',
                                    tensor2array(tgt_img[0].cpu()), i)
            output_writer.add_image(
                'val Total Flow Output',
                flow_to_image(tensor2array(total_flow.data[0].cpu())), ind)
            output_writer.add_image(
                'val Rigid Flow Output',
                flow_to_image(tensor2array(flow_fwd_rigid.data[0].cpu())), ind)
            output_writer.add_image(
                'val Non-rigid Flow Output',
                flow_to_image(tensor2array(flow_fwd_non_rigid.data[0].cpu())),
                ind)
            output_writer.add_image(
                'val Rigidity Mask',
                tensor2array(rigidity_mask.data[0].cpu(),
                             max_value=1,
                             colormap='bone'), ind)
            output_writer.add_image(
                'val Rigidity Mask Census',
                tensor2array(rigidity_mask_census.data[0].cpu(),
                             max_value=1,
                             colormap='bone'), ind)
            output_writer.add_image(
                'val Rigidity Mask Combined',
                tensor2array(rigidity_mask_combined.data[0].cpu(),
                             max_value=1,
                             colormap='bone'), ind)

        if args.output_dir is not None:
            tgt_img_viz = tensor2array(tgt_img[0].cpu())
            depth_viz = tensor2array(disp.data[0].cpu(),
                                     max_value=None,
                                     colormap='magma')
            mask_viz = tensor2array(rigidity_mask_census_soft.data[0].cpu(),
                                    max_value=1,
                                    colormap='bone')
            row2_viz = flow_to_image(
                np.hstack((tensor2array(flow_cam.data[0].cpu()),
                           tensor2array(flow_fwd_non_rigid.data[0].cpu()),
                           tensor2array(total_flow.data[0].cpu()))))

            row1_viz = np.hstack((tgt_img_viz, depth_viz, mask_viz))
            ####### sửa 2 cái vstack thành hstack ###############
            viz3 = np.hstack(
                (255 * tgt_img_viz, 255 * depth_viz, 255 * mask_viz,
                 flow_to_image(
                     np.hstack((tensor2array(flow_fwd_non_rigid.data[0].cpu()),
                                tensor2array(total_flow.data[0].cpu()))))))
            ########################################################
            ######## code tự thêm ####################
            row1_viz = np.transpose(row1_viz, (1, 2, 0))
            row2_viz = np.transpose(row2_viz, (1, 2, 0))
            viz3 = np.transpose(viz3, (1, 2, 0))
            ##########################################

            row1_viz_im = Image.fromarray((255 * row1_viz).astype('uint8'))
            row2_viz_im = Image.fromarray((row2_viz).astype('uint8'))
            viz3_im = Image.fromarray(viz3.astype('uint8'))

            row1_viz_im.save(viz_dir / str(i).zfill(3) + '01.png')
            row2_viz_im.save(viz_dir / str(i).zfill(3) + '02.png')
            viz3_im.save(viz_dir / str(i).zfill(3) + '03.png')

    bg_iou = errors.sum[0] / (errors.sum[0] + errors.sum[1] + errors.sum[2])
    fg_iou = errors.sum[3] / (errors.sum[3] + errors.sum[4] + errors.sum[5])
    avg_iou = (bg_iou + fg_iou) / 2

    bg_iou_census = errors_census.sum[0] / (
        errors_census.sum[0] + errors_census.sum[1] + errors_census.sum[2])
    fg_iou_census = errors_census.sum[3] / (
        errors_census.sum[3] + errors_census.sum[4] + errors_census.sum[5])
    avg_iou_census = (bg_iou_census + fg_iou_census) / 2

    bg_iou_bare = errors_bare.sum[0] / (
        errors_bare.sum[0] + errors_bare.sum[1] + errors_bare.sum[2])
    fg_iou_bare = errors_bare.sum[3] / (
        errors_bare.sum[3] + errors_bare.sum[4] + errors_bare.sum[5])
    avg_iou_bare = (bg_iou_bare + fg_iou_bare) / 2

    print("Results Full Model")
    print("\t {:>10}, {:>10}, {:>10} ".format('iou', 'bg_iou', 'fg_iou'))
    print("Errors \t {:10.4f}, {:10.4f} {:10.4f}".format(
        avg_iou, bg_iou, fg_iou))

    print("Results Census only")
    print("\t {:>10}, {:>10}, {:>10} ".format('iou', 'bg_iou', 'fg_iou'))
    print("Errors \t {:10.4f}, {:10.4f} {:10.4f}".format(
        avg_iou_census, bg_iou_census, fg_iou_census))

    print("Results Bare")
    print("\t {:>10}, {:>10}, {:>10} ".format('iou', 'bg_iou', 'fg_iou'))
    print("Errors \t {:10.4f}, {:10.4f} {:10.4f}".format(
        avg_iou_bare, bg_iou_bare, fg_iou_bare))
Exemplo n.º 28
0
def validate(val_loader,
             disp_net,
             pose_exp_net,
             epoch,
             logger,
             output_writers=[]):
    global args
    batch_time = AverageMeter()
    losses = AverageMeter(i=3, precision=4)
    log_outputs = len(output_writers) > 0
    w1, w2, w3 = args.photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight
    poses = np.zeros(
        ((len(val_loader) - 1) * args.batch_size * (args.sequence_length - 1),
         6))

    # switch to evaluate mode
    disp_net.eval()
    pose_exp_net.eval()

    end = time.time()

    for i, (tgt_img, ref_imgs, intrinsics,
            intrinsics_inv) in enumerate(val_loader):
        tgt_img_var = Variable(tgt_img.cuda(), volatile=True)
        ref_imgs_var = [
            Variable(img.cuda(), volatile=True) for img in ref_imgs
        ]
        intrinsics_var = Variable(intrinsics.cuda(), volatile=True)
        intrinsics_inv_var = Variable(intrinsics_inv.cuda(), volatile=True)

        # compute output
        disp = disp_net(tgt_img_var)
        depth = 1 / disp
        explainability_mask, pose = pose_exp_net(tgt_img_var, ref_imgs_var)

        loss_1 = photometric_reconstruction_loss(tgt_img_var, ref_imgs_var,
                                                 intrinsics_var,
                                                 intrinsics_inv_var, depth,
                                                 explainability_mask, pose)
        loss_1 = loss_1.data[0]
        if w2 > 0:
            loss_2 = explainability_loss(explainability_mask).data[0]
        else:
            loss_2 = 0
        loss_3 = smooth_loss(disp).data[0]

        if log_outputs and i % 100 == 0 and i / 100 < len(
                output_writers):  # log first output of every 100 batch
            index = int(i // 100)
            if epoch == 0:
                for j, ref in enumerate(ref_imgs):
                    output_writers[index].add_image('val Input {}'.format(j),
                                                    tensor2array(tgt_img[0]),
                                                    0)
                    output_writers[index].add_image('val Input {}'.format(j),
                                                    tensor2array(ref[0]), 1)

            output_writers[index].add_image(
                'val Dispnet Output Normalized',
                tensor2array(disp.data[0].cpu(),
                             max_value=None,
                             colormap='bone'), epoch)
            output_writers[index].add_image(
                'val Depth Output',
                tensor2array(1. / disp.data[0].cpu(), max_value=10), epoch)
            # log warped images along with explainability mask
            for j, ref in enumerate(ref_imgs_var):
                ref_warped = inverse_warp(ref[:1], depth[:1, 0], pose[:1, j],
                                          intrinsics_var[:1],
                                          intrinsics_inv_var[:1])[0]
                output_writers[index].add_image(
                    'val Warped Outputs {}'.format(j),
                    tensor2array(ref_warped.data.cpu()), epoch)
                output_writers[index].add_image(
                    'val Diff Outputs {}'.format(j),
                    tensor2array(
                        0.5 * (tgt_img_var[0] - ref_warped).abs().data.cpu()),
                    epoch)
                if explainability_mask is not None:
                    output_writers[index].add_image(
                        'val Exp mask Outputs {}'.format(j),
                        tensor2array(explainability_mask[0, j].data.cpu(),
                                     max_value=1,
                                     colormap='bone'), epoch)

        if log_outputs and i < len(val_loader) - 1:
            step = args.batch_size * (args.sequence_length - 1)
            poses[i * step:(i + 1) * step] = pose.data.cpu().view(-1,
                                                                  6).numpy()

        loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3
        losses.update([loss, loss_1, loss_2])

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        logger.valid_bar.update(i)
        if i % args.print_freq == 0:
            logger.valid_writer.write('valid: Time {} Loss {}'.format(
                batch_time, losses))
    if log_outputs:
        output_writers[0].add_histogram('val poses_tx', poses[:, 0], epoch)
        output_writers[0].add_histogram('val poses_ty', poses[:, 1], epoch)
        output_writers[0].add_histogram('val poses_tz', poses[:, 2], epoch)
        output_writers[0].add_histogram('val poses_rx', poses[:, 3], epoch)
        output_writers[0].add_histogram('val poses_ry', poses[:, 4], epoch)
        output_writers[0].add_histogram('val poses_rz', poses[:, 5], epoch)

    return losses.avg
Exemplo n.º 29
0
def validate_without_gt(args,
                        val_loader,
                        disp_net,
                        pose_exp_net,
                        epoch,
                        logger,
                        output_writers=[]):
    batch_time = AverageMeter()
    losses = AverageMeter(i=3, precision=4)
    log_outputs = len(output_writers) > 0
    w1, w2, w3 = args.photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight
    poses = np.zeros(
        ((len(val_loader) - 1) * args.batch_size * (args.sequence_length - 1),
         6))
    disp_values = np.zeros(((len(val_loader) - 1) * args.batch_size * 3))

    # switch to evaluate mode
    disp_net.eval()
    pose_exp_net.eval()

    end = time.time()
    logger.valid_bar.update(0)
    for i, (tgt_img, ref_imgs, intrinsics,
            intrinsics_inv) in enumerate(val_loader):
        tgt_img_var = Variable(tgt_img.cuda(), volatile=True)
        ref_imgs_var = [
            Variable(img.cuda(), volatile=True) for img in ref_imgs
        ]
        intrinsics_var = Variable(intrinsics.cuda(), volatile=True)
        intrinsics_inv_var = Variable(intrinsics_inv.cuda(), volatile=True)

        # compute output
        disp = disp_net(tgt_img_var)
        depth = 1 / disp
        explainability_mask, pose = pose_exp_net(tgt_img_var, ref_imgs_var)

        loss_1 = photometric_reconstruction_loss(tgt_img_var, ref_imgs_var,
                                                 intrinsics_var,
                                                 intrinsics_inv_var, depth,
                                                 explainability_mask, pose,
                                                 args.rotation_mode,
                                                 args.padding_mode)
        loss_1 = loss_1.data[0]
        if w2 > 0:
            loss_2 = explainability_loss(explainability_mask).data[0]
        else:
            loss_2 = 0
        loss_3 = smooth_loss(disp).data[0]

        if log_outputs and i % 100 == 0 and i / 100 < len(
                output_writers):  # log first output of every 100 batch
            index = int(i // 100)
            if epoch == 0:
                for j, ref in enumerate(ref_imgs):
                    output_writers[index].add_image('val Input {}'.format(j),
                                                    tensor2array(tgt_img[0]),
                                                    0)
                    output_writers[index].add_image('val Input {}'.format(j),
                                                    tensor2array(ref[0]), 1)

            output_writers[index].add_image(
                'val Dispnet Output Normalized',
                tensor2array(disp.data[0].cpu(),
                             max_value=None,
                             colormap='bone'), epoch)
            output_writers[index].add_image(
                'val Depth Output',
                tensor2array(1. / disp.data[0].cpu(), max_value=10), epoch)
            # log warped images along with explainability mask
            for j, ref in enumerate(ref_imgs_var):
                ref_warped = inverse_warp(ref[:1],
                                          depth[:1, 0],
                                          pose[:1, j],
                                          intrinsics_var[:1],
                                          intrinsics_inv_var[:1],
                                          rotation_mode=args.rotation_mode,
                                          padding_mode=args.padding_mode)[0]

                output_writers[index].add_image(
                    'val Warped Outputs {}'.format(j),
                    tensor2array(ref_warped.data.cpu()), epoch)
                output_writers[index].add_image(
                    'val Diff Outputs {}'.format(j),
                    tensor2array(
                        0.5 * (tgt_img_var[0] - ref_warped).abs().data.cpu()),
                    epoch)
                if explainability_mask is not None:
                    output_writers[index].add_image(
                        'val Exp mask Outputs {}'.format(j),
                        tensor2array(explainability_mask[0, j].data.cpu(),
                                     max_value=1,
                                     colormap='bone'), epoch)

        if log_outputs and i < len(val_loader) - 1:
            step = args.batch_size * (args.sequence_length - 1)
            poses[i * step:(i + 1) * step] = pose.data.cpu().view(-1,
                                                                  6).numpy()
            step = args.batch_size * 3
            disp_unraveled = disp.data.cpu().view(args.batch_size, -1)
            disp_values[i * step:(i + 1) * step] = torch.cat([
                disp_unraveled.min(-1)[0],
                disp_unraveled.median(-1)[0],
                disp_unraveled.max(-1)[0]
            ]).numpy()

        loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3
        losses.update([loss, loss_1, loss_2])

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        logger.valid_bar.update(i + 1)
        if i % args.print_freq == 0:
            logger.valid_writer.write('valid: Time {} Loss {}'.format(
                batch_time, losses))
    if log_outputs:
        prefix = 'valid poses'
        coeffs_names = ['tx', 'ty', 'tz']
        if args.rotation_mode == 'euler':
            coeffs_names.extend(['rx', 'ry', 'rz'])
        elif args.rotation_mode == 'quat':
            coeffs_names.extend(['qx', 'qy', 'qz'])
        for i in range(poses.shape[1]):
            output_writers.add_histogram(
                '{} {}'.format(prefix, coeffs_names[i]), poses[:, i], epoch)
        output_writers[0].add_histogram('disp_values', disp_values, epoch)
    logger.valid_bar.update(len(val_loader))
    return losses.avg, ['Total loss', 'Photo loss', 'Exp loss']
Exemplo n.º 30
0
def main():
    global args
    args = parser.parse_args()
    args.pretrained_disp = Path(args.pretrained_disp)
    args.pretrained_pose = Path(args.pretrained_pose)
    args.pretrained_mask = Path(args.pretrained_mask)
    args.pretrained_flow = Path(args.pretrained_flow)

    if args.output_dir is not None:
        args.output_dir = Path(args.output_dir)
        args.output_dir.makedirs_p()

        image_dir = args.output_dir / 'images'
        gt_dir = args.output_dir / 'gt'
        mask_dir = args.output_dir / 'mask'
        viz_dir = args.output_dir / 'viz'

        image_dir.makedirs_p()
        gt_dir.makedirs_p()
        mask_dir.makedirs_p()
        viz_dir.makedirs_p()

        output_writer = SummaryWriter(args.output_dir)

    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                            std=[0.5, 0.5, 0.5])
    flow_loader_h, flow_loader_w = 256, 832
    valid_flow_transform = custom_transforms.Compose([
        custom_transforms.Scale(h=flow_loader_h, w=flow_loader_w),
        custom_transforms.ArrayToTensor(), normalize
    ])
    if args.dataset == "kitti2015":
        val_flow_set = ValidationFlow(root=args.kitti_dir,
                                      sequence_length=5,
                                      transform=valid_flow_transform)

    val_loader = torch.utils.data.DataLoader(val_flow_set,
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=2,
                                             pin_memory=True,
                                             drop_last=True)

    disp_net = getattr(models, args.dispnet)().cuda()
    pose_net = getattr(models, args.posenet)(nb_ref_imgs=4).cuda()
    mask_net = getattr(models, args.masknet)(nb_ref_imgs=4).cuda()
    flow_net = getattr(models, args.flownet)(nlevels=args.nlevels).cuda()

    dispnet_weights = torch.load(args.pretrained_disp)
    posenet_weights = torch.load(args.pretrained_pose)
    masknet_weights = torch.load(args.pretrained_mask)
    flownet_weights = torch.load(args.pretrained_flow)
    disp_net.load_state_dict(dispnet_weights['state_dict'])
    pose_net.load_state_dict(posenet_weights['state_dict'])
    flow_net.load_state_dict(flownet_weights['state_dict'])
    mask_net.load_state_dict(masknet_weights['state_dict'])

    disp_net.eval()
    pose_net.eval()
    mask_net.eval()
    flow_net.eval()

    error_names = [
        'epe_total', 'epe_sp', 'epe_mv', 'Fl', 'epe_total_gt_mask',
        'epe_sp_gt_mask', 'epe_mv_gt_mask', 'Fl_gt_mask'
    ]
    errors = AverageMeter(i=len(error_names))
    for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv, flow_gt,
            obj_map_gt) in enumerate(tqdm(val_loader)):
        tgt_img_var = Variable(tgt_img.cuda(), volatile=True)
        ref_imgs_var = [
            Variable(img.cuda(), volatile=True) for img in ref_imgs
        ]
        intrinsics_var = Variable(intrinsics.cuda(), volatile=True)
        intrinsics_inv_var = Variable(intrinsics_inv.cuda(), volatile=True)

        flow_gt_var = Variable(flow_gt.cuda(), volatile=True)
        obj_map_gt_var = Variable(obj_map_gt.cuda(), volatile=True)

        disp = disp_net(tgt_img_var)
        depth = 1 / disp
        pose = pose_net(tgt_img_var, ref_imgs_var)
        explainability_mask = mask_net(tgt_img_var, ref_imgs_var)

        if args.flownet == 'Back2Future':
            flow_fwd, flow_bwd, _ = flow_net(tgt_img_var, ref_imgs_var[1:3])
        else:
            flow_fwd = flow_net(tgt_img_var, ref_imgs_var[2])

        flow_cam = pose2flow(depth.squeeze(1), pose[:, 2], intrinsics_var,
                             intrinsics_inv_var)
        flow_cam_bwd = pose2flow(depth.squeeze(1), pose[:, 1], intrinsics_var,
                                 intrinsics_inv_var)

        rigidity_mask = 1 - (1 - explainability_mask[:, 1]) * (
            1 - explainability_mask[:, 2]).unsqueeze(1) > 0.5
        rigidity_mask_census_soft = (flow_cam - flow_fwd).abs()  #.normalize()
        rigidity_mask_census_u = rigidity_mask_census_soft[:, 0] < args.THRESH
        rigidity_mask_census_v = rigidity_mask_census_soft[:, 1] < args.THRESH
        rigidity_mask_census = (rigidity_mask_census_u).type_as(flow_fwd) * (
            rigidity_mask_census_v).type_as(flow_fwd)

        rigidity_mask_combined = 1 - (
            1 - rigidity_mask.type_as(explainability_mask)) * (
                1 - rigidity_mask_census.type_as(explainability_mask))

        obj_map_gt_var_expanded = obj_map_gt_var.unsqueeze(1).type_as(flow_fwd)

        flow_fwd_non_rigid = (rigidity_mask_combined <= args.THRESH).type_as(
            flow_fwd).expand_as(flow_fwd) * flow_fwd
        flow_fwd_rigid = (rigidity_mask_combined > args.THRESH
                          ).type_as(flow_cam).expand_as(flow_cam) * flow_cam
        total_flow = flow_fwd_rigid + flow_fwd_non_rigid

        rigidity_mask = rigidity_mask.type_as(flow_fwd)
        _epe_errors = compute_all_epes(
            flow_gt_var, flow_cam,
            flow_fwd, rigidity_mask_combined) + compute_all_epes(
                flow_gt_var, flow_cam, flow_fwd, (1 - obj_map_gt_var_expanded))
        errors.update(_epe_errors)

        tgt_img_np = tgt_img[0].numpy()
        rigidity_mask_combined_np = rigidity_mask_combined.cpu().data[0].numpy(
        )
        gt_mask_np = obj_map_gt[0].numpy()

        if args.output_dir is not None:
            np.save(image_dir / str(i).zfill(3), tgt_img_np)
            np.save(gt_dir / str(i).zfill(3), gt_mask_np)
            np.save(mask_dir / str(i).zfill(3), rigidity_mask_combined_np)

        if (args.output_dir is not None) and i % 10 == 0:
            ind = int(i // 10)
            output_writer.add_image(
                'val Dispnet Output Normalized',
                tensor2array(disp.data[0].cpu(),
                             max_value=None,
                             colormap='bone'), ind)
            output_writer.add_image('val Input',
                                    tensor2array(tgt_img[0].cpu()), i)
            output_writer.add_image(
                'val Total Flow Output',
                flow_to_image(tensor2array(total_flow.data[0].cpu())), ind)
            output_writer.add_image(
                'val Rigid Flow Output',
                flow_to_image(tensor2array(flow_fwd_rigid.data[0].cpu())), ind)
            output_writer.add_image(
                'val Non-rigid Flow Output',
                flow_to_image(tensor2array(flow_fwd_non_rigid.data[0].cpu())),
                ind)
            output_writer.add_image(
                'val Rigidity Mask',
                tensor2array(rigidity_mask.data[0].cpu(),
                             max_value=1,
                             colormap='bone'), ind)
            output_writer.add_image(
                'val Rigidity Mask Census',
                tensor2array(rigidity_mask_census.data[0].cpu(),
                             max_value=1,
                             colormap='bone'), ind)
            output_writer.add_image(
                'val Rigidity Mask Combined',
                tensor2array(rigidity_mask_combined.data[0].cpu(),
                             max_value=1,
                             colormap='bone'), ind)

            tgt_img_viz = tensor2array(tgt_img[0].cpu())
            depth_viz = tensor2array(disp.data[0].cpu(),
                                     max_value=None,
                                     colormap='bone')
            mask_viz = tensor2array(
                rigidity_mask_census_soft.data[0].prod(dim=0).cpu(),
                max_value=1,
                colormap='bone')
            rigid_flow_viz = flow_to_image(tensor2array(
                flow_cam.data[0].cpu()))
            non_rigid_flow_viz = flow_to_image(
                tensor2array(flow_fwd_non_rigid.data[0].cpu()))
            total_flow_viz = flow_to_image(
                tensor2array(total_flow.data[0].cpu()))
            row1_viz = np.hstack((tgt_img_viz, depth_viz, mask_viz))
            row2_viz = np.hstack(
                (rigid_flow_viz, non_rigid_flow_viz, total_flow_viz))

            row1_viz_im = Image.fromarray((255 * row1_viz).astype('uint8'))
            row2_viz_im = Image.fromarray((row2_viz).astype('uint8'))

            row1_viz_im.save(viz_dir / str(i).zfill(3) + '01.png')
            row2_viz_im.save(viz_dir / str(i).zfill(3) + '02.png')

    print("Results")
    print("\t {:>10}, {:>10}, {:>10}, {:>6}, {:>10}, {:>10}, {:>10}, {:>10} ".
          format(*error_names))
    print(
        "Errors \t {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}"
        .format(*errors.avg))