コード例 #1
0
ファイル: student_code.py プロジェクト: binli123/cs838_hw2
def get_train_transforms(normalize):
  train_transforms = []
  train_transforms.append(transforms.Scale(160))
  train_transforms.append(transforms.RandomHorizontalFlip())
  train_transforms.append(transforms.RandomColor(0.15))
  train_transforms.append(transforms.RandomRotate(15))
  train_transforms.append(transforms.RandomSizedCrop(128))
  train_transforms.append(transforms.ToTensor())
  train_transforms.append(normalize)
  train_transforms = transforms.Compose(train_transforms)
  return train_transforms
コード例 #2
0
ファイル: main.py プロジェクト: zhuoliny/flowattack
def main():
    global args, best_error, n_iter
    args = parser.parse_args()
    save_path = Path(args.name)
    args.save_path = 'checkpoints' / save_path  #/timestamp
    print('=> will save everything to {}'.format(args.save_path))
    args.save_path.makedirs_p()
    torch.manual_seed(args.seed)

    training_writer = SummaryWriter(args.save_path)
    output_writer = SummaryWriter(args.save_path / 'valid')

    # Data loading code
    flow_loader_h, flow_loader_w = 384, 1280

    train_transform = custom_transforms.Compose([
        custom_transforms.RandomHorizontalFlip(),
        custom_transforms.RandomScaleCrop(h=256, w=256),
        custom_transforms.ArrayToTensor(),
    ])

    valid_transform = custom_transforms.Compose([
        custom_transforms.Scale(h=flow_loader_h, w=flow_loader_w),
        custom_transforms.ArrayToTensor()
    ])

    print("=> fetching scenes in '{}'".format(args.data))
    train_set = SequenceFolder(args.data,
                               transform=train_transform,
                               seed=args.seed,
                               train=True,
                               sequence_length=3)

    if args.valset == "kitti2015":
        from datasets.validation_flow import ValidationFlowKitti2015
        val_set = ValidationFlowKitti2015(root=args.kitti_data,
                                          transform=valid_transform)
    elif args.valset == "kitti2012":
        from datasets.validation_flow import ValidationFlowKitti2012
        val_set = ValidationFlowKitti2012(root=args.kitti_data,
                                          transform=valid_transform)

    if args.DEBUG:
        train_set.__len__ = 32
        train_set.samples = train_set.samples[:32]

    print('{} samples found in {} train scenes'.format(len(train_set),
                                                       len(train_set.scenes)))
    print('{} samples found in valid scenes'.format(len(val_set)))
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=1,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               drop_last=True)

    val_loader = torch.utils.data.DataLoader(
        val_set,
        batch_size=
        1,  # batch size is 1 since images in kitti have different sizes
        shuffle=False,
        num_workers=args.workers,
        pin_memory=True,
        drop_last=True)

    if args.epoch_size == 0:
        args.epoch_size = len(train_loader)

    # create model
    print("=> creating model")

    if args.flownet == 'SpyNet':
        flow_net = getattr(models, args.flownet)(nlevels=6, pretrained=True)
    elif args.flownet == 'Back2Future':
        flow_net = getattr(
            models, args.flownet)(pretrained='pretrained/b2f_rm_hard.pth.tar')
    elif args.flownet == 'PWCNet':
        flow_net = models.pwc_dc_net(
            'pretrained/pwc_net_chairs.pth.tar')  # pwc_net.pth.tar')
    else:
        flow_net = getattr(models, args.flownet)()

    if args.flownet in ['SpyNet', 'Back2Future', 'PWCNet']:
        print("=> using pre-trained weights for " + args.flownet)
    elif args.flownet in ['FlowNetC']:
        print("=> using pre-trained weights for FlowNetC")
        weights = torch.load('pretrained/FlowNet2-C_checkpoint.pth.tar')
        flow_net.load_state_dict(weights['state_dict'])
    elif args.flownet in ['FlowNetS']:
        print("=> using pre-trained weights for FlowNetS")
        weights = torch.load('pretrained/flownets.pth.tar')
        flow_net.load_state_dict(weights['state_dict'])
    elif args.flownet in ['FlowNet2']:
        print("=> using pre-trained weights for FlowNet2")
        weights = torch.load('pretrained/FlowNet2_checkpoint.pth.tar')
        flow_net.load_state_dict(weights['state_dict'])
    else:
        flow_net.init_weights()

    pytorch_total_params = sum(p.numel() for p in flow_net.parameters())
    print("Number of model paramters: " + str(pytorch_total_params))

    flow_net = flow_net.cuda()

    cudnn.benchmark = True
    if args.patch_type == 'circle':
        patch, mask, patch_shape = init_patch_circle(args.image_size,
                                                     args.patch_size)
        patch_init = patch.copy()
    elif args.patch_type == 'square':
        patch, patch_shape = init_patch_square(args.image_size,
                                               args.patch_size)
        patch_init = patch.copy()
        mask = np.ones(patch_shape)
    else:
        sys.exit("Please choose a square or circle patch")

    if args.patch_path:
        patch, mask, patch_shape = init_patch_from_image(
            args.patch_path, args.mask_path, args.image_size, args.patch_size)
        patch_init = patch.copy()

    if args.log_terminal:
        logger = TermLogger(n_epochs=args.epochs,
                            train_size=min(len(train_loader), args.epoch_size),
                            valid_size=len(val_loader),
                            attack_size=args.max_count)
        logger.epoch_bar.start()
    else:
        logger = None

    for epoch in range(args.epochs):

        if args.log_terminal:
            logger.epoch_bar.update(epoch)
            logger.reset_train_bar()

        # train for one epoch
        patch, mask, patch_init, patch_shape = train(patch, mask, patch_init,
                                                     patch_shape, train_loader,
                                                     flow_net, epoch, logger,
                                                     training_writer)

        # Validate
        errors, error_names = validate_flow_with_gt(patch, mask, patch_shape,
                                                    val_loader, flow_net,
                                                    epoch, logger,
                                                    output_writer)

        error_string = ', '.join('{} : {:.3f}'.format(name, error)
                                 for name, error in zip(error_names, errors))
        #
        if args.log_terminal:
            logger.valid_writer.write(' * Avg {}'.format(error_string))
        else:
            print('Epoch {} completed'.format(epoch))

        for error, name in zip(errors, error_names):
            training_writer.add_scalar(name, error, epoch)

        torch.save(patch, args.save_path / 'patch_epoch_{}'.format(str(epoch)))

    if args.log_terminal:
        logger.epoch_bar.finish()
コード例 #3
0
ファイル: main.py プロジェクト: binli123/cs838_hw2
def main(args):
  # parse args
  best_acc1 = 0.0

  if args.gpu >= 0:
    torch.cuda.set_device(args.gpu)
    print("Use GPU: {}".format(args.gpu))
  else:
    print('You are using CPU for computing!',
          'Yet we assume you are using a GPU.',
          'You will NOT be able to switch between CPU and GPU training!')

  # fix the random seeds (the best we can)
  fixed_random_seed = 2019
  torch.manual_seed(fixed_random_seed)
  np.random.seed(fixed_random_seed)
  random.seed(fixed_random_seed)

  # set up the model + loss
  if args.use_custom_conv:
    print("Using custom convolutions in the network")
    model = default_model(conv_op=CustomConv2d, num_classes=100)
  elif args.use_resnet18:
    model = torchvision.models.resnet18(pretrained=True)
    model.fc = nn.Linear(512, 100)
  elif args.use_adv_training:
    model = AdvSimpleNet(num_classes=100)
  else:
    model = default_model(num_classes=100)
  model_arch = "simplenet"
  criterion = nn.CrossEntropyLoss()
  # put everthing to gpu
  if args.gpu >= 0:
    model = model.cuda(args.gpu)
    criterion = criterion.cuda(args.gpu)

  # setup the optimizer
  optimizer = torch.optim.SGD(model.parameters(), args.lr,
                momentum=args.momentum,
                weight_decay=args.weight_decay)

  # resume from a checkpoint?
  if args.resume:
    if os.path.isfile(args.resume):
      print("=> loading checkpoint '{}'".format(args.resume))
      checkpoint = torch.load(args.resume)
      args.start_epoch = checkpoint['epoch']
      best_acc1 = checkpoint['best_acc1']
      model.load_state_dict(checkpoint['state_dict'])
      if args.gpu < 0:
        model = model.cpu()
      else:
        model = model.cuda(args.gpu)
      # only load the optimizer if necessary
      if (not args.evaluate) and (not args.attack):
        optimizer.load_state_dict(checkpoint['optimizer'])
      print("=> loaded checkpoint '{}' (epoch {}, acc1 {})"
          .format(args.resume, checkpoint['epoch'], best_acc1))
    else:
      print("=> no checkpoint found at '{}'".format(args.resume))

  # set up transforms for data augmentation
  normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                       std=[0.229, 0.224, 0.225])
  train_transforms = get_train_transforms(normalize)
  # val transofrms
  val_transforms=[]
  val_transforms.append(transforms.Scale(160, interpolations=None))
  val_transforms.append(transforms.ToTensor())
  val_transforms.append(normalize)
  val_transforms = transforms.Compose(val_transforms)
  if (not args.evaluate) and (not args.attack):
    print("Training time data augmentations:")
    print(train_transforms)

  # setup dataset and dataloader
  train_dataset = MiniPlacesLoader(args.data_folder,
                  split='train', transforms=train_transforms)
  val_dataset = MiniPlacesLoader(args.data_folder,
                  split='val', transforms=val_transforms)

  train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=args.batch_size, shuffle=True,
    num_workers=args.workers, pin_memory=True, sampler=None, drop_last=True)
  val_loader = torch.utils.data.DataLoader(
    val_dataset, batch_size=100, shuffle=False,
    num_workers=args.workers, pin_memory=True, sampler=None, drop_last=False)

  # testing only
  if (args.evaluate==args.attack) and args.evaluate:
    print("Cann't set evaluate and attack to True at the same time!")
    return

  # set up visualizer
  if args.vis:
    visualizer = default_attention(criterion)
  else:
    visualizer = None

  # evaluation
  if args.resume and args.evaluate:
    print("Testing the model ...")
    cudnn.deterministic = True
    validate(val_loader, model, -1, args, visualizer=visualizer)
    return

  # attack
  if args.resume and args.attack:
    print("Generating adversarial samples for the model ..")
    cudnn.deterministic = True
    validate(val_loader, model, -1, args,
             attacker=default_attack(criterion),
             visualizer=visualizer)
    return

  # enable cudnn benchmark
  cudnn.enabled = True
  cudnn.benchmark = True

  # warmup the training
  if (args.start_epoch == 0) and (args.warmup_epochs > 0):
    print("Warmup the training ...")
    for epoch in range(0, args.warmup_epochs):
      train(train_loader, model, criterion, optimizer, epoch, "warmup", args)

  # start the training
  print("Training the model ...")
  for epoch in range(args.start_epoch, args.epochs):
    # train for one epoch
    train(train_loader, model, criterion, optimizer, epoch, "train", args)

    # evaluate on validation set
    acc1 = validate(val_loader, model, epoch, args)

    # remember best acc@1 and save checkpoint
    is_best = acc1 > best_acc1
    best_acc1 = max(acc1, best_acc1)
    save_checkpoint({
      'epoch': epoch + 1,
      'model_arch': model_arch,
      'state_dict': model.state_dict(),
      'best_acc1': best_acc1,
      'optimizer' : optimizer.state_dict(),
    }, is_best)
コード例 #4
0
def main():
    global args
    args = parser.parse_args()

    args.pretrained_disp = Path(args.pretrained_disp)
    args.pretrained_pose = Path(args.pretrained_pose)
    args.pretrained_mask = Path(args.pretrained_mask)
    args.pretrained_flow = Path(args.pretrained_flow)

    if args.output_dir is not None:
        args.output_dir = Path(args.output_dir)
        args.output_dir.makedirs_p()

        image_dir = args.output_dir / 'images'
        gt_dir = args.output_dir / 'gt'
        mask_dir = args.output_dir / 'mask'
        viz_dir = args.output_dir / 'viz'
        rigidity_mask_dir = args.output_dir / 'rigidity'
        rigidity_census_mask_dir = args.output_dir / 'rigidity_census'
        explainability_mask_dir = args.output_dir / 'explainability'

        image_dir.makedirs_p()
        gt_dir.makedirs_p()
        mask_dir.makedirs_p()
        viz_dir.makedirs_p()
        rigidity_mask_dir.makedirs_p()
        rigidity_census_mask_dir.makedirs_p()
        explainability_mask_dir.makedirs_p()

        output_writer = SummaryWriter(args.output_dir)

    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                            std=[0.5, 0.5, 0.5])
    flow_loader_h, flow_loader_w = 256, 832
    valid_flow_transform = custom_transforms.Compose([
        custom_transforms.Scale(h=flow_loader_h, w=flow_loader_w),
        custom_transforms.ArrayToTensor(), normalize
    ])
    val_flow_set = ValidationMask(root=args.kitti_dir,
                                  sequence_length=5,
                                  transform=valid_flow_transform)

    val_loader = torch.utils.data.DataLoader(val_flow_set,
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=2,
                                             pin_memory=True,
                                             drop_last=True)

    disp_net = getattr(models, args.dispnet)().cuda()
    pose_net = getattr(models, args.posenet)(nb_ref_imgs=4).cuda()
    mask_net = getattr(models, args.masknet)(nb_ref_imgs=4).cuda()
    flow_net = getattr(models, args.flownet)(nlevels=args.nlevels).cuda()

    dispnet_weights = torch.load(args.pretrained_disp)
    posenet_weights = torch.load(args.pretrained_pose)
    masknet_weights = torch.load(args.pretrained_mask)
    flownet_weights = torch.load(args.pretrained_flow)
    disp_net.load_state_dict(dispnet_weights['state_dict'])
    pose_net.load_state_dict(posenet_weights['state_dict'])
    flow_net.load_state_dict(flownet_weights['state_dict'])
    mask_net.load_state_dict(masknet_weights['state_dict'])

    disp_net.eval()
    pose_net.eval()
    mask_net.eval()
    flow_net.eval()

    error_names = ['tp_0', 'fp_0', 'fn_0', 'tp_1', 'fp_1', 'fn_1']
    errors = AverageMeter(i=len(error_names))
    errors_census = AverageMeter(i=len(error_names))
    errors_bare = AverageMeter(i=len(error_names))

    for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv, flow_gt, obj_map_gt,
            semantic_map_gt) in enumerate(tqdm(val_loader)):
        tgt_img_var = Variable(tgt_img.cuda(), volatile=True)
        ref_imgs_var = [
            Variable(img.cuda(), volatile=True) for img in ref_imgs
        ]
        intrinsics_var = Variable(intrinsics.cuda(), volatile=True)
        intrinsics_inv_var = Variable(intrinsics_inv.cuda(), volatile=True)

        flow_gt_var = Variable(flow_gt.cuda(), volatile=True)
        obj_map_gt_var = Variable(obj_map_gt.cuda(), volatile=True)

        disp = disp_net(tgt_img_var)
        depth = 1 / disp
        pose = pose_net(tgt_img_var, ref_imgs_var)
        explainability_mask = mask_net(tgt_img_var, ref_imgs_var)
        if args.flownet in ['Back2Future']:
            flow_fwd, flow_bwd, _ = flow_net(tgt_img_var, ref_imgs_var[1:3])
        else:
            flow_fwd = flow_net(tgt_img_var, ref_imgs_var[2])
        flow_cam = pose2flow(depth.squeeze(1), pose[:, 2], intrinsics_var,
                             intrinsics_inv_var)

        rigidity_mask = 1 - (1 - explainability_mask[:, 1]) * (
            1 - explainability_mask[:, 2]).unsqueeze(1) > 0.5
        rigidity_mask_census_soft = (flow_cam - flow_fwd).pow(2).sum(
            dim=1).unsqueeze(1).sqrt()  #.normalize()
        rigidity_mask_census_soft = 1 - rigidity_mask_census_soft / rigidity_mask_census_soft.max(
        )
        rigidity_mask_census = rigidity_mask_census_soft > args.THRESH

        rigidity_mask_combined = 1 - (
            1 - rigidity_mask.type_as(explainability_mask)) * (
                1 - rigidity_mask_census.type_as(explainability_mask))

        flow_fwd_non_rigid = (1 - rigidity_mask_combined).type_as(
            flow_fwd).expand_as(flow_fwd) * flow_fwd
        flow_fwd_rigid = rigidity_mask_combined.type_as(flow_fwd).expand_as(
            flow_fwd) * flow_cam
        total_flow = flow_fwd_rigid + flow_fwd_non_rigid

        obj_map_gt_var_expanded = obj_map_gt_var.unsqueeze(1).type_as(flow_fwd)

        tgt_img_np = tgt_img[0].numpy()
        rigidity_mask_combined_np = rigidity_mask_combined.cpu().data[0].numpy(
        )
        rigidity_mask_census_np = rigidity_mask_census.cpu().data[0].numpy()
        rigidity_mask_bare_np = rigidity_mask.cpu().data[0].numpy()

        gt_mask_np = obj_map_gt[0].numpy()
        semantic_map_np = semantic_map_gt[0].numpy()

        _errors = mask_error(gt_mask_np, semantic_map_np,
                             rigidity_mask_combined_np[0])
        _errors_census = mask_error(gt_mask_np, semantic_map_np,
                                    rigidity_mask_census_np[0])
        _errors_bare = mask_error(gt_mask_np, semantic_map_np,
                                  rigidity_mask_bare_np[0])

        errors.update(_errors)
        errors_census.update(_errors_census)
        errors_bare.update(_errors_bare)

        if args.output_dir is not None:
            np.save(image_dir / str(i).zfill(3), tgt_img_np)
            np.save(gt_dir / str(i).zfill(3), gt_mask_np)
            np.save(mask_dir / str(i).zfill(3), rigidity_mask_combined_np)
            np.save(rigidity_mask_dir / str(i).zfill(3),
                    rigidity_mask.cpu().data[0].numpy())
            np.save(rigidity_census_mask_dir / str(i).zfill(3),
                    rigidity_mask_census.cpu().data[0].numpy())
            np.save(explainability_mask_dir / str(i).zfill(3),
                    explainability_mask[:, 1].cpu().data[0].numpy())
            # rigidity_mask_dir rigidity_mask.numpy()
            # rigidity_census_mask_dir rigidity_mask_census.numpy()

        if (args.output_dir is not None) and i % 10 == 0:
            ind = int(i // 10)
            output_writer.add_image(
                'val Dispnet Output Normalized',
                tensor2array(disp.data[0].cpu(),
                             max_value=None,
                             colormap='bone'), ind)
            output_writer.add_image('val Input',
                                    tensor2array(tgt_img[0].cpu()), i)
            output_writer.add_image(
                'val Total Flow Output',
                flow_to_image(tensor2array(total_flow.data[0].cpu())), ind)
            output_writer.add_image(
                'val Rigid Flow Output',
                flow_to_image(tensor2array(flow_fwd_rigid.data[0].cpu())), ind)
            output_writer.add_image(
                'val Non-rigid Flow Output',
                flow_to_image(tensor2array(flow_fwd_non_rigid.data[0].cpu())),
                ind)
            output_writer.add_image(
                'val Rigidity Mask',
                tensor2array(rigidity_mask.data[0].cpu(),
                             max_value=1,
                             colormap='bone'), ind)
            output_writer.add_image(
                'val Rigidity Mask Census',
                tensor2array(rigidity_mask_census.data[0].cpu(),
                             max_value=1,
                             colormap='bone'), ind)
            output_writer.add_image(
                'val Rigidity Mask Combined',
                tensor2array(rigidity_mask_combined.data[0].cpu(),
                             max_value=1,
                             colormap='bone'), ind)

        if args.output_dir is not None:
            tgt_img_viz = tensor2array(tgt_img[0].cpu())
            depth_viz = tensor2array(disp.data[0].cpu(),
                                     max_value=None,
                                     colormap='magma')
            mask_viz = tensor2array(rigidity_mask_census_soft.data[0].cpu(),
                                    max_value=1,
                                    colormap='bone')
            row2_viz = flow_to_image(
                np.hstack((tensor2array(flow_cam.data[0].cpu()),
                           tensor2array(flow_fwd_non_rigid.data[0].cpu()),
                           tensor2array(total_flow.data[0].cpu()))))

            row1_viz = np.hstack((tgt_img_viz, depth_viz, mask_viz))
            ####### sửa 2 cái vstack thành hstack ###############
            viz3 = np.hstack(
                (255 * tgt_img_viz, 255 * depth_viz, 255 * mask_viz,
                 flow_to_image(
                     np.hstack((tensor2array(flow_fwd_non_rigid.data[0].cpu()),
                                tensor2array(total_flow.data[0].cpu()))))))
            ########################################################
            ######## code tự thêm ####################
            row1_viz = np.transpose(row1_viz, (1, 2, 0))
            row2_viz = np.transpose(row2_viz, (1, 2, 0))
            viz3 = np.transpose(viz3, (1, 2, 0))
            ##########################################

            row1_viz_im = Image.fromarray((255 * row1_viz).astype('uint8'))
            row2_viz_im = Image.fromarray((row2_viz).astype('uint8'))
            viz3_im = Image.fromarray(viz3.astype('uint8'))

            row1_viz_im.save(viz_dir / str(i).zfill(3) + '01.png')
            row2_viz_im.save(viz_dir / str(i).zfill(3) + '02.png')
            viz3_im.save(viz_dir / str(i).zfill(3) + '03.png')

    bg_iou = errors.sum[0] / (errors.sum[0] + errors.sum[1] + errors.sum[2])
    fg_iou = errors.sum[3] / (errors.sum[3] + errors.sum[4] + errors.sum[5])
    avg_iou = (bg_iou + fg_iou) / 2

    bg_iou_census = errors_census.sum[0] / (
        errors_census.sum[0] + errors_census.sum[1] + errors_census.sum[2])
    fg_iou_census = errors_census.sum[3] / (
        errors_census.sum[3] + errors_census.sum[4] + errors_census.sum[5])
    avg_iou_census = (bg_iou_census + fg_iou_census) / 2

    bg_iou_bare = errors_bare.sum[0] / (
        errors_bare.sum[0] + errors_bare.sum[1] + errors_bare.sum[2])
    fg_iou_bare = errors_bare.sum[3] / (
        errors_bare.sum[3] + errors_bare.sum[4] + errors_bare.sum[5])
    avg_iou_bare = (bg_iou_bare + fg_iou_bare) / 2

    print("Results Full Model")
    print("\t {:>10}, {:>10}, {:>10} ".format('iou', 'bg_iou', 'fg_iou'))
    print("Errors \t {:10.4f}, {:10.4f} {:10.4f}".format(
        avg_iou, bg_iou, fg_iou))

    print("Results Census only")
    print("\t {:>10}, {:>10}, {:>10} ".format('iou', 'bg_iou', 'fg_iou'))
    print("Errors \t {:10.4f}, {:10.4f} {:10.4f}".format(
        avg_iou_census, bg_iou_census, fg_iou_census))

    print("Results Bare")
    print("\t {:>10}, {:>10}, {:>10} ".format('iou', 'bg_iou', 'fg_iou'))
    print("Errors \t {:10.4f}, {:10.4f} {:10.4f}".format(
        avg_iou_bare, bg_iou_bare, fg_iou_bare))
コード例 #5
0
ファイル: test_flow.py プロジェクト: zenithfang/cc
def main():
    global args
    args = parser.parse_args()
    args.pretrained_disp = Path(args.pretrained_disp)
    args.pretrained_pose = Path(args.pretrained_pose)
    args.pretrained_mask = Path(args.pretrained_mask)
    args.pretrained_flow = Path(args.pretrained_flow)

    if args.output_dir is not None:
        args.output_dir = Path(args.output_dir)
        args.output_dir.makedirs_p()

        image_dir = args.output_dir / 'images'
        gt_dir = args.output_dir / 'gt'
        mask_dir = args.output_dir / 'mask'
        viz_dir = args.output_dir / 'viz'

        image_dir.makedirs_p()
        gt_dir.makedirs_p()
        mask_dir.makedirs_p()
        viz_dir.makedirs_p()

        output_writer = SummaryWriter(args.output_dir)

    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                            std=[0.5, 0.5, 0.5])
    flow_loader_h, flow_loader_w = 256, 832
    valid_flow_transform = custom_transforms.Compose([
        custom_transforms.Scale(h=flow_loader_h, w=flow_loader_w),
        custom_transforms.ArrayToTensor(), normalize
    ])
    if args.dataset == "kitti2015":
        val_flow_set = ValidationFlow(root=args.kitti_dir,
                                      sequence_length=5,
                                      transform=valid_flow_transform)

    val_loader = torch.utils.data.DataLoader(val_flow_set,
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=2,
                                             pin_memory=True,
                                             drop_last=True)

    disp_net = getattr(models, args.dispnet)().cuda()
    pose_net = getattr(models, args.posenet)(nb_ref_imgs=4).cuda()
    mask_net = getattr(models, args.masknet)(nb_ref_imgs=4).cuda()
    flow_net = getattr(models, args.flownet)(nlevels=args.nlevels).cuda()

    dispnet_weights = torch.load(args.pretrained_disp)
    posenet_weights = torch.load(args.pretrained_pose)
    masknet_weights = torch.load(args.pretrained_mask)
    flownet_weights = torch.load(args.pretrained_flow)
    disp_net.load_state_dict(dispnet_weights['state_dict'])
    pose_net.load_state_dict(posenet_weights['state_dict'])
    flow_net.load_state_dict(flownet_weights['state_dict'])
    mask_net.load_state_dict(masknet_weights['state_dict'])

    disp_net.eval()
    pose_net.eval()
    mask_net.eval()
    flow_net.eval()

    error_names = [
        'epe_total', 'epe_sp', 'epe_mv', 'Fl', 'epe_total_gt_mask',
        'epe_sp_gt_mask', 'epe_mv_gt_mask', 'Fl_gt_mask'
    ]
    errors = AverageMeter(i=len(error_names))
    for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv, flow_gt,
            obj_map_gt) in enumerate(tqdm(val_loader)):
        tgt_img_var = Variable(tgt_img.cuda(), volatile=True)
        ref_imgs_var = [
            Variable(img.cuda(), volatile=True) for img in ref_imgs
        ]
        intrinsics_var = Variable(intrinsics.cuda(), volatile=True)
        intrinsics_inv_var = Variable(intrinsics_inv.cuda(), volatile=True)

        flow_gt_var = Variable(flow_gt.cuda(), volatile=True)
        obj_map_gt_var = Variable(obj_map_gt.cuda(), volatile=True)

        disp = disp_net(tgt_img_var)
        depth = 1 / disp
        pose = pose_net(tgt_img_var, ref_imgs_var)
        explainability_mask = mask_net(tgt_img_var, ref_imgs_var)

        if args.flownet == 'Back2Future':
            flow_fwd, flow_bwd, _ = flow_net(tgt_img_var, ref_imgs_var[1:3])
        else:
            flow_fwd = flow_net(tgt_img_var, ref_imgs_var[2])

        flow_cam = pose2flow(depth.squeeze(1), pose[:, 2], intrinsics_var,
                             intrinsics_inv_var)
        flow_cam_bwd = pose2flow(depth.squeeze(1), pose[:, 1], intrinsics_var,
                                 intrinsics_inv_var)

        rigidity_mask = 1 - (1 - explainability_mask[:, 1]) * (
            1 - explainability_mask[:, 2]).unsqueeze(1) > 0.5
        rigidity_mask_census_soft = (flow_cam - flow_fwd).abs()  #.normalize()
        rigidity_mask_census_u = rigidity_mask_census_soft[:, 0] < args.THRESH
        rigidity_mask_census_v = rigidity_mask_census_soft[:, 1] < args.THRESH
        rigidity_mask_census = (rigidity_mask_census_u).type_as(flow_fwd) * (
            rigidity_mask_census_v).type_as(flow_fwd)

        rigidity_mask_combined = 1 - (
            1 - rigidity_mask.type_as(explainability_mask)) * (
                1 - rigidity_mask_census.type_as(explainability_mask))

        obj_map_gt_var_expanded = obj_map_gt_var.unsqueeze(1).type_as(flow_fwd)

        flow_fwd_non_rigid = (rigidity_mask_combined <= args.THRESH).type_as(
            flow_fwd).expand_as(flow_fwd) * flow_fwd
        flow_fwd_rigid = (rigidity_mask_combined > args.THRESH
                          ).type_as(flow_cam).expand_as(flow_cam) * flow_cam
        total_flow = flow_fwd_rigid + flow_fwd_non_rigid

        rigidity_mask = rigidity_mask.type_as(flow_fwd)
        _epe_errors = compute_all_epes(
            flow_gt_var, flow_cam,
            flow_fwd, rigidity_mask_combined) + compute_all_epes(
                flow_gt_var, flow_cam, flow_fwd, (1 - obj_map_gt_var_expanded))
        errors.update(_epe_errors)

        tgt_img_np = tgt_img[0].numpy()
        rigidity_mask_combined_np = rigidity_mask_combined.cpu().data[0].numpy(
        )
        gt_mask_np = obj_map_gt[0].numpy()

        if args.output_dir is not None:
            np.save(image_dir / str(i).zfill(3), tgt_img_np)
            np.save(gt_dir / str(i).zfill(3), gt_mask_np)
            np.save(mask_dir / str(i).zfill(3), rigidity_mask_combined_np)

        if (args.output_dir is not None) and i % 10 == 0:
            ind = int(i // 10)
            output_writer.add_image(
                'val Dispnet Output Normalized',
                tensor2array(disp.data[0].cpu(),
                             max_value=None,
                             colormap='bone'), ind)
            output_writer.add_image('val Input',
                                    tensor2array(tgt_img[0].cpu()), i)
            output_writer.add_image(
                'val Total Flow Output',
                flow_to_image(tensor2array(total_flow.data[0].cpu())), ind)
            output_writer.add_image(
                'val Rigid Flow Output',
                flow_to_image(tensor2array(flow_fwd_rigid.data[0].cpu())), ind)
            output_writer.add_image(
                'val Non-rigid Flow Output',
                flow_to_image(tensor2array(flow_fwd_non_rigid.data[0].cpu())),
                ind)
            output_writer.add_image(
                'val Rigidity Mask',
                tensor2array(rigidity_mask.data[0].cpu(),
                             max_value=1,
                             colormap='bone'), ind)
            output_writer.add_image(
                'val Rigidity Mask Census',
                tensor2array(rigidity_mask_census.data[0].cpu(),
                             max_value=1,
                             colormap='bone'), ind)
            output_writer.add_image(
                'val Rigidity Mask Combined',
                tensor2array(rigidity_mask_combined.data[0].cpu(),
                             max_value=1,
                             colormap='bone'), ind)

            tgt_img_viz = tensor2array(tgt_img[0].cpu())
            depth_viz = tensor2array(disp.data[0].cpu(),
                                     max_value=None,
                                     colormap='bone')
            mask_viz = tensor2array(
                rigidity_mask_census_soft.data[0].prod(dim=0).cpu(),
                max_value=1,
                colormap='bone')
            rigid_flow_viz = flow_to_image(tensor2array(
                flow_cam.data[0].cpu()))
            non_rigid_flow_viz = flow_to_image(
                tensor2array(flow_fwd_non_rigid.data[0].cpu()))
            total_flow_viz = flow_to_image(
                tensor2array(total_flow.data[0].cpu()))
            row1_viz = np.hstack((tgt_img_viz, depth_viz, mask_viz))
            row2_viz = np.hstack(
                (rigid_flow_viz, non_rigid_flow_viz, total_flow_viz))

            row1_viz_im = Image.fromarray((255 * row1_viz).astype('uint8'))
            row2_viz_im = Image.fromarray((row2_viz).astype('uint8'))

            row1_viz_im.save(viz_dir / str(i).zfill(3) + '01.png')
            row2_viz_im.save(viz_dir / str(i).zfill(3) + '02.png')

    print("Results")
    print("\t {:>10}, {:>10}, {:>10}, {:>6}, {:>10}, {:>10}, {:>10}, {:>10} ".
          format(*error_names))
    print(
        "Errors \t {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}"
        .format(*errors.avg))
コード例 #6
0
ファイル: main2.py プロジェクト: cudnn/cc-1
def main():
    global global_vars_dict
    args = global_vars_dict['args']
    best_error = -1  #best model choosing

    #mkdir
    timestamp = datetime.datetime.now().strftime("%m-%d-%H:%M")

    args.save_path = Path('checkpoints') / Path(args.data_dir).stem / timestamp
    print('=> will save everything to {}'.format(args.save_path))
    args.save_path.makedirs_p()
    torch.manual_seed(args.seed)
    if args.alternating:
        args.alternating_flags = np.array([False, False, True])
    #mk writers
    tb_writer = SummaryWriter(args.save_path)

    # Data loading code
    flow_loader_h, flow_loader_w = 256, 832

    if args.data_normalization == 'global':
        normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                                std=[0.5, 0.5, 0.5])
    elif args.data_normalization == 'local':
        normalize = custom_transforms.NormalizeLocally()

    if args.fix_flownet:
        train_transform = custom_transforms.Compose([
            custom_transforms.RandomHorizontalFlip(),
            custom_transforms.RandomScaleCrop(),
            custom_transforms.ArrayToTensor(), normalize
        ])
    else:
        train_transform = custom_transforms.Compose([
            custom_transforms.RandomRotate(),
            custom_transforms.RandomHorizontalFlip(),
            custom_transforms.RandomScaleCrop(),
            custom_transforms.ArrayToTensor(), normalize
        ])

    valid_transform = custom_transforms.Compose(
        [custom_transforms.ArrayToTensor(), normalize])

    valid_flow_transform = custom_transforms.Compose([
        custom_transforms.Scale(h=flow_loader_h, w=flow_loader_w),
        custom_transforms.ArrayToTensor(), normalize
    ])

    print("=> fetching scenes in '{}'".format(args.data_dir))

    #train set, loader only建立一个
    if args.dataset_format == 'stacked':
        from datasets.stacked_sequence_folders import SequenceFolder
    elif args.dataset_format == 'sequential':
        from datasets.sequence_folders import SequenceFolder
        train_set = SequenceFolder(  #mc data folder
            args.data_dir,
            transform=train_transform,
            seed=args.seed,
            train=True,
            sequence_length=args.sequence_length,  #5
            target_transform=None)
    elif args.dataset_format == 'sequential_with_gt':  # with all possible gt
        from datasets.sequence_mc import SequenceFolder
        train_set = SequenceFolder(  # mc data folder
            args.data_dir,
            transform=train_transform,
            seed=args.seed,
            train=True,
            sequence_length=args.sequence_length,  # 5
            target_transform=None)
    else:
        return

    if args.DEBUG:
        train_set.__len__ = 32
        train_set.samples = train_set.samples[:32]
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               drop_last=True)

    if args.epoch_size == 0:
        args.epoch_size = len(train_loader)

#val set,loader 挨个建立

# if no Groundtruth is avalaible, Validation set is the same type as training set to measure photometric loss from warping
    if args.val_without_gt:
        from datasets.sequence_folders2 import SequenceFolder  #就多了一级文件夹
        val_set_without_gt = SequenceFolder(  #只有图
            args.data_dir,
            transform=valid_transform,
            seed=None,
            train=False,
            sequence_length=args.sequence_length,
            target_transform=None)
        val_loader = torch.utils.data.DataLoader(val_set_without_gt,
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 num_workers=args.workers,
                                                 pin_memory=True,
                                                 drop_last=True)

    if args.val_with_depth_gt:
        from datasets.validation_folders2 import ValidationSet

        val_set_with_depth_gt = ValidationSet(args.data_dir,
                                              transform=valid_transform)

        val_loader_depth = torch.utils.data.DataLoader(
            val_set_with_depth_gt,
            batch_size=args.batch_size,
            shuffle=False,
            num_workers=args.workers,
            pin_memory=True,
            drop_last=True)

    if args.val_with_flow_gt:  #暂时没有
        from datasets.validation_flow import ValidationFlow
        val_flow_set = ValidationFlow(root=args.kitti_dir,
                                      sequence_length=args.sequence_length,
                                      transform=valid_flow_transform)
        val_flow_loader = torch.utils.data.DataLoader(
            val_flow_set,
            batch_size=1,
            # batch size is 1 since images in kitti have different sizes
            shuffle=False,
            num_workers=args.workers,
            pin_memory=True,
            drop_last=True)

    print('{} samples found in {} train scenes'.format(len(train_set),
                                                       len(train_set.scenes)))
    if args.val_without_gt:
        print('{} samples found in {} valid scenes'.format(
            len(val_set_without_gt), len(val_set_without_gt.scenes)))

#1 create model
    print("=> creating model")
    #1.1 disp_net
    disp_net = getattr(models, args.dispnet)().cuda()
    output_exp = True  #args.mask_loss_weight > 0
    if not output_exp:
        print("=> no mask loss, PoseExpnet will only output pose")
    #1.2 pose_net
    pose_net = getattr(models, args.posenet)(nb_ref_imgs=args.sequence_length -
                                             1).cuda()

    #1.3.flow_net
    if args.flownet == 'SpyNet':
        flow_net = getattr(models,
                           args.flownet)(nlevels=args.nlevels,
                                         pre_normalization=normalize).cuda()
    elif args.flownet == 'FlowNetC6':  #flonwtc6
        flow_net = getattr(models, args.flownet)(nlevels=args.nlevels).cuda()
    elif args.flownet == 'FlowNetS':
        flow_net = getattr(models, args.flownet)(nlevels=args.nlevels).cuda()
    elif args.flownet == 'Back2Future':
        flow_net = getattr(models, args.flownet)(nlevels=args.nlevels).cuda()

    # 1.4 mask_net
    mask_net = getattr(models,
                       args.masknet)(nb_ref_imgs=args.sequence_length - 1,
                                     output_exp=True).cuda()

    #2 载入参数
    #2.1 pose
    if args.pretrained_pose:
        print("=> using pre-trained weights for explainabilty and pose net")
        weights = torch.load(args.pretrained_pose)
        pose_net.load_state_dict(weights['state_dict'])
    else:
        pose_net.init_weights()

    if args.pretrained_mask:
        print("=> using pre-trained weights for explainabilty and pose net")
        weights = torch.load(args.pretrained_mask)
        mask_net.load_state_dict(weights['state_dict'])
    else:
        mask_net.init_weights()

    # import ipdb; ipdb.set_trace()
    if args.pretrained_disp:
        print("=> using pre-trained weights from {}".format(
            args.pretrained_disp))
        weights = torch.load(args.pretrained_disp)
        disp_net.load_state_dict(weights['state_dict'])
    else:
        disp_net.init_weights()

    if args.pretrained_flow:
        print("=> using pre-trained weights for FlowNet")
        weights = torch.load(args.pretrained_flow)
        flow_net.load_state_dict(weights['state_dict'])
    else:
        flow_net.init_weights()

    if args.resume:
        print("=> resuming from checkpoint")
        dispnet_weights = torch.load(args.save_path /
                                     'dispnet_checkpoint.pth.tar')
        posenet_weights = torch.load(args.save_path /
                                     'posenet_checkpoint.pth.tar')
        masknet_weights = torch.load(args.save_path /
                                     'masknet_checkpoint.pth.tar')
        flownet_weights = torch.load(args.save_path /
                                     'flownet_checkpoint.pth.tar')
        disp_net.load_state_dict(dispnet_weights['state_dict'])
        pose_net.load_state_dict(posenet_weights['state_dict'])
        flow_net.load_state_dict(flownet_weights['state_dict'])
        mask_net.load_state_dict(masknet_weights['state_dict'])

    # import ipdb; ipdb.set_trace()
    cudnn.benchmark = True
    disp_net = torch.nn.DataParallel(disp_net)
    pose_net = torch.nn.DataParallel(pose_net)
    mask_net = torch.nn.DataParallel(mask_net)
    flow_net = torch.nn.DataParallel(flow_net)

    print('=> setting adam solver')

    parameters = chain(disp_net.parameters(), pose_net.parameters(),
                       mask_net.parameters(), flow_net.parameters())
    optimizer = torch.optim.Adam(parameters,
                                 args.lr,
                                 betas=(args.momentum, args.beta),
                                 weight_decay=args.weight_decay)

    if args.resume and (args.save_path /
                        'optimizer_checkpoint.pth.tar').exists():
        print("=> loading optimizer from checkpoint")
        optimizer_weights = torch.load(args.save_path /
                                       'optimizer_checkpoint.pth.tar')
        optimizer.load_state_dict(optimizer_weights['state_dict'])

    with open(args.save_path / args.log_summary, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss', 'validation_loss'])

    with open(args.save_path / args.log_full, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow([
            'train_loss', 'photo_cam_loss', 'photo_flow_loss',
            'explainability_loss', 'smooth_loss'
        ])

    #
    if args.log_terminal:
        logger = TermLogger(n_epochs=args.epochs,
                            train_size=min(len(train_loader), args.epoch_size),
                            valid_size=len(val_loader_depth))
        logger.epoch_bar.start()
    else:
        logger = None

#预先评估下

    if args.pretrained_disp or args.evaluate:
        logger.reset_valid_bar()
        if args.val_without_gt:
            pass
            #val_loss = validate_without_gt(val_loader,disp_net,pose_net,mask_net,flow_net,epoch=0, logger=logger, tb_writer=tb_writer,nb_writers=3,global_vars_dict = global_vars_dict)
            #val_loss =0

        if args.val_with_depth_gt:
            pass
            depth_errors, depth_error_names = validate_depth_with_gt(
                val_loader_depth,
                disp_net,
                epoch=0,
                logger=logger,
                tb_writer=tb_writer,
                global_vars_dict=global_vars_dict)


#3. main cycle
    for epoch in range(1, args.epochs):  #epoch 0 在第没入循环之前已经测试了.
        #3.1 四个子网络,训练哪几个
        if args.fix_flownet:
            for fparams in flow_net.parameters():
                fparams.requires_grad = False

        if args.fix_masknet:
            for fparams in mask_net.parameters():
                fparams.requires_grad = False

        if args.fix_posenet:
            for fparams in pose_net.parameters():
                fparams.requires_grad = False

        if args.fix_dispnet:
            for fparams in disp_net.parameters():
                fparams.requires_grad = False

        if args.log_terminal:
            logger.epoch_bar.update(epoch)
            logger.reset_train_bar()
        #validation data
        flow_error_names = ['no']
        flow_errors = [0]
        errors = [0]
        error_names = ['no error names depth']
        print('\nepoch [{}/{}]\n'.format(epoch + 1, args.epochs))
        #3.2 train for one epoch---------
        #train_loss=0
        train_loss = train_gt(train_loader, disp_net, pose_net, mask_net,
                              flow_net, optimizer, logger, tb_writer,
                              global_vars_dict)

        #3.3 evaluate on validation set-----

        if args.val_without_gt:
            val_loss = validate_without_gt(val_loader,
                                           disp_net,
                                           pose_net,
                                           mask_net,
                                           flow_net,
                                           epoch=0,
                                           logger=logger,
                                           tb_writer=tb_writer,
                                           nb_writers=3,
                                           global_vars_dict=global_vars_dict)

        if args.val_with_depth_gt:
            depth_errors, depth_error_names = validate_depth_with_gt(
                val_loader_depth,
                disp_net,
                epoch=epoch,
                logger=logger,
                tb_writer=tb_writer,
                global_vars_dict=global_vars_dict)

        if args.val_with_flow_gt:
            pass
            #flow_errors, flow_error_names = validate_flow_with_gt(val_flow_loader, disp_net, pose_net, mask_net, flow_net, epoch, logger, tb_writer)

            #for error, name in zip(flow_errors, flow_error_names):
            #    training_writer.add_scalar(name, error, epoch)

        #----------------------

        #3.4 Up to you to chose the most relevant error to measure your model's performance, careful some measures are to maximize (such as a1,a2,a3)

        if not args.fix_posenet:
            decisive_error = 0  # flow_errors[-2]    # epe_rigid_with_gt_mask
        elif not args.fix_dispnet:
            decisive_error = 0  # errors[0]      #depth abs_diff
        elif not args.fix_flownet:
            decisive_error = 0  # flow_errors[-1]    #epe_non_rigid_with_gt_mask
        elif not args.fix_masknet:
            decisive_error = 0  #flow_errors[3]     # percent outliers

        #3.5 log
        if args.log_terminal:
            logger.train_writer.write(
                ' * Avg Loss : {:.3f}'.format(train_loss))
            logger.reset_valid_bar()
        #eopch data log on tensorboard
        #train loss
        tb_writer.add_scalar('epoch/train_loss', train_loss, epoch)
        #val_without_gt loss
        if args.val_without_gt:
            tb_writer.add_scalar('epoch/val_loss', val_loss, epoch)

        if args.val_with_depth_gt:
            #val with depth gt
            for error, name in zip(depth_errors, depth_error_names):
                tb_writer.add_scalar('epoch/' + name, error, epoch)

        #3.6 save model and remember lowest error and save checkpoint

        if best_error < 0:
            best_error = train_loss

        is_best = train_loss <= best_error
        best_error = min(best_error, train_loss)
        save_checkpoint(args.save_path, {
            'epoch': epoch + 1,
            'state_dict': disp_net.module.state_dict()
        }, {
            'epoch': epoch + 1,
            'state_dict': pose_net.module.state_dict()
        }, {
            'epoch': epoch + 1,
            'state_dict': mask_net.module.state_dict()
        }, {
            'epoch': epoch + 1,
            'state_dict': flow_net.module.state_dict()
        }, is_best)

        with open(args.save_path / args.log_summary, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([train_loss, decisive_error])
    if args.log_terminal:
        logger.epoch_bar.finish()
コード例 #7
0
def main():
    global args
    args = parser.parse_args()
    args.pretrained_disp = Path(args.pretrained_disp)
    args.pretrained_pose = Path(args.pretrained_pose)
    # args.pretrained_mask = Path(args.pretrained_mask)
    args.pretrained_flow = Path(args.pretrained_flow)

    if args.output_dir is not None:
        args.output_dir = Path(args.output_dir)
        args.output_dir.makedirs_p()

        image_dir = args.output_dir / 'images'
        gt_dir = args.output_dir / 'gt'
        mask_dir = args.output_dir / 'mask'
        viz_dir = args.output_dir / 'viz'

        image_dir.makedirs_p()
        gt_dir.makedirs_p()
        mask_dir.makedirs_p()
        viz_dir.makedirs_p()

        output_writer = SummaryWriter(args.output_dir)

    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                            std=[0.5, 0.5, 0.5])
    flow_loader_h, flow_loader_w = 256, 832
    valid_flow_transform = custom_transforms.Compose([
        custom_transforms.Scale(h=flow_loader_h, w=flow_loader_w),
        custom_transforms.ArrayToTensor(), normalize
    ])
    if args.dataset == "kitti2015":
        val_flow_set = ValidationFlow(root=args.kitti_dir,
                                      sequence_length=3,
                                      transform=valid_flow_transform)

    val_loader = torch.utils.data.DataLoader(val_flow_set,
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=2,
                                             pin_memory=True,
                                             drop_last=True)

    disp_net = getattr(models, args.dispnet)().cuda()
    pose_net = getattr(models, args.posenet)(nb_ref_imgs=2).cuda()
    # mask_net = getattr(models, args.masknet)(nb_ref_imgs=4).cuda()
    flow_net = getattr(models, args.flownet)(nlevels=args.nlevels).cuda()

    dispnet_weights = torch.load(args.pretrained_disp)
    posenet_weights = torch.load(args.pretrained_pose)
    # masknet_weights = torch.load(args.pretrained_mask)
    flownet_weights = torch.load(args.pretrained_flow)
    disp_net.load_state_dict(dispnet_weights['state_dict'])
    pose_net.load_state_dict(posenet_weights['state_dict'])
    flow_net.load_state_dict(flownet_weights['state_dict'])
    # mask_net.load_state_dict(masknet_weights['state_dict'])

    disp_net.eval()
    pose_net.eval()
    # mask_net.eval()
    flow_net.eval()

    error_names = [
        'epe_total', 'epe_sp', 'epe_mv', 'Fl', 'epe_total_gt_mask',
        'epe_sp_gt_mask', 'epe_mv_gt_mask', 'Fl_gt_mask'
    ]
    errors = AverageMeter(i=len(error_names))
    for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv, flow_gt,
            obj_map_gt) in enumerate(tqdm(val_loader)):
        tgt_img_var = Variable(tgt_img.cuda(), volatile=True)
        ref_imgs_var = [
            Variable(img.cuda(), volatile=True) for img in ref_imgs
        ]
        intrinsics_var = Variable(intrinsics.cuda(), volatile=True)
        intrinsics_inv_var = Variable(intrinsics_inv.cuda(), volatile=True)

        flow_gt_var = Variable(flow_gt.cuda(), volatile=True)
        obj_map_gt_var = Variable(obj_map_gt.cuda(), volatile=True)

        disp = disp_net(tgt_img_var)
        depth = 1 / disp
        pose = pose_net(tgt_img_var, ref_imgs_var)
        # explainability_mask = mask_net(tgt_img_var, ref_imgs_var)
        # print(len(explainability_mask))

        if args.flownet == 'Back2Future':
            flow_fwd, flow_bwd, _ = flow_net(tgt_img_var, ref_imgs_var)
        else:
            flow_fwd = flow_net(tgt_img_var, ref_imgs_var[2])

        flow_cam = pose2flow(depth.squeeze(1), pose[:, 1], intrinsics_var,
                             intrinsics_inv_var)

        # flow_cam_bwd = pose2flow(depth.squeeze(1), pose[:,1], intrinsics_var, intrinsics_inv_var)
        #---------------------------------------------------------------

        flows_cam_fwd = [
            pose2flow(depth_.squeeze(1), pose[:, 1], intrinsics_var,
                      intrinsics_inv_var) for depth_ in depth
        ]
        flows_cam_bwd = [
            pose2flow(depth_.squeeze(1), pose[:, 0], intrinsics_var,
                      intrinsics_inv_var) for depth_ in depth
        ]
        flow_fwd_list = []
        flow_fwd_list.append(flow_fwd)
        flow_bwd_list = []
        flow_bwd_list.append(flow_bwd)
        rigidity_mask_fwd = consensus_exp_masks(flows_cam_fwd,
                                                flows_cam_bwd,
                                                flow_fwd_list,
                                                flow_bwd_list,
                                                tgt_img_var,
                                                ref_imgs_var[1],
                                                ref_imgs_var[0],
                                                wssim=0.85,
                                                wrig=1.0,
                                                ws=0.1)[0]
        del flow_fwd_list
        del flow_bwd_list
        #--------------------------------------------------------------

        #rigidity_mask = 1 - (1-explainability_mask[:,1])*(1-explainability_mask[:,2]).unsqueeze(1) > 0.5
        rigidity_mask_census_soft = (flow_cam - flow_fwd).abs()  #.normalize()
        #rigidity_mask_census_u = rigidity_mask_census_soft[:,0] < args.THRESH
        #rigidity_mask_census_v = rigidity_mask_census_soft[:,1] < args.THRESH
        #rigidity_mask_census = (rigidity_mask_census_u).type_as(flow_fwd) * (rigidity_mask_census_v).type_as(flow_fwd)

        # rigidity_mask_census = ( torch.pow( (torch.pow(rigidity_mask_census_soft[:,0],2) + torch.pow(rigidity_mask_census_soft[:,1] , 2)), 0.5) < args.THRESH ).type_as(flow_fwd)
        THRESH_1 = 1
        THRESH_2 = 1
        rigidity_mask_census = (
            (torch.pow(rigidity_mask_census_soft[:, 0], 2) +
             torch.pow(rigidity_mask_census_soft[:, 1], 2)) < THRESH_1 *
            (flow_cam.pow(2).sum(dim=1) + flow_fwd.pow(2).sum(dim=1)) +
            THRESH_2).type_as(flow_fwd)

        # rigidity_mask_census = torch.zeros_like(rigidity_mask_census)
        rigidity_mask_fwd = torch.zeros_like(rigidity_mask_fwd)
        rigidity_mask_combined = 1 - (1 - rigidity_mask_fwd) * (
            1 - rigidity_mask_census)  #
        obj_map_gt_var_expanded = obj_map_gt_var.unsqueeze(1).type_as(flow_fwd)

        flow_fwd_non_rigid = (rigidity_mask_combined <= args.THRESH).type_as(
            flow_fwd).expand_as(flow_fwd) * flow_fwd
        flow_fwd_rigid = (rigidity_mask_combined > args.THRESH
                          ).type_as(flow_cam).expand_as(flow_cam) * flow_cam
        total_flow = flow_fwd_rigid + flow_fwd_non_rigid

        # rigidity_mask = rigidity_mask.type_as(flow_fwd)
        _epe_errors = compute_all_epes(
            flow_gt_var, flow_cam, flow_fwd,
            torch.zeros_like(rigidity_mask_combined)) + compute_all_epes(
                flow_gt_var, flow_cam, flow_fwd, (1 - obj_map_gt_var_expanded))
        errors.update(_epe_errors)

        tgt_img_np = tgt_img[0].numpy()
        rigidity_mask_combined_np = rigidity_mask_combined.cpu().data[0].numpy(
        )
        gt_mask_np = obj_map_gt[0].numpy()

        if args.output_dir is not None:
            np.save(image_dir / str(i).zfill(3), tgt_img_np)
            np.save(gt_dir / str(i).zfill(3), gt_mask_np)
            np.save(mask_dir / str(i).zfill(3), rigidity_mask_combined_np)

        if (args.output_dir is not None):
            tmp1 = flow_fwd.data[0].permute(1, 2, 0).cpu().numpy()
            tmp1 = flow_2_image(tmp1)
            scipy.misc.imsave(viz_dir / str(i).zfill(3) + 'flow.png', tmp1)

    print("Results")
    print("\t {:>10}, {:>10}, {:>10}, {:>6}, {:>10}, {:>10}, {:>10}, {:>10} ".
          format(*error_names))
    print(
        "Errors \t {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}"
        .format(*errors.avg))
コード例 #8
0
def main():
    global args
    args = parser.parse_args()
    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                            std=[0.5, 0.5, 0.5])
    flow_loader_h, flow_loader_w = 256, 832
    valid_flow_transform = custom_transforms.Compose([
        custom_transforms.Scale(h=flow_loader_h, w=flow_loader_w),
        custom_transforms.ArrayToTensor(), normalize
    ])
    if args.dataset == "kitti2015":
        val_flow_set = ValidationFlow(
            root='/home/anuragr/datasets/kitti/kitti2015',
            sequence_length=5,
            transform=valid_flow_transform)
    elif args.dataset == "kitti2012":
        val_flow_set = ValidationFlowKitti2012(
            root='/is/ps2/aranjan/AllFlowData/kitti/kitti2012',
            sequence_length=5,
            transform=valid_flow_transform)

    val_flow_loader = torch.utils.data.DataLoader(val_flow_set,
                                                  batch_size=1,
                                                  shuffle=False,
                                                  num_workers=2,
                                                  pin_memory=True,
                                                  drop_last=True)

    flow_net = getattr(models, args.flownet)(nlevels=args.nlevels).cuda()

    if args.pretrained_flow:
        print("=> using pre-trained weights from {}".format(
            args.pretrained_flow))
        weights = torch.load(args.pretrained_flow)
        flow_net.load_state_dict(weights['state_dict'])  #, strict=False)

    flow_net = flow_net.cuda()
    flow_net.eval()
    error_names = ['epe_total', 'epe_non_rigid', 'epe_rigid', 'outliers']
    errors = AverageMeter(i=len(error_names))

    for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv, flow_gt,
            obj_map) in enumerate(tqdm(val_flow_loader)):
        tgt_img_var = Variable(tgt_img.cuda(), volatile=True)
        if args.dataset == "kitti2015":
            ref_imgs_var = [
                Variable(img.cuda(), volatile=True) for img in ref_imgs
            ]
            ref_img_var = ref_imgs_var[1:3]
        elif args.dataset == "kitti2012":
            ref_img_var = Variable(ref_imgs.cuda(), volatile=True)

        flow_gt_var = Variable(flow_gt.cuda(), volatile=True)
        # compute output
        flow_fwd, flow_bwd, occ = flow_net(tgt_img_var, ref_img_var)
        #epe = compute_epe(gt=flow_gt_var, pred=flow_fwd)
        obj_map_gt_var = Variable(obj_map.cuda(), volatile=True)
        obj_map_gt_var_expanded = obj_map_gt_var.unsqueeze(1).type_as(flow_fwd)

        epe = compute_all_epes(flow_gt_var, flow_fwd, flow_fwd,
                               (1 - obj_map_gt_var_expanded))
        #print(i, epe)
        errors.update(epe)

    print("Averge EPE", errors.avg)
コード例 #9
0
ファイル: submit_flow.py プロジェクト: maxuanquang/cc
def main():
    global args
    args = parser.parse_args()
    args.pretrained_path = Path(args.pretrained_path)

    if args.output_dir is not None:
        args.output_dir = Path(args.output_dir)
        args.output_dir.makedirs_p()

        image_dir = args.output_dir / 'images'
        mask_dir = args.output_dir / 'mask'
        viz_dir = args.output_dir / 'viz'
        testing_dir = args.output_dir / 'testing'
        testing_dir_flo = args.output_dir / 'testing_flo'

        image_dir.makedirs_p()
        mask_dir.makedirs_p()
        viz_dir.makedirs_p()
        testing_dir.makedirs_p()
        testing_dir_flo.makedirs_p()

    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                            std=[0.5, 0.5, 0.5])
    flow_loader_h, flow_loader_w = 256, 832
    valid_flow_transform = custom_transforms.Compose([
        custom_transforms.Scale(h=flow_loader_h, w=flow_loader_w),
        custom_transforms.ArrayToTensor(), normalize
    ])

    val_flow_set = KITTI2015Test(root=args.kitti_dir,
                                 sequence_length=5,
                                 transform=valid_flow_transform)

    if args.DEBUG:
        print("DEBUG MODE: Using Training Set")
        val_flow_set = KITTI2015Test(root=args.kitti_dir,
                                     sequence_length=5,
                                     transform=valid_flow_transform,
                                     phase='training')

    val_loader = torch.utils.data.DataLoader(val_flow_set,
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=2,
                                             pin_memory=True,
                                             drop_last=True)

    disp_net = getattr(models, args.dispnet)().cuda()
    pose_net = getattr(models, args.posenet)(nb_ref_imgs=4).cuda()
    mask_net = getattr(models, args.masknet)(nb_ref_imgs=4).cuda()
    flow_net = getattr(models, args.flownet)(nlevels=args.nlevels).cuda()

    dispnet_weights = torch.load(args.pretrained_path /
                                 'dispnet_model_best.pth.tar')
    posenet_weights = torch.load(args.pretrained_path /
                                 'posenet_model_best.pth.tar')
    masknet_weights = torch.load(args.pretrained_path /
                                 'masknet_model_best.pth.tar')
    flownet_weights = torch.load(args.pretrained_path /
                                 'flownet_model_best.pth.tar')
    disp_net.load_state_dict(dispnet_weights['state_dict'])
    pose_net.load_state_dict(posenet_weights['state_dict'])
    flow_net.load_state_dict(flownet_weights['state_dict'])
    mask_net.load_state_dict(masknet_weights['state_dict'])

    disp_net.eval()
    pose_net.eval()
    mask_net.eval()
    flow_net.eval()

    for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv,
            tgt_img_original) in enumerate(tqdm(val_loader)):
        tgt_img_var = Variable(tgt_img.cuda(), volatile=True)
        ref_imgs_var = [
            Variable(img.cuda(), volatile=True) for img in ref_imgs
        ]
        intrinsics_var = Variable(intrinsics.cuda(), volatile=True)
        intrinsics_inv_var = Variable(intrinsics_inv.cuda(), volatile=True)

        disp = disp_net(tgt_img_var)
        depth = 1 / disp
        pose = pose_net(tgt_img_var, ref_imgs_var)
        explainability_mask = mask_net(tgt_img_var, ref_imgs_var)
        if args.flownet == 'Back2Future':
            flow_fwd, _, _ = flow_net(tgt_img_var, ref_imgs_var[1:3])
        else:
            flow_fwd = flow_net(tgt_img_var, ref_imgs_var[2])
        flow_cam = pose2flow(depth.squeeze(1), pose[:, 2], intrinsics_var,
                             intrinsics_inv_var)

        rigidity_mask = 1 - (1 - explainability_mask[:, 1]) * (
            1 - explainability_mask[:, 2]).unsqueeze(1) > 0.5

        rigidity_mask_census_soft = (flow_cam - flow_fwd).abs()  #.normalize()
        rigidity_mask_census_u = rigidity_mask_census_soft[:, 0] < args.THRESH
        rigidity_mask_census_v = rigidity_mask_census_soft[:, 1] < args.THRESH
        rigidity_mask_census = (rigidity_mask_census_u).type_as(flow_fwd) * (
            rigidity_mask_census_v).type_as(flow_fwd)
        rigidity_mask_combined = 1 - (
            1 - rigidity_mask.type_as(explainability_mask)) * (
                1 - rigidity_mask_census.type_as(explainability_mask))

        _, _, h_pred, w_pred = flow_cam.size()
        _, _, h_gt, w_gt = tgt_img_original.size()
        rigidity_pred_mask = nn.functional.upsample(rigidity_mask_combined,
                                                    size=(h_pred, w_pred),
                                                    mode='bilinear')

        non_rigid_pred = (rigidity_pred_mask <= args.THRESH
                          ).type_as(flow_fwd).expand_as(flow_fwd) * flow_fwd
        rigid_pred = (rigidity_pred_mask > args.THRESH
                      ).type_as(flow_cam).expand_as(flow_cam) * flow_cam
        total_pred = non_rigid_pred + rigid_pred

        pred_fullres = nn.functional.upsample(total_pred,
                                              size=(h_gt, w_gt),
                                              mode='bilinear')
        pred_fullres[:, 0, :, :] = pred_fullres[:, 0, :, :] * (w_gt / w_pred)
        pred_fullres[:, 1, :, :] = pred_fullres[:, 1, :, :] * (h_gt / h_pred)

        flow_fwd_fullres = nn.functional.upsample(flow_fwd,
                                                  size=(h_gt, w_gt),
                                                  mode='bilinear')
        flow_fwd_fullres[:,
                         0, :, :] = flow_fwd_fullres[:,
                                                     0, :, :] * (w_gt / w_pred)
        flow_fwd_fullres[:,
                         1, :, :] = flow_fwd_fullres[:,
                                                     1, :, :] * (h_gt / h_pred)

        flow_cam_fullres = nn.functional.upsample(flow_cam,
                                                  size=(h_gt, w_gt),
                                                  mode='bilinear')
        flow_cam_fullres[:,
                         0, :, :] = flow_cam_fullres[:,
                                                     0, :, :] * (w_gt / w_pred)
        flow_cam_fullres[:,
                         1, :, :] = flow_cam_fullres[:,
                                                     1, :, :] * (h_gt / h_pred)

        tgt_img_np = tgt_img[0].numpy()
        rigidity_mask_combined_np = rigidity_mask_combined.cpu().data[0].numpy(
        )

        if args.output_dir is not None:
            np.save(image_dir / str(i).zfill(3), tgt_img_np)
            np.save(mask_dir / str(i).zfill(3), rigidity_mask_combined_np)
            pred_u = pred_fullres[0][0].data.cpu().numpy()
            pred_v = pred_fullres[0][1].data.cpu().numpy()
            flow_io.flow_write_png(testing_dir / str(i).zfill(6) + '_10.png',
                                   u=pred_u,
                                   v=pred_v)
            flow_io.flow_write(testing_dir_flo / str(i).zfill(6) + '_10.flo',
                               pred_u, pred_v)

        if (args.output_dir is not None):
            ind = int(i)
            tgt_img_viz = tensor2array(tgt_img[0].cpu())
            depth_viz = tensor2array(disp.data[0].cpu(),
                                     max_value=None,
                                     colormap='magma')
            mask_viz = tensor2array(rigidity_mask_combined.data[0].cpu(),
                                    max_value=1,
                                    colormap='magma')
            row2_viz = flow_to_image(
                np.hstack((tensor2array(flow_cam_fullres.data[0].cpu()),
                           tensor2array(flow_fwd_fullres.data[0].cpu()),
                           tensor2array(pred_fullres.data[0].cpu()))))

            row1_viz = np.hstack((tgt_img_viz, depth_viz, mask_viz))

            row1_viz_im = Image.fromarray(
                (255 * row1_viz.transpose(1, 2, 0)).astype('uint8'))
            row2_viz_im = Image.fromarray(
                (255 * row2_viz.transpose(1, 2, 0)).astype('uint8'))

            row1_viz_im.save(viz_dir / str(i).zfill(3) + '01.png')
            row2_viz_im.save(viz_dir / str(i).zfill(3) + '02.png')

    print("Done!")
コード例 #10
0
def main():
    global args, best_error, n_iter
    args = parser.parse_args()
    if args.dataset_format == 'stacked':
        from datasets.stacked_sequence_folders import SequenceFolder
    elif args.dataset_format == 'sequential':
        from datasets.sequence_folders import SequenceFolder
    save_path = Path(args.name)
    args.save_path = 'checkpoints'/save_path #/timestamp
    print('=> will save everything to {}'.format(args.save_path))
    args.save_path.makedirs_p()
    torch.manual_seed(args.seed)
    if args.alternating:
        args.alternating_flags = np.array([False,False,True])

    training_writer = SummaryWriter(args.save_path)
    output_writers = []
    if args.log_output:
        for i in range(3):
            output_writers.append(SummaryWriter(args.save_path/'valid'/str(i)))

    # Data loading code
    flow_loader_h, flow_loader_w = 256, 832

    if args.data_normalization =='global':
        normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                                std=[0.5, 0.5, 0.5])
    elif args.data_normalization =='local':
        normalize = custom_transforms.NormalizeLocally()


    train_transform = custom_transforms.Compose([
        custom_transforms.RandomHorizontalFlip(),
        custom_transforms.RandomScaleCrop(),
        custom_transforms.ArrayToTensor(),
        normalize
    ])
 

    valid_transform = custom_transforms.Compose([custom_transforms.ArrayToTensor(), normalize])

    valid_flow_transform = custom_transforms.Compose([custom_transforms.Scale(h=flow_loader_h, w=flow_loader_w),
                            custom_transforms.ArrayToTensor(), normalize])

    print("=> fetching scenes in '{}'".format(args.data))
    train_set = SequenceFolder(
        args.data,
        transform=train_transform,
        seed=args.seed,
        train=True,
        sequence_length=args.sequence_length
    )

    # if no Groundtruth is avalaible, Validation set is the same type as training set to measure photometric loss from warping
    
    val_set = SequenceFolder(
        args.data,
        transform=valid_transform,
        seed=args.seed,
        train=False,
        sequence_length=args.sequence_length,
    )

    if args.with_flow_gt:
        from datasets.validation_flow import ValidationFlow
        val_flow_set = ValidationFlow(root=args.kitti_dir,
                                        sequence_length=args.sequence_length, transform=valid_flow_transform)

    if args.DEBUG:
        train_set.__len__ = 32
        train_set.samples = train_set.samples[:32]

    print('{} samples found in {} train scenes'.format(len(train_set), len(train_set.scenes)))
    print('{} samples found in {} valid scenes'.format(len(val_set), len(val_set.scenes)))
    train_loader = torch.utils.data.DataLoader(
        train_set, batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True, drop_last=True)
    val_loader = torch.utils.data.DataLoader(
        val_set, batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True, drop_last=True)

    if args.with_flow_gt:
        val_flow_loader = torch.utils.data.DataLoader(val_flow_set, batch_size=1,               # batch size is 1 since images in kitti have different sizes
                        shuffle=False, num_workers=args.workers, pin_memory=True, drop_last=True)

    if args.epoch_size == 0:
        args.epoch_size = len(train_loader)

    # create model
    print("=> creating model")
    
    if args.flownet=='SpyNet':
        flow_net = getattr(models, args.flownet)(nlevels=args.nlevels, pre_normalization=normalize).cuda()
    else:
        flow_net = getattr(models, args.flownet)(nlevels=args.nlevels).cuda()

    # load pre-trained weights

    if args.pretrained_flow:
        print("=> using pre-trained weights for FlowNet")
        weights = torch.load(args.pretrained_flow)
        flow_net.load_state_dict(weights['state_dict'])
    # else:
        #flow_net.init_weights()


    if args.resume:
        print("=> resuming from checkpoint")  
        flownet_weights = torch.load(args.save_path/'flownet_checkpoint.pth.tar')
        flow_net.load_state_dict(flownet_weights['state_dict'])


    # import ipdb; ipdb.set_trace()
    cudnn.benchmark = True
    flow_net = torch.nn.DataParallel(flow_net)

    print('=> setting adam solver')
    parameters = chain(flow_net.parameters())
    optimizer = torch.optim.Adam(parameters, args.lr,
                                 betas=(args.momentum, args.beta),
                                 weight_decay=args.weight_decay)

    milestones = [300]
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones, gamma=0.1, last_epoch=-1)

    if args.min:
        print("using min method")

    if args.resume and (args.save_path/'optimizer_checkpoint.pth.tar').exists():
        print("=> loading optimizer from checkpoint")
        optimizer_weights = torch.load(args.save_path/'optimizer_checkpoint.pth.tar')
        optimizer.load_state_dict(optimizer_weights['state_dict'])

    with open(args.save_path/args.log_summary, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss', 'validation_loss'])

    with open(args.save_path/args.log_full, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss', 'photo_cam_loss', 'photo_flow_loss', 'explainability_loss', 'smooth_loss'])

    if args.log_terminal:
        logger = TermLogger(n_epochs=args.epochs, train_size=min(len(train_loader), args.epoch_size), valid_size=len(val_loader))
        logger.epoch_bar.start()
    else:
        logger=None

    for epoch in range(args.epochs):
        scheduler.step()

        if args.fix_flownet:
            for fparams in flow_net.parameters():
                fparams.requires_grad = False

        if args.log_terminal:
            logger.epoch_bar.update(epoch)
            logger.reset_train_bar()

        # train for one epoch
        train_loss = train(train_loader, flow_net, optimizer, args.epoch_size, logger, training_writer)

        if args.log_terminal:
            logger.train_writer.write(' * Avg Loss : {:.3f}'.format(train_loss))
            logger.reset_valid_bar()


        if args.with_flow_gt:
            flow_errors, flow_error_names = validate_flow_with_gt(val_flow_loader, flow_net, epoch, logger, output_writers)

            error_string = ', '.join('{} : {:.3f}'.format(name, error) for name, error in zip(flow_error_names, flow_errors))

            if args.log_terminal:
                logger.valid_writer.write(' * Avg {}'.format(error_string))
            else:
                print('Epoch {} completed'.format(epoch))

            for error, name in zip(flow_errors, flow_error_names):
                training_writer.add_scalar(name, error, epoch)

        
        decisive_error = flow_errors[0]
        if best_error < 0:
            best_error = decisive_error

        # remember lowest error and save checkpoint
        is_best = decisive_error <= best_error
        best_error = min(best_error, decisive_error)
        save_checkpoint(
            args.save_path, {
                'epoch': epoch + 1,
                'state_dict': flow_net.module.state_dict()
            }, {
                'epoch': epoch + 1,
                'state_dict': optimizer.state_dict()
            },
            is_best)

        with open(args.save_path/args.log_summary, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([train_loss, decisive_error])
    if args.log_terminal:
        logger.epoch_bar.finish()
コード例 #11
0
def main():
    global args
    args = parser.parse_args()
    save_path = Path(args.name)
    args.save_path = 'results' / save_path  #/timestamp
    print('=> will save everything to {}'.format(args.save_path))
    args.save_path.makedirs_p()
    output_vis_dir = args.save_path / 'images'
    output_vis_dir.makedirs_p()

    args.batch_size = 1

    output_writer = SummaryWriter(args.save_path / 'valid')

    # Data loading code
    flow_loader_h, flow_loader_w = 384, 1280

    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                            std=[0.5, 0.5, 0.5])

    # valid_transform = custom_transforms.Compose([custom_transforms.Scale(h=flow_loader_h, w=flow_loader_w),
    #                         custom_transforms.ArrayToTensor(), normalize])
    valid_transform = custom_transforms.Compose([
        custom_transforms.Scale(h=flow_loader_h, w=flow_loader_w),
        custom_transforms.ArrayToTensor()
    ])

    if args.valset == "kitti2015":
        # from datasets.validation_flow import ValidationFlowKitti2015MV
        # val_set = ValidationFlowKitti2015MV(root='/ps/project/datasets/AllFlowData/kitti/kitti2015', transform=valid_transform, compression=args.compression, raw_root='/is/rg/avg/jjanai/data/Kitti_2012_2015/Raw', example=args.example, true_motion=args.true_motion)
        from datasets.validation_flow import ValidationFlowKitti2015
        # # val_set = ValidationFlowKitti2015(root='/is/ps2/aranjan/AllFlowData/kitti/kitti2015', transform=valid_transform, compression=args.compression)
        val_set = ValidationFlowKitti2015(
            root='/ps/project/datasets/AllFlowData/kitti/kitti2015',
            transform=valid_transform,
            compression=args.compression,
            raw_root='/is/rg/avg/jjanai/data/Kitti_2012_2015/Raw',
            example=args.example,
            true_motion=args.true_motion)
    elif args.valset == "kitti2012":
        from datasets.validation_flow import ValidationFlowKitti2012
        # val_set = ValidationFlowKitti2012(root='/is/ps2/aranjan/AllFlowData/kitti/kitti2012', transform=valid_transform, compression=args.compression)
        val_set = ValidationFlowKitti2012(
            root='/ps/project/datasets/AllFlowData/kitti/kitti2012',
            transform=valid_transform,
            compression=args.compression,
            raw_root='/is/rg/avg/jjanai/data/Kitti_2012_2015/Raw')

    print('{} samples found in valid scenes'.format(len(val_set)))

    val_loader = torch.utils.data.DataLoader(
        val_set,
        batch_size=
        1,  # batch size is 1 since images in kitti have different sizes
        shuffle=False,
        num_workers=args.workers,
        pin_memory=True,
        drop_last=True)

    result_file = open(os.path.join(args.save_path, 'results.csv'), 'a')
    result_scene_file = open(os.path.join(args.save_path, 'result_scenes.csv'),
                             'a')

    # create model
    print("=> fetching model")

    if args.flownet == 'SpyNet':
        flow_net = getattr(models, args.flownet)(nlevels=6, pretrained=True)
    elif args.flownet == 'Back2Future':
        flow_net = getattr(
            models, args.flownet)(pretrained='pretrained/b2f_rm_hard.pth.tar')
    elif args.flownet == 'PWCNet':
        flow_net = models.pwc_dc_net(
            'pretrained/pwc_net_chairs.pth.tar')  # pwc_net.pth.tar')
    else:
        flow_net = getattr(models, args.flownet)()

    if args.flownet in ['SpyNet', 'Back2Future', 'PWCNet']:
        print("=> using pre-trained weights for " + args.flownet)
    elif args.flownet in ['FlowNetC']:
        print("=> using pre-trained weights for FlowNetC")
        weights = torch.load('pretrained/FlowNet2-C_checkpoint.pth.tar')
        flow_net.load_state_dict(weights['state_dict'])
    elif args.flownet in ['FlowNetS']:
        print("=> using pre-trained weights for FlowNetS")
        weights = torch.load('pretrained/flownets.pth.tar')
        flow_net.load_state_dict(weights['state_dict'])
    elif args.flownet in ['FlowNet2']:
        print("=> using pre-trained weights for FlowNet2")
        weights = torch.load('pretrained/FlowNet2_checkpoint.pth.tar')
        flow_net.load_state_dict(weights['state_dict'])
    else:
        flow_net.init_weights()

    flow_net = flow_net.cuda()

    cudnn.benchmark = True

    if args.whole_img == 0 and args.compression == 0:
        print("Loading patch from ", args.patch_path)
        patch = torch.load(args.patch_path)
        patch_shape = patch.shape
        if args.mask_path:
            mask_image = load_as_float(args.mask_path)
            mask_image = imresize(mask_image,
                                  (patch_shape[-1], patch_shape[-2])) / 256.
            mask = np.array([mask_image.transpose(2, 0, 1)])
        else:
            if args.patch_type == 'circle':
                mask = createCircularMask(patch_shape[-2],
                                          patch_shape[-1]).astype('float32')
                mask = np.array([[mask, mask, mask]])
            elif args.patch_type == 'square':
                mask = np.ones(patch_shape)
    else:
        # add gaussian noise
        mean = 0
        var = 1
        sigma = var**0.5
        patch = np.random.normal(mean, sigma,
                                 (flow_loader_h, flow_loader_w, 3))
        patch = patch.reshape(3, flow_loader_h, flow_loader_w)
        mask = np.ones(patch.shape) * args.whole_img

    #import ipdb; ipdb.set_trace()
    error_names = ['epe', 'adv_epe', 'cos_sim', 'adv_cos_sim']
    errors = AverageMeter(i=len(error_names))

    # header
    result_file.write("{:>10}, {:>10}, {:>10}, {:>10}\n".format(*error_names))
    result_scene_file.write("{:>10}, {:>10}, {:>10}, {:>10}, {:>10}\n".format(
        *(['scene'] + error_names)))

    flow_net.eval()

    # set seed for reproductivity
    np.random.seed(1337)

    for i, (ref_img_past, tgt_img, ref_img, flow_gt, disp_gt, calib,
            poses) in enumerate(tqdm(val_loader)):
        tgt_img_var = Variable(tgt_img.cuda(), volatile=True)
        ref_past_img_var = Variable(ref_img_past.cuda(), volatile=True)
        ref_img_var = Variable(ref_img.cuda(), volatile=True)
        flow_gt_var = Variable(flow_gt.cuda(), volatile=True)

        if type(flow_net).__name__ == 'Back2Future':
            flow_fwd = flow_net(ref_past_img_var, tgt_img_var, ref_img_var)
        else:
            flow_fwd = flow_net(tgt_img_var, ref_img_var)

        data_shape = tgt_img.cpu().numpy().shape

        margin = 0
        if len(calib) > 0:
            margin = int(disp_gt.max())

        random_x = args.fixed_loc_x
        random_y = args.fixed_loc_y
        if args.whole_img == 0:
            if args.patch_type == 'circle':
                patch_full, mask_full, _, random_x, random_y, _ = circle_transform(
                    patch,
                    mask,
                    patch.copy(),
                    data_shape,
                    patch_shape,
                    margin,
                    norotate=args.norotate,
                    fixed_loc=(random_x, random_y))
            elif args.patch_type == 'square':
                patch_full, mask_full, _, _, _ = square_transform(
                    patch,
                    mask,
                    patch.copy(),
                    data_shape,
                    patch_shape,
                    norotate=args.norotate)
            patch_full, mask_full = torch.FloatTensor(
                patch_full), torch.FloatTensor(mask_full)
        else:
            patch_full, mask_full = torch.FloatTensor(
                patch), torch.FloatTensor(mask)

        patch_full, mask_full = patch_full.cuda(), mask_full.cuda()
        patch_var, mask_var = Variable(patch_full), Variable(mask_full)

        patch_var_future = patch_var_past = patch_var
        mask_var_future = mask_var_past = mask_var

        # adverserial flow
        bt, _, h_gt, w_gt = flow_gt_var.shape
        forward_patch_flow = Variable(torch.cat((torch.zeros(
            (bt, 2, h_gt, w_gt)), torch.ones((bt, 1, h_gt, w_gt))), 1).cuda(),
                                      volatile=True)

        # project patch into 3D scene
        if len(calib) > 0:
            # #################################### ONLY WORKS WITH BATCH SIZE 1 ####################################
            imu2vel = calib['imu2vel']["RT"][0].numpy()
            imu2cam = calib['P_imu_cam'][0].numpy()
            imu2img = calib['P_imu_img'][0].numpy()

            pose_past = poses[0][0].numpy()
            pose_ref = poses[1][0].numpy()
            inv_pose_ref = inv(pose_ref)
            pose_fut = poses[2][0].numpy()

            # get point in IMU
            patch_disp = disp_gt[0, random_y:random_y + patch_shape[-2],
                                 random_x:random_x + patch_shape[-1]]
            valid = (patch_disp > 0)
            # set to object or free space disparity
            if False and args.fixed_loc_x > 0 and args.fixed_loc_y > 0:
                # disparity = patch_disp[valid].mean() - 3  # small correction for gps errors
                disparity = patch_disp[valid].mean()
            else:
                subset = patch_disp[valid]
                min_disp = 0
                if len(subset) > 0:
                    min_disp = subset.min()
                max_disp = disp_gt.max()

                disparity = np.random.uniform(min_disp, max_disp)  # disparity

            # print('Disp from ', min_disp, ' to ', max_disp)
            depth = (calib['cam']['focal_length_x'] *
                     calib['cam']['baseline'] / disparity)
            p_cam0 = np.array([[0], [0], [0], [1]])
            p_cam0[0] = depth * (
                random_x - calib['cam']['cx']) / calib['cam']['focal_length_x']
            p_cam0[1] = depth * (
                random_y - calib['cam']['cy']) / calib['cam']['focal_length_y']
            p_cam0[2] = depth

            # transform
            T_p_cam0 = np.eye(4)
            T_p_cam0[0:4, 3:4] = p_cam0

            # transformation to generate patch points
            patch_size = -0.25
            pts = np.array([[0, 0, 0, 1], [0, patch_size, 0, 1],
                            [patch_size, 0, 0, 1],
                            [patch_size, patch_size, 0, 1]]).T
            pts = inv(imu2cam).dot(T_p_cam0.dot(pts))

            # get points in reference image
            pts_src = pose_ref.dot(pts)
            pts_src = imu2img.dot(pts_src)
            pts_src = pts_src[:3, :] / pts_src[2:3, :].repeat(3, 0)

            # get points in past image
            pts_past = pose_past.dot(pts)
            pts_past = imu2img.dot(pts_past)
            pts_past = pts_past[:3, :] / pts_past[2:3, :].repeat(3, 0)

            # get points in future image
            pts_fut = pose_fut.dot(pts)
            pts_fut = imu2img.dot(pts_fut)
            pts_fut = pts_fut[:3, :] / pts_fut[2:3, :].repeat(3, 0)

            # find homography between points
            H_past, _ = cv2.findHomography(pts_src.T, pts_past.T, cv2.RANSAC)
            H_fut, _ = cv2.findHomography(pts_src.T, pts_fut.T, cv2.RANSAC)

            # import pdb; pdb.set_trace()
            refMtrx = torch.from_numpy(H_fut).float().cuda()
            refMtrx = refMtrx.repeat(args.batch_size, 1, 1)
            # get pixel origins
            X, Y = np.meshgrid(np.arange(flow_loader_w),
                               np.arange(flow_loader_h))
            X, Y = X.flatten(), Y.flatten()
            XYhom = np.stack([X, Y, np.ones_like(X)], axis=1).T
            XYhom = np.tile(XYhom, [args.batch_size, 1, 1]).astype(np.float32)
            XYhom = torch.from_numpy(XYhom).cuda()
            XHom, YHom, Zom = torch.unbind(XYhom, dim=1)
            XHom = XHom.resize_(
                (args.batch_size, flow_loader_h, flow_loader_w))
            YHom = YHom.resize_(
                (args.batch_size, flow_loader_h, flow_loader_w))
            # warp the canonical coordinates
            XYwarpHom = refMtrx.matmul(XYhom)
            XwarpHom, YwarpHom, ZwarpHom = torch.unbind(XYwarpHom, dim=1)
            Xwarp = (XwarpHom / (ZwarpHom + 1e-8)).resize_(
                (args.batch_size, flow_loader_h, flow_loader_w))
            Ywarp = (YwarpHom / (ZwarpHom + 1e-8)).resize_(
                (args.batch_size, flow_loader_h, flow_loader_w))
            # get forward flow
            u = (XHom - Xwarp).unsqueeze(1)
            v = (YHom - Ywarp).unsqueeze(1)
            flow = torch.cat((u, v), 1)
            flow = nn.functional.upsample(flow,
                                          size=(h_gt, w_gt),
                                          mode='bilinear')
            flow[:, 0, :, :] = flow[:, 0, :, :] * (w_gt / flow_loader_w)
            flow[:, 1, :, :] = flow[:, 1, :, :] * (h_gt / flow_loader_h)
            forward_patch_flow[:, :2, :, :] = flow
            # get grid for resampling
            Xwarp = 2 * ((Xwarp / (flow_loader_w - 1)) - 0.5)
            Ywarp = 2 * ((Ywarp / (flow_loader_h - 1)) - 0.5)
            grid = torch.stack([Xwarp, Ywarp], dim=-1)
            # sampling with bilinear interpolation
            patch_var_future = torch.nn.functional.grid_sample(patch_var,
                                                               grid,
                                                               mode="bilinear")
            mask_var_future = torch.nn.functional.grid_sample(mask_var,
                                                              grid,
                                                              mode="bilinear")

            # use past homography
            refMtrxP = torch.from_numpy(H_past).float().cuda()
            refMtrx = refMtrx.repeat(args.batch_size, 1, 1)
            # warp the canonical coordinates
            XYwarpHomP = refMtrxP.matmul(XYhom)
            XwarpHomP, YwarpHomP, ZwarpHomP = torch.unbind(XYwarpHomP, dim=1)
            XwarpP = (XwarpHomP / (ZwarpHomP + 1e-8)).resize_(
                (args.batch_size, flow_loader_h, flow_loader_w))
            YwarpP = (YwarpHomP / (ZwarpHomP + 1e-8)).resize_(
                (args.batch_size, flow_loader_h, flow_loader_w))
            # get grid for resampling
            XwarpP = 2 * ((XwarpP / (flow_loader_w - 1)) - 0.5)
            YwarpP = 2 * ((YwarpP / (flow_loader_h - 1)) - 0.5)
            gridP = torch.stack([XwarpP, YwarpP], dim=-1)
            # sampling with bilinear interpolation
            patch_var_past = torch.nn.functional.grid_sample(patch_var,
                                                             gridP,
                                                             mode="bilinear")
            mask_var_past = torch.nn.functional.grid_sample(mask_var,
                                                            gridP,
                                                            mode="bilinear")

        adv_tgt_img_var = torch.mul(
            (1 - mask_var), tgt_img_var) + torch.mul(mask_var, patch_var)
        adv_ref_past_img_var = torch.mul(
            (1 - mask_var_past), ref_past_img_var) + torch.mul(
                mask_var_past, patch_var_past)
        adv_ref_img_var = torch.mul(
            (1 - mask_var_future), ref_img_var) + torch.mul(
                mask_var_future, patch_var_future)

        adv_tgt_img_var = torch.clamp(adv_tgt_img_var, -1, 1)
        adv_ref_past_img_var = torch.clamp(adv_ref_past_img_var, -1, 1)
        adv_ref_img_var = torch.clamp(adv_ref_img_var, -1, 1)

        if type(flow_net).__name__ == 'Back2Future':
            adv_flow_fwd = flow_net(adv_ref_past_img_var, adv_tgt_img_var,
                                    adv_ref_img_var)
        else:
            adv_flow_fwd = flow_net(adv_tgt_img_var, adv_ref_img_var)

        # set patch to zero flow!
        mask_var_res = nn.functional.upsample(mask_var,
                                              size=(h_gt, w_gt),
                                              mode='bilinear')

        # Ignore patch motion if set!
        if args.ignore_mask_flow:
            forward_patch_flow = Variable(torch.cat((torch.zeros(
                (bt, 2, h_gt, w_gt)), torch.zeros((bt, 1, h_gt, w_gt))),
                                                    1).cuda(),
                                          volatile=True)

        flow_gt_var_adv = torch.mul(
            (1 - mask_var_res), flow_gt_var) + torch.mul(
                mask_var_res, forward_patch_flow)

        # import pdb; pdb.set_trace()
        epe = compute_epe(gt=flow_gt_var, pred=flow_fwd)
        adv_epe = compute_epe(gt=flow_gt_var_adv, pred=adv_flow_fwd)
        cos_sim = compute_cossim(flow_gt_var, flow_fwd)
        adv_cos_sim = compute_cossim(flow_gt_var_adv, adv_flow_fwd)

        errors.update([epe, adv_epe, cos_sim, adv_cos_sim])

        if i % 1 == 0:
            index = i  #int(i//10)
            imgs = normalize([tgt_img] + [ref_img_past] + [ref_img])
            norm_tgt_img = imgs[0]
            norm_ref_img_past = imgs[1]
            norm_ref_img = imgs[2]

            patch_cpu = patch_var.data[0].cpu()
            mask_cpu = mask_var.data[0].cpu()

            adv_norm_tgt_img = normalize(
                adv_tgt_img_var.data.cpu()
            )  #torch.mul((1-mask_cpu), norm_tgt_img) + torch.mul(mask_cpu, patch_cpu)
            adv_norm_ref_img_past = normalize(
                adv_ref_past_img_var.data.cpu()
            )  # torch.mul((1-mask_cpu), norm_ref_img_past) + torch.mul(mask_cpu, patch_cpu)
            adv_norm_ref_img = normalize(
                adv_ref_img_var.data.cpu()
            )  #torch.mul((1-mask_cpu), norm_ref_img) + torch.mul(mask_cpu, patch_cpu)

            output_writer.add_image(
                'val flow Input',
                transpose_image(tensor2array(norm_tgt_img[0])), 0)
            flow_to_show = flow_gt[0][:2, :, :].cpu()
            output_writer.add_image(
                'val target Flow',
                transpose_image(flow_to_image(tensor2array(flow_to_show))), 0)

            # set flow to zero
            # zero_flow = Variable(torch.zeros(flow_fwd.shape).cuda(), volatile=True)
            # flow_fwd_masked = torch.mul((1-mask_var[:,:2,:,:]), flow_fwd) + torch.mul(mask_var[:,:2,:,:], zero_flow)
            flow_fwd_masked = flow_fwd

            # get ground truth flow
            val_GT_adv = flow_gt_var_adv.data[0].cpu().numpy().transpose(
                1, 2, 0)
            # val_GT_adv = interp_gt_flow(val_GT_adv[:,:,:2], val_GT_adv[:,:,2])
            val_GT_adv = cv2.resize(val_GT_adv, (flow_loader_w, flow_loader_h),
                                    interpolation=cv2.INTER_NEAREST)
            val_GT_adv[:, :, 0] = val_GT_adv[:, :, 0] * (flow_loader_w / w_gt)
            val_GT_adv[:, :, 1] = val_GT_adv[:, :, 1] * (flow_loader_h / h_gt)

            # gt normalization for visualization
            u = val_GT_adv[:, :, 0]
            v = val_GT_adv[:, :, 1]
            idxUnknow = (abs(u) > 1e7) | (abs(v) > 1e7)
            u[idxUnknow] = 0
            v[idxUnknow] = 0
            rad = np.sqrt(u**2 + v**2)
            maxrad = np.max(rad)

            val_GT_adv_Output = flow_to_image(val_GT_adv, maxrad)
            val_GT_adv_Output = cv2.erode(val_GT_adv_Output,
                                          np.ones((3, 3), np.uint8),
                                          iterations=1)  # make points thicker
            val_GT_adv_Output = transpose_image(val_GT_adv_Output) / 255.
            val_Flow_Output = transpose_image(
                flow_to_image(tensor2array(flow_fwd.data[0].cpu()),
                              maxrad)) / 255.
            val_adv_Flow_Output = transpose_image(
                flow_to_image(tensor2array(adv_flow_fwd.data[0].cpu()),
                              maxrad)) / 255.
            val_Diff_Flow_Output = transpose_image(
                flow_to_image(
                    tensor2array(
                        (adv_flow_fwd - flow_fwd_masked).data[0].cpu()),
                    maxrad)) / 255.

            val_tgt_image = transpose_image(tensor2array(norm_tgt_img[0]))
            val_ref_image = transpose_image(tensor2array(norm_ref_img[0]))
            val_adv_tgt_image = transpose_image(
                tensor2array(adv_norm_tgt_img[0]))
            val_adv_ref_image_past = transpose_image(
                tensor2array(adv_norm_ref_img_past[0]))
            val_adv_ref_image = transpose_image(
                tensor2array(adv_norm_ref_img[0]))
            val_patch = transpose_image(tensor2array(patch_var.data.cpu()[0]))
            # print(adv_norm_tgt_img.shape)
            # print(flow_fwd.data[0].cpu().shape)

            # if type(flow_net).__name__ == 'Back2Future':
            #     val_output_viz = np.concatenate((val_adv_ref_image_past, val_adv_tgt_image, val_adv_ref_image, val_Flow_Output, val_adv_Flow_Output, val_Diff_Flow_Output), 2)
            # else:
            # val_output_viz = np.concatenate((val_adv_tgt_image, val_adv_ref_image, val_Flow_Output, val_adv_Flow_Output, val_Diff_Flow_Output, val_GT_adv_Output), 2)
            val_output_viz = np.concatenate(
                (val_ref_image, val_adv_ref_image, val_Flow_Output,
                 val_adv_Flow_Output, val_Diff_Flow_Output, val_GT_adv_Output),
                2)
            val_output_viz_im = Image.fromarray(
                (255 * val_output_viz.transpose(1, 2, 0)).astype('uint8'))
            val_output_viz_im.save(args.save_path / args.name + 'viz' +
                                   str(i).zfill(3) + '.jpg')
            output_writer.add_image('val Output viz {}'.format(index),
                                    val_output_viz, 0)

            #val_output_viz = np.vstack((val_Flow_Output, val_adv_Flow_Output, val_Diff_Flow_Output, val_adv_tgt_image, val_adv_ref_image))
            #scipy.misc.imsave('outfile.jpg', os.path.join(output_vis_dir, 'vis_{}.png'.format(index)))

            result_scene_file.write(
                "{:10d}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}\n".format(
                    i, epe, adv_epe, cos_sim, adv_cos_sim))

    print("{:>10}, {:>10}, {:>10}, {:>10}".format(*error_names))
    print("{:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}".format(*errors.avg))
    result_file.write(
        "{:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}\n".format(*errors.avg))
    result_scene_file.write(
        "{:>10}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}\n".format(
            *(["avg"] + errors.avg)))

    result_file.close()
    result_scene_file.close()
コード例 #12
0
ファイル: test_flownetc.py プロジェクト: maxuanquang/cc
def main():
    global args
    args = parser.parse_args()
    save_path = 'checkpoints/test_flownetc'

    if not os.path.exists(save_path):
        os.makedirs(save_path)
    summary_writer = SummaryWriter(save_path)
    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                            std=[1.0, 1.0, 1.0])
    flow_loader_h, flow_loader_w = 384, 1280
    valid_flow_transform = custom_transforms.Compose([
        custom_transforms.Scale(h=flow_loader_h, w=flow_loader_w),
        custom_transforms.ArrayToTensor(), normalize
    ])
    if args.dataset == "kitti2015":
        val_flow_set = ValidationFlowFlowNetC(
            root='/is/ps2/aranjan/AllFlowData/kitti/kitti2015',
            sequence_length=5,
            transform=valid_flow_transform)
    elif args.dataset == "kitti2012":
        val_flow_set = ValidationFlowKitti2012(
            root='/is/ps2/aranjan/AllFlowData/kitti/kitti2012',
            sequence_length=5,
            transform=valid_flow_transform)

    val_flow_loader = torch.utils.data.DataLoader(val_flow_set,
                                                  batch_size=1,
                                                  shuffle=False,
                                                  num_workers=2,
                                                  pin_memory=True,
                                                  drop_last=True)

    flow_net = getattr(models, args.flownet)(pretrained=True).cuda()

    flow_net.eval()
    error_names = ['epe']
    errors = AverageMeter(i=len(error_names))

    for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv, flow_gt,
            flownet_c_flow, obj_map) in enumerate(val_flow_loader):
        tgt_img_var = Variable(tgt_img.cuda(), volatile=True)
        if args.dataset == "kitti2015":
            ref_imgs_var = [
                Variable(img.cuda(), volatile=True) for img in ref_imgs
            ]
            ref_img_var = ref_imgs_var[2]
        elif args.dataset == "kitti2012":
            ref_img_var = Variable(ref_imgs.cuda(), volatile=True)

        flow_gt_var = Variable(flow_gt.cuda(), volatile=True)
        flownet_c_flow = Variable(flownet_c_flow.cuda(), volatile=True)

        # compute output
        flow_fwd = flow_net(tgt_img_var, ref_img_var)
        epe = compute_epe(gt=flownet_c_flow, pred=flow_fwd)
        scale_factor = compute_epe(gt=flownet_c_flow, pred=flow_fwd, op='div')
        #import ipdb
        #ipdb.set_trace()
        summary_writer.add_image('Frame 1',
                                 tensor2array(tgt_img_var.data[0].cpu()), i)
        summary_writer.add_image('Frame 2',
                                 tensor2array(ref_img_var.data[0].cpu()), i)
        summary_writer.add_image(
            'Flow Output', flow_to_image(tensor2array(flow_fwd.data[0].cpu())),
            i)
        summary_writer.add_image(
            'UnFlow Output',
            flow_to_image(tensor2array(flownet_c_flow.data[0][:2].cpu())), i)
        summary_writer.add_image(
            'gtFlow Output',
            flow_to_image(tensor2array(flow_gt_var.data[0][:2].cpu())), i)
        summary_writer.add_image('EPE Image w UnFlow',
                                 tensor2array(epe.data.cpu()), i)
        summary_writer.add_scalar('EPE mean w UnFlow',
                                  epe.mean().data.cpu(), i)
        summary_writer.add_scalar('EPE max w UnFlow', epe.max().data.cpu(), i)
        summary_writer.add_scalar('Scale Factor max w UnFlow',
                                  scale_factor.max().data.cpu(), i)
        summary_writer.add_scalar('Scale Factor mean w UnFlow',
                                  scale_factor.mean().data.cpu(), i)
        summary_writer.add_scalar('Flow value max',
                                  flow_fwd.max().data.cpu(), i)
        print(i, "EPE: ", epe.mean().item())

        #print(i, epe)
        #errors.update(epe)

    print('Done')