def __init__(self, 
              ds_path, 
              seed=0, 
              sequence_length = 3,
              num_scales = 4):
     
     np.random.seed(seed)
     random.seed(seed)
     
     self.num_scales = num_scales
     
     seq_list = get_lists(ds_path)
     self.samples = get_samples(seq_list, sequence_length)
     
     # get by calib.py
     self.intrinsics = np.array([
                              [1.14183754e+03, 0.00000000e+00, 6.28283670e+02],
                              [0.00000000e+00, 1.13869492e+03, 3.56277189e+02],
                              [0.00000000e+00, 0.00000000e+00, 1.00000000e+00],
                              ]).astype(np.float32)
     
     # resize 1280 x 720 -> 416 x 128
     self.intrinsics[0] = self.intrinsics[0]*(416.0/1280.0)
     self.intrinsics[1] = self.intrinsics[1]*(128.0/720.0)
     
     self.ms_k     = get_multi_scale_intrinsics(self.intrinsics, self.num_scales)
     self.ms_inv_k = get_multi_scale_inv_intrinsics(self.intrinsics, self.num_scales)
     
     ######################
     self.to_tensor = custom_transforms.Compose([ custom_transforms.ArrayToTensor() ])
     self.to_tensor_norm = custom_transforms.Compose([ custom_transforms.ArrayToTensor(),
                                                  custom_transforms.Normalize(
                                                          mean=[0.485, 0.456, 0.406],
                                                          std =[0.229, 0.224, 0.225])
                                               ])
示例#2
0
def dataflow_test():
    from DataFlow.sequence_folders import SequenceFolder
    import custom_transforms
    from torch.utils.data import DataLoader
    from DataFlow.validation_folders import ValidationSet
    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    train_transform = custom_transforms.Compose([custom_transforms.RandomHorizontalFlip(),
                                                 custom_transforms.RandomScaleCrop(),
                                                 custom_transforms.ArrayToTensor(), normalize])
    valid_transform = custom_transforms.Compose([custom_transforms.ArrayToTensor(), normalize])
    datapath = 'G:/data/KITTI/KittiRaw_formatted'
    seed = 8964
    train_set = SequenceFolder(datapath, transform=train_transform, seed=seed, train=True,
                               sequence_length=3)

    train_loader = DataLoader(train_set, batch_size=4, shuffle=True, num_workers=4,
                              pin_memory=True)

    val_set = ValidationSet(datapath, transform=valid_transform)
    print("length of train loader is %d" % len(train_loader))
    val_loader = DataLoader(val_set, batch_size=4, shuffle=False, num_workers=4, pin_memory=True)
    print("length of val loader is %d" % len(val_loader))

    dataiter = iter(train_loader)
    imgs, intrinsics = next(dataiter)
    print(len(imgs))
    print(intrinsics.shape)

    pass
示例#3
0
def main():
    global args, best_error, n_iter
    args = parser.parse_args()

    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                            std=[0.5, 0.5, 0.5])

    train_transform = custom_transforms.Compose([
        custom_transforms.RandomRotate(),
        custom_transforms.RandomHorizontalFlip(),
        custom_transforms.RandomScaleCrop(),
        custom_transforms.ArrayToTensor(), normalize
    ])

    train_set = SequenceFolder(args.data,
                               transform=train_transform,
                               seed=args.seed,
                               train=True,
                               sequence_length=args.sequence_length)
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=4,
                                               shuffle=True,
                                               num_workers=4,
                                               pin_memory=True,
                                               drop_last=True)
    print(len(train_loader))
def main():
    args = parser.parse_args()
    config = load_config(args.config)

    # directory for storing output
    output_dir = os.path.join(config.save_path, "results", args.seq)

    # load configuration from checkpoints
    if args.use_latest_not_best:
        config.posenet = os.path.join(config.save_path,
                                      "posenet_checkpoint.pth.tar")
        output_dir = output_dir + "-latest"
    else:
        config.posenet = os.path.join(config.save_path, "posenet_best.pth.tar")

    os.makedirs(output_dir)

    # define transformations
    transform = custom_transforms.Compose([
        custom_transforms.ArrayToTensor(),
        custom_transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])

    # load dataset loader
    dataset = SequenceFolder(
        config.data,
        cameras=config.cameras,
        gray=config.gray,
        sequence_length=config.sequence_length,
        shuffle=False,
        train=False,
        transform=transform,
        sequence=args.seq,
    )

    input_channels = dataset[0][1].shape[0]

    # define posenet
    pose_net = PoseNet(in_channels=input_channels,
                       nb_ref_imgs=2,
                       output_exp=False).to(device)
    weights = torch.load(config.posenet)
    pose_net.load_state_dict(weights['state_dict'])
    pose_net.eval()

    # prediction
    poses = []
    for i, (tgt, tgt_lf, ref, ref_lf, k, kinv, pose_gt) in enumerate(dataset):
        print("{:03d}/{:03d}".format(i + 1, len(dataset)), end="\r")
        tgt = tgt.unsqueeze(0).to(device)
        ref = [r.unsqueeze(0).to(device) for r in ref]
        tgt_lf = tgt_lf.unsqueeze(0).to(device)
        ref_lf = [r.unsqueeze(0).to(device) for r in ref_lf]
        exp, pose = pose_net(tgt_lf, ref_lf)
        poses.append(pose[0, 1, :].cpu().numpy())

    # save poses
    outdir = os.path.join(output_dir, "poses.npy")
    np.save(outdir, poses)
    print("\nok")
示例#5
0
def main():
    print('=> PyTorch version: ' + torch.__version__ + ' || CUDA_VISIBLE_DEVICES: ' + os.environ["CUDA_VISIBLE_DEVICES"])

    global device
    args = parser.parse_args()
    if args.save_fig:
        timestamp = datetime.datetime.now().strftime("%m-%d-%H:%M")
        args.save_path = 'outputs'/Path(args.name)/timestamp
        print('=> will save everything to {}'.format(args.save_path))
        args.save_path.makedirs_p()

    print("=> fetching scenes in '{}'".format(args.data))
    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    demo_transform = custom_transforms.Compose([
        custom_transforms.ArrayToTensor(),
        normalize
    ])
    demo_set = SequenceFolder(
        args.data,
        transform=demo_transform,
        max_num_instances=args.mni
    )
    print('=> {} samples found'.format(len(demo_set)))
    demo_loader = torch.utils.data.DataLoader(demo_set, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True)

    # create model
    print("=> creating model")
    disp_net = models.DispResNet().to(device)
    ego_pose_net = models.EgoPoseNet().to(device)
    obj_pose_net = models.ObjPoseNet().to(device)

    if args.pretrained_ego_pose:
        print("=> using pre-trained weights for EgoPoseNet")
        weights = torch.load(args.pretrained_ego_pose)
        ego_pose_net.load_state_dict(weights['state_dict'], strict=False)
    else:
        ego_pose_net.init_weights()

    if args.pretrained_obj_pose:
        print("=> using pre-trained weights for ObjPoseNet")
        weights = torch.load(args.pretrained_obj_pose)
        obj_pose_net.load_state_dict(weights['state_dict'], strict=False)
    else:
        obj_pose_net.init_weights()

    if args.pretrained_disp:
        print("=> using pre-trained weights for DispResNet")
        weights = torch.load(args.pretrained_disp)
        disp_net.load_state_dict(weights['state_dict'], strict=False)
    else:
        disp_net.init_weights()

    cudnn.benchmark = True
    disp_net = torch.nn.DataParallel(disp_net)
    ego_pose_net = torch.nn.DataParallel(ego_pose_net)
    obj_pose_net = torch.nn.DataParallel(obj_pose_net)

    demo_visualize(args, demo_loader, disp_net, ego_pose_net, obj_pose_net)
示例#6
0
def save_mask_all():
    args = parser.parse_args()

    weights = torch.load(args.pretrained)
    seq_length = int(weights['state_dict']['conv1.0.weight'].size(1) / 3)
    mask_net = getattr(models, 'MaskNet6')(nb_ref_imgs=5 - 1,
                                           output_exp=True).cuda()
    mask_net.load_state_dict(weights['state_dict'], strict=False)
    mask_net.eval()

    dataset_dir = Path(args.dataset_dir)
    output_dir = Path(args.output_dir)
    output_dir.makedirs_p()

    trace_dir = output_dir / args.trace_dir  # 轨迹
    trace_dir.makedirs_p()
    # data prepare

    normalize = custom_transforms.NormalizeLocally()
    valid_transform = custom_transforms.Compose(
        [custom_transforms.ArrayToTensor(), normalize])
    val_set = SequenceFolder(  # 只有图
        args.dataset_dir,
        transform=valid_transform,
        seed=None,
        train=False,
        sequence_length=5,
        target_transform=None)
    if len(val_set) == 0:
        print('读取错误')
        return
    val_loader = torch.utils.data.DataLoader(val_set,
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=2,
                                             pin_memory=True,
                                             drop_last=True)
    mask_all = None
    print(len(val_loader))
    for i, (tgt_img, ref_imgs, intrinsics,
            intrinsics_inv) in enumerate(val_loader):
        tgt_img = tgt_img.to(device)
        ref_imgs = [img.to(device) for img in ref_imgs]
        explainability_mask = mask_net(tgt_img, ref_imgs)

        #explainability_mask = torch.cat([explainability_mask[:, :len(ref_imgs) // 2, :],
        #                   torch.zeros(1, 1, 6).float().to(device),
        #                   explainability_mask[:, len(ref_imgs) // 2:, :]], dim=1)  # add 0
        if i == 0:
            mask_all = explainability_mask
        else:
            mask_all = torch.cat([mask_all, explainability_mask])

    np.save('mask_all.npy', mask_all.detach().cpu().numpy())
示例#7
0
def get_train_transforms(normalize):
  train_transforms = []
  train_transforms.append(transforms.Scale(160))
  train_transforms.append(transforms.RandomHorizontalFlip())
  train_transforms.append(transforms.RandomColor(0.15))
  train_transforms.append(transforms.RandomRotate(15))
  train_transforms.append(transforms.RandomSizedCrop(128))
  train_transforms.append(transforms.ToTensor())
  train_transforms.append(normalize)
  train_transforms = transforms.Compose(train_transforms)
  return train_transforms
示例#8
0
def main():

    global global_vars_dict
    args = global_vars_dict['args']


    normalize = custom_transforms.NormalizeLocally()
    valid_transform = custom_transforms.Compose([custom_transforms.ArrayToTensor(), normalize])

    timestamp = datetime.datetime.now().strftime("%m-%d-%H:%M")
    args.save_path = Path('test_out')/ Path(args.sq_name)
    print('=> will save everything to {}'.format(args.save_path))
    args.save_path.makedirs_p()

    tb_writer = SummaryWriter(args.save_path)

    val_set = SequenceFolder(  # 只有图
        args.data,
        transform=valid_transform,
        seed=None,
        train=False,
        sequence_length=args.sequence_length,
        target_transform=None
    )
    val_loader = torch.utils.data.DataLoader(
        val_set, batch_size=1, shuffle=False,
        num_workers=args.workers, pin_memory=True, drop_last=True)

    print("=> creating model")
    # 1.1 disp_net
    disp_net = getattr(models, args.dispnet)().cuda()
    weights = torch.load(args.pretrained_disp)
    disp_net.load_state_dict(weights['state_dict'])


    # 1.2 pose_net
    pose_net = getattr(models, args.posenet)(nb_ref_imgs=args.sequence_length - 1).cuda()
    weights = torch.load(args.pretrained_pose)
    pose_net.load_state_dict(weights['state_dict'])

    # 1.3.flow_net
    flow_net = getattr(models, args.flownet)(nlevels=args.nlevels).cuda()
    weights = torch.load(args.pretrained_flow)
    flow_net.load_state_dict(weights['state_dict'])

    # 1.4 mask_net
    mask_net = getattr(models, args.masknet)(nb_ref_imgs=args.sequence_length - 1, output_exp=True).cuda()
    weights = torch.load(args.pretrained_mask)
    mask_net.load_state_dict(weights['state_dict'])

    disp_list,disp_arr,flow_list,mask_list= test(val_loader,disp_net,mask_net,pose_net, flow_net, tb_writer,global_vars_dict = global_vars_dict)

    print('over')
    def init(self, 
                 ds_path, 
                 AoI,
                 seed=0, 
                 sequence_length = 3,
                 num_scales = 4):
        np.random.seed(seed)
        random.seed(seed)

        self.num_scales = num_scales
        
        seq_list = get_lists(ds_path)
        self.samples = get_samples(seq_list, sequence_length,AoI)
        
        # get by calib.py
        self.intrinsics = np.array([
                                 [1.14183754e+03, 0.00000000e+00, 6.28283670e+02],
                                 [0.00000000e+00, 1.13869492e+03, 3.56277189e+02],
                                 [0.00000000e+00, 0.00000000e+00, 1.00000000e+00],
                                 ]).astype(np.float32)
        
        # The original size of the picture taken by my phone camera is 1280 x 720.
        # if your original picture size is not 1280 x 720, change the two numbers below
        # resize 1280 x 720 -> 416 x 128
        self.intrinsics[0] = self.intrinsics[0]*(416.0/1280.0)
        self.intrinsics[1] = self.intrinsics[1]*(128.0/720.0)
        
        self.ms_k     = get_multi_scale_intrinsics(self.intrinsics, self.num_scales)
        self.ms_inv_k = get_multi_scale_inv_intrinsics(self.intrinsics, self.num_scales)
        
        ######################
        self.to_tensor = custom_transforms.Compose([ custom_transforms.ArrayToTensor() ])
        self.to_tensor_norm = custom_transforms.Compose([ custom_transforms.ArrayToTensor(),
                                                     custom_transforms.Normalize(
                                                             mean=[0.485, 0.456, 0.406],
                                                             std =[0.229, 0.224, 0.225])
                                                  ])
示例#10
0
def make_augmentation_transforms(augmentation, mode):
    if mode == 'train':
        transforms = [
            t.RandomPadToLength(length=config.AUDIO_LENGTH),
            t.Noise(
                length=config.AUDIO_LENGTH,
                noise_waves=load_noise_waves(),
                noise_limit=0.2,
            ).with_prob(0.5),
            t.RandomShift(shift_limit=0.2).with_prob(0.5),
        ]
    else:
        transforms = [t.PadToLength(length=config.AUDIO_LENGTH)]
    transforms.append(augmentations[augmentation])
    transforms += [
        t.Pad(((0, 0), (0, 1)), 'constant'),
        t.ExpandDims(),
        t.ToTensor(),
    ]
    return t.Compose(transforms)
示例#11
0
def main():
    global global_vars_dict
    args = global_vars_dict['args']
    best_error = -1  #best model choosing

    #mkdir
    timestamp = datetime.datetime.now().strftime("%m-%d-%H:%M")

    args.save_path = Path('checkpoints') / Path(args.data_dir).stem / timestamp

    args.save_path.makedirs_p()
    torch.manual_seed(args.seed)
    if args.alternating:
        args.alternating_flags = np.array([False, False, True])
    #mk writers
    tb_writer = SummaryWriter(args.save_path)

    # Data loading code and transpose

    if args.data_normalization == 'global':
        normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                                std=[0.5, 0.5, 0.5])
    elif args.data_normalization == 'local':
        normalize = custom_transforms.NormalizeLocally()

    valid_transform = custom_transforms.Compose(
        [custom_transforms.ArrayToTensor(), normalize])

    print("=> fetching scenes in '{}'".format(args.data_dir))

    train_transform = custom_transforms.Compose([
        #custom_transforms.RandomRotate(),
        custom_transforms.RandomHorizontalFlip(),
        custom_transforms.RandomScaleCrop(),
        custom_transforms.ArrayToTensor(),
        normalize
    ])

    #train set, loader only建立一个
    from datasets.sequence_mc import SequenceFolder
    train_set = SequenceFolder(  # mc data folder
        args.data_dir,
        transform=train_transform,
        seed=args.seed,
        train=True,
        sequence_length=args.sequence_length,  # 5
        target_transform=None,
        depth_format='png')

    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               drop_last=True)

    if args.epoch_size == 0:
        args.epoch_size = len(train_loader)

#val set,loader 挨个建立
#if args.val_with_depth_gt:
    from datasets.validation_folders2 import ValidationSet

    val_set_with_depth_gt = ValidationSet(args.data_dir,
                                          transform=valid_transform,
                                          depth_format='png')

    val_loader_depth = torch.utils.data.DataLoader(val_set_with_depth_gt,
                                                   batch_size=args.batch_size,
                                                   shuffle=False,
                                                   num_workers=args.workers,
                                                   pin_memory=True,
                                                   drop_last=True)

    print('{} samples found in {} train scenes'.format(len(train_set),
                                                       len(train_set.scenes)))

    #1 create model
    print("=> creating model")
    #1.1 disp_net
    disp_net = getattr(models, args.dispnet)().cuda()
    output_exp = True  #args.mask_loss_weight > 0

    if args.pretrained_disp:
        print("=> using pre-trained weights from {}".format(
            args.pretrained_disp))
        weights = torch.load(args.pretrained_disp)
        disp_net.load_state_dict(weights['state_dict'])
    else:
        disp_net.init_weights()

    if args.resume:
        print("=> resuming from checkpoint")
        dispnet_weights = torch.load(args.save_path /
                                     'dispnet_checkpoint.pth.tar')
        disp_net.load_state_dict(dispnet_weights['state_dict'])

    cudnn.benchmark = True
    disp_net = torch.nn.DataParallel(disp_net)

    print('=> setting adam solver')

    parameters = chain(disp_net.parameters())

    optimizer = torch.optim.Adam(parameters,
                                 args.lr,
                                 betas=(args.momentum, args.beta),
                                 weight_decay=args.weight_decay)

    if args.resume and (args.save_path /
                        'optimizer_checkpoint.pth.tar').exists():
        print("=> loading optimizer from checkpoint")
        optimizer_weights = torch.load(args.save_path /
                                       'optimizer_checkpoint.pth.tar')
        optimizer.load_state_dict(optimizer_weights['state_dict'])

    #
    if args.log_terminal:
        logger = TermLogger(n_epochs=args.epochs,
                            train_size=min(len(train_loader), args.epoch_size),
                            valid_size=len(val_loader_depth))
        logger.reset_epoch_bar()
    else:
        logger = None


#预先评估下
    criterion_train = MaskedL1Loss().to(device)  # l1LOSS 容易优化
    criterion_val = ComputeErrors().to(device)

    #depth_error_names,depth_errors = validate_depth_with_gt(val_loader_depth, disp_net,criterion=criterion_val, epoch=0, logger=logger,tb_writer=tb_writer,global_vars_dict=global_vars_dict)

    #logger.reset_epoch_bar()
    #    logger.epoch_logger_update(epoch=0,time=0,names=depth_error_names,values=depth_errors)
    epoch_time = AverageMeter()
    end = time.time()
    #3. main cycle
    for epoch in range(1, args.epochs):  #epoch 0 在第没入循环之前已经测试了.

        logger.reset_train_bar()
        logger.reset_valid_bar()

        errors = [0]
        error_names = ['no error names depth']

        #3.2 train for one epoch---------
        loss_names, losses = train_depth_gt(train_loader=train_loader,
                                            disp_net=disp_net,
                                            optimizer=optimizer,
                                            criterion=criterion_train,
                                            logger=logger,
                                            train_writer=tb_writer,
                                            global_vars_dict=global_vars_dict)

        #3.3 evaluate on validation set-----
        depth_error_names, depth_errors = validate_depth_with_gt(
            val_loader=val_loader_depth,
            disp_net=disp_net,
            criterion=criterion_val,
            epoch=epoch,
            logger=logger,
            tb_writer=tb_writer,
            global_vars_dict=global_vars_dict)

        epoch_time.update(time.time() - end)
        end = time.time()

        #3.5 log_terminal
        #if args.log_terminal:
        if args.log_terminal:
            logger.epoch_logger_update(epoch=epoch,
                                       time=epoch_time,
                                       names=depth_error_names,
                                       values=depth_errors)

    # tensorboard scaler
    #train loss
        for loss_name, loss in zip(loss_names, losses.avg):
            tb_writer.add_scalar('train/' + loss_name, loss, epoch)

        #val_with_gt loss
        for name, error in zip(depth_error_names, depth_errors.avg):
            tb_writer.add_scalar('val/' + name, error, epoch)

        #3.6 save model and remember lowest error and save checkpoint
        total_loss = losses.avg[0]
        if best_error < 0:
            best_error = total_loss

        is_best = total_loss <= best_error
        best_error = min(best_error, total_loss)
        save_checkpoint(args.save_path, {
            'epoch': epoch + 1,
            'state_dict': disp_net.module.state_dict()
        }, {
            'epoch': epoch + 1,
            'state_dict': None
        }, {
            'epoch': epoch + 1,
            'state_dict': None
        }, {
            'epoch': epoch + 1,
            'state_dict': None
        }, is_best)

    if args.log_terminal:
        logger.epoch_bar.finish()
def main():
    global args
    args = parser.parse_args()
    save_path = Path(args.name)
    args.save_path = 'checkpoints'/save_path 
    print('=> will save everything to {}'.format(args.save_path))
    args.save_path.makedirs_p()


    
    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    train_transform = custom_transforms.Compose([
          # custom_transforms.RandomRotate(),
          # custom_transforms.RandomHorizontalFlip(),
          # custom_transforms.RandomScaleCrop(),
          custom_transforms.ArrayToTensor(),
          normalize  ])

    training_writer = SummaryWriter(args.save_path)

    intrinsics = np.array([542.822841, 0, 315.593520, 0, 542.576870, 237.756098, 0, 0, 1]).astype(np.float32).reshape((3, 3))
    
    inference_set = SequenceFolder(
        root = args.dataset_dir,
        intrinsics = intrinsics,
        transform=train_transform,
        train=False,
        sequence_length=args.sequence_length
    )

    print('{} samples found in {} train scenes'.format(len(inference_set), len(inference_set.scenes)))
    inference_loader = torch.utils.data.DataLoader(
        inference_set, batch_size=1, shuffle=False,
        num_workers=args.workers, pin_memory=True, drop_last=True)

    print("=> creating model")
    mask_net = MaskResNet6.MaskResNet6().cuda()
    pose_net = PoseNetB6.PoseNetB6().cuda()
    mask_net = torch.nn.DataParallel(mask_net)

    masknet_weights = torch.load(args.pretrained_mask)# 
    posenet_weights = torch.load(args.pretrained_pose)
    mask_net.load_state_dict(masknet_weights['state_dict'])
    # pose_net.load_state_dict(posenet_weights['state_dict'])
    pose_net.eval()
    mask_net.eval()

    # training 

    for i, (rgb_tgt_img, rgb_ref_imgs, intrinsics, intrinsics_inv) in enumerate(tqdm(inference_loader)):
        #print(rgb_tgt_img)
        tgt_img_var = Variable(rgb_tgt_img.cuda(), volatile=True)
        ref_imgs_var = [Variable(img.cuda(), volatile=True) for img in rgb_ref_imgs]
        intrinsics_var = Variable(intrinsics.cuda(), volatile=True)
        intrinsics_inv_var = Variable(intrinsics_inv.cuda(), volatile=True)

        explainability_mask = mask_net(tgt_img_var, ref_imgs_var)
        
        after_mask = tensor2array(ref_imgs_var[0][0]*explainability_mask[0,0]).transpose(1,2,0)
        x = Image.fromarray(np.uint8(after_mask*255))
        x.save(args.save_path/str(i).zfill(3)+'multi.png')
        
        explainability_mask = (explainability_mask[0,0].detach().cpu()).numpy()
        # print(explainability_mask.shape)
        y = Image.fromarray(np.uint8(explainability_mask*255))
        y.save(args.save_path/str(i).zfill(3)+'mask.png')
示例#13
0
def main():
    args = parser.parse_args()
    output_dir = Path(args.output_dir)

    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                            std=[0.5, 0.5, 0.5])
    valid_transform = custom_transforms.Compose(
        [custom_transforms.ArrayToTensor(), normalize])
    val_set = SequenceFolder(args.data,
                             transform=valid_transform,
                             seed=args.seed,
                             sequence_length=args.sequence_length)

    print('{} samples found in {} valid scenes'.format(len(val_set),
                                                       len(val_set.scenes)))
    val_loader = torch.utils.data.DataLoader(val_set,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    dpsnet = PSNet(args.nlabel, args.mindepth).cuda()
    weights = torch.load(args.pretrained_dps)
    dpsnet.load_state_dict(weights['state_dict'])
    dpsnet.eval()

    output_dir = Path(args.output_dir)
    if not os.path.isdir(output_dir):
        os.mkdir(output_dir)

    errors = np.zeros((2, 8, int(len(val_loader) / args.print_freq) + 1),
                      np.float32)
    with torch.no_grad():
        for ii, (tgt_img, ref_imgs, ref_poses, intrinsics, intrinsics_inv,
                 tgt_depth, scale_) in enumerate(val_loader):
            if ii % args.print_freq == 0:
                i = int(ii / args.print_freq)
                tgt_img_var = Variable(tgt_img.cuda())
                ref_imgs_var = [Variable(img.cuda()) for img in ref_imgs]
                ref_poses_var = [Variable(pose.cuda()) for pose in ref_poses]
                intrinsics_var = Variable(intrinsics.cuda())
                intrinsics_inv_var = Variable(intrinsics_inv.cuda())
                tgt_depth_var = Variable(tgt_depth.cuda())
                scale = scale_.numpy()[0]

                # compute output
                pose = torch.cat(ref_poses_var, 1)
                start = time.time()
                output_depth = dpsnet(tgt_img_var, ref_imgs_var, pose,
                                      intrinsics_var, intrinsics_inv_var)
                elps = time.time() - start
                mask = (tgt_depth <= args.maxdepth) & (
                    tgt_depth >= args.mindepth) & (tgt_depth == tgt_depth)

                tgt_disp = args.mindepth * args.nlabel / tgt_depth
                output_disp = args.mindepth * args.nlabel / output_depth

                output_disp_ = torch.squeeze(output_disp.data.cpu(), 1)
                output_depth_ = torch.squeeze(output_depth.data.cpu(), 1)

                errors[0, :,
                       i] = compute_errors_test(tgt_depth[mask] / scale,
                                                output_depth_[mask] / scale)
                errors[1, :,
                       i] = compute_errors_test(tgt_disp[mask] / scale,
                                                output_disp_[mask] / scale)

                print('Elapsed Time {} Abs Error {:.4f}'.format(
                    elps, errors[0, 0, i]))

                if args.output_print:
                    output_disp_n = (output_disp_).numpy()[0]
                    np.save(output_dir / '{:04d}{}'.format(i, '.npy'),
                            output_disp_n)
                    disp = (255 * tensor2array(torch.from_numpy(output_disp_n),
                                               max_value=args.nlabel,
                                               colormap='bone')).astype(
                                                   np.uint8)
                    imsave(output_dir / '{:04d}_disp{}'.format(i, '.png'),
                           disp)

    mean_errors = errors.mean(2)
    error_names = [
        'abs_rel', 'abs_diff', 'sq_rel', 'rms', 'log_rms', 'a1', 'a2', 'a3'
    ]
    print("{}".format(args.output_dir))
    print("Depth Results : ")
    print("{:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}".
          format(*error_names))
    print(
        "{:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}"
        .format(*mean_errors[0]))

    print("Disparity Results : ")
    print("{:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}".
          format(*error_names))
    print(
        "{:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}"
        .format(*mean_errors[1]))

    np.savetxt(output_dir / 'errors.csv',
               mean_errors,
               fmt='%1.4f',
               delimiter=',')
示例#14
0
def main():
    global n_iter
    args = parser.parse_args()
    save_path = save_path_formatter(args, parser)
    args.save_path = 'checkpoints' / (args.exp + '_' + save_path)
    print('=> will save everything to {}'.format(args.save_path))
    args.save_path.makedirs_p()
    torch.manual_seed(args.seed)

    training_writer = SummaryWriter(args.save_path)
    output_writers = []
    for i in range(3):
        output_writers.append(SummaryWriter(args.save_path / 'valid' / str(i)))

    # Data loading code
    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                            std=[0.5, 0.5, 0.5])

    train_transform = custom_transforms.Compose([
        custom_transforms.RandomScaleCrop(),
        custom_transforms.ArrayToTensor(), normalize
    ])

    valid_transform = custom_transforms.Compose(
        [custom_transforms.ArrayToTensor(), normalize])

    print("=> fetching scenes in '{}'".format(args.data))
    train_set = SequenceFolder(args.data,
                               transform=train_transform,
                               seed=args.seed,
                               ttype=args.ttype,
                               dataset=args.dataset)
    val_set = SequenceFolder(args.data,
                             transform=valid_transform,
                             seed=args.seed,
                             ttype=args.ttype2,
                             dataset=args.dataset)

    train_set.samples = train_set.samples[:len(train_set) -
                                          len(train_set) % args.batch_size]

    print('{} samples found in {} train scenes'.format(len(train_set),
                                                       len(train_set.scenes)))
    print('{} samples found in {} valid scenes'.format(len(val_set),
                                                       len(val_set.scenes)))
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_set,
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    if args.epoch_size == 0:
        args.epoch_size = len(train_loader)

    # create model
    print("=> creating model")

    mvdnet = MVDNet(args.nlabel, args.mindepth, no_pool=args.no_pool).cuda()
    mvdnet.init_weights()
    if args.pretrained_mvdn:
        print("=> using pre-trained weights for MVDNet")
        weights = torch.load(args.pretrained_mvdn)
        mvdnet.load_state_dict(weights['state_dict'])

    depth_cons = DepthCons().cuda()
    depth_cons.init_weights()

    if args.pretrained_cons:
        print("=> using pre-trained weights for ConsNet")
        weights = torch.load(args.pretrained_cons)
        depth_cons.load_state_dict(weights['state_dict'])

    cons_loss_ = ConsLoss().cuda()
    print('=> setting adam solver')

    if args.train_cons:
        optimizer = torch.optim.Adam(depth_cons.parameters(),
                                     args.lr,
                                     betas=(args.momentum, args.beta),
                                     weight_decay=args.weight_decay)
        mvdnet.eval()
    else:
        optimizer = torch.optim.Adam(mvdnet.parameters(),
                                     args.lr,
                                     betas=(args.momentum, args.beta),
                                     weight_decay=args.weight_decay)

    cudnn.benchmark = True
    mvdnet = torch.nn.DataParallel(mvdnet)
    depth_cons = torch.nn.DataParallel(depth_cons)

    print(' ==> setting log files')
    with open(args.save_path / args.log_summary, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow([
            'train_loss', 'validation_abs_rel', 'validation_abs_diff',
            'validation_sq_rel', 'validation_a1', 'validation_a2',
            'validation_a3', 'mean_angle_error'
        ])

    with open(args.save_path / args.log_full, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss'])

    print(' ==> main Loop')
    for epoch in range(args.epochs):
        adjust_learning_rate(args, optimizer, epoch)

        # train for one epoch
        if args.evaluate:
            train_loss = 0
        else:
            train_loss = train(args, train_loader, mvdnet, depth_cons,
                               cons_loss_, optimizer, args.epoch_size,
                               training_writer, epoch)
        if not args.evaluate and (args.skip_v):
            error_names = [
                'abs_rel', 'abs_diff', 'sq_rel', 'a1', 'a2', 'a3', 'angle'
            ]
            errors = [0] * 7
        else:
            errors, error_names = validate_with_gt(args, val_loader, mvdnet,
                                                   depth_cons, epoch,
                                                   output_writers)

        error_string = ', '.join('{} : {:.3f}'.format(name, error)
                                 for name, error in zip(error_names, errors))

        for error, name in zip(errors, error_names):
            training_writer.add_scalar(name, error, epoch)

        # Up to you to chose the most relevant error to measure your model's performance, careful some measures are to maximize (such as a1,a2,a3)
        decisive_error = errors[0]
        with open(args.save_path / args.log_summary, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([
                train_loss, decisive_error, errors[1], errors[2], errors[3],
                errors[4], errors[5], errors[6]
            ])
        if args.evaluate:
            break
        if args.train_cons:
            save_checkpoint(args.save_path, {
                'epoch': epoch + 1,
                'state_dict': depth_cons.module.state_dict()
            },
                            epoch,
                            file_prefixes=['cons'])
        else:
            save_checkpoint(args.save_path, {
                'epoch': epoch + 1,
                'state_dict': mvdnet.module.state_dict()
            },
                            epoch,
                            file_prefixes=['mvdnet'])
示例#15
0
def main():
    global args
    args = parser.parse_args()

    args.pretrained_disp = Path(args.pretrained_disp)
    args.pretrained_pose = Path(args.pretrained_pose)
    args.pretrained_mask = Path(args.pretrained_mask)
    args.pretrained_flow = Path(args.pretrained_flow)

    if args.output_dir is not None:
        args.output_dir = Path(args.output_dir)
        args.output_dir.makedirs_p()

        image_dir = args.output_dir / 'images'
        gt_dir = args.output_dir / 'gt'
        mask_dir = args.output_dir / 'mask'
        viz_dir = args.output_dir / 'viz'
        rigidity_mask_dir = args.output_dir / 'rigidity'
        rigidity_census_mask_dir = args.output_dir / 'rigidity_census'
        explainability_mask_dir = args.output_dir / 'explainability'

        image_dir.makedirs_p()
        gt_dir.makedirs_p()
        mask_dir.makedirs_p()
        viz_dir.makedirs_p()
        rigidity_mask_dir.makedirs_p()
        rigidity_census_mask_dir.makedirs_p()
        explainability_mask_dir.makedirs_p()

        output_writer = SummaryWriter(args.output_dir)

    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                            std=[0.5, 0.5, 0.5])
    flow_loader_h, flow_loader_w = 256, 832
    valid_flow_transform = custom_transforms.Compose([
        custom_transforms.Scale(h=flow_loader_h, w=flow_loader_w),
        custom_transforms.ArrayToTensor(), normalize
    ])
    val_flow_set = ValidationMask(root=args.kitti_dir,
                                  sequence_length=5,
                                  transform=valid_flow_transform)

    val_loader = torch.utils.data.DataLoader(val_flow_set,
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=2,
                                             pin_memory=True,
                                             drop_last=True)

    disp_net = getattr(models, args.dispnet)().cuda()
    pose_net = getattr(models, args.posenet)(nb_ref_imgs=4).cuda()
    mask_net = getattr(models, args.masknet)(nb_ref_imgs=4).cuda()
    flow_net = getattr(models, args.flownet)(nlevels=args.nlevels).cuda()

    dispnet_weights = torch.load(args.pretrained_disp)
    posenet_weights = torch.load(args.pretrained_pose)
    masknet_weights = torch.load(args.pretrained_mask)
    flownet_weights = torch.load(args.pretrained_flow)
    disp_net.load_state_dict(dispnet_weights['state_dict'])
    pose_net.load_state_dict(posenet_weights['state_dict'])
    flow_net.load_state_dict(flownet_weights['state_dict'])
    mask_net.load_state_dict(masknet_weights['state_dict'])

    disp_net.eval()
    pose_net.eval()
    mask_net.eval()
    flow_net.eval()

    error_names = ['tp_0', 'fp_0', 'fn_0', 'tp_1', 'fp_1', 'fn_1']
    errors = AverageMeter(i=len(error_names))
    errors_census = AverageMeter(i=len(error_names))
    errors_bare = AverageMeter(i=len(error_names))

    for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv, flow_gt, obj_map_gt,
            semantic_map_gt) in enumerate(tqdm(val_loader)):
        tgt_img_var = Variable(tgt_img.cuda(), volatile=True)
        ref_imgs_var = [
            Variable(img.cuda(), volatile=True) for img in ref_imgs
        ]
        intrinsics_var = Variable(intrinsics.cuda(), volatile=True)
        intrinsics_inv_var = Variable(intrinsics_inv.cuda(), volatile=True)

        flow_gt_var = Variable(flow_gt.cuda(), volatile=True)
        obj_map_gt_var = Variable(obj_map_gt.cuda(), volatile=True)

        disp = disp_net(tgt_img_var)
        depth = 1 / disp
        pose = pose_net(tgt_img_var, ref_imgs_var)
        explainability_mask = mask_net(tgt_img_var, ref_imgs_var)
        if args.flownet in ['Back2Future']:
            flow_fwd, flow_bwd, _ = flow_net(tgt_img_var, ref_imgs_var[1:3])
        else:
            flow_fwd = flow_net(tgt_img_var, ref_imgs_var[2])
        flow_cam = pose2flow(depth.squeeze(1), pose[:, 2], intrinsics_var,
                             intrinsics_inv_var)

        rigidity_mask = 1 - (1 - explainability_mask[:, 1]) * (
            1 - explainability_mask[:, 2]).unsqueeze(1) > 0.5
        rigidity_mask_census_soft = (flow_cam - flow_fwd).pow(2).sum(
            dim=1).unsqueeze(1).sqrt()  #.normalize()
        rigidity_mask_census_soft = 1 - rigidity_mask_census_soft / rigidity_mask_census_soft.max(
        )
        rigidity_mask_census = rigidity_mask_census_soft > args.THRESH

        rigidity_mask_combined = 1 - (
            1 - rigidity_mask.type_as(explainability_mask)) * (
                1 - rigidity_mask_census.type_as(explainability_mask))

        flow_fwd_non_rigid = (1 - rigidity_mask_combined).type_as(
            flow_fwd).expand_as(flow_fwd) * flow_fwd
        flow_fwd_rigid = rigidity_mask_combined.type_as(flow_fwd).expand_as(
            flow_fwd) * flow_cam
        total_flow = flow_fwd_rigid + flow_fwd_non_rigid

        obj_map_gt_var_expanded = obj_map_gt_var.unsqueeze(1).type_as(flow_fwd)

        tgt_img_np = tgt_img[0].numpy()
        rigidity_mask_combined_np = rigidity_mask_combined.cpu().data[0].numpy(
        )
        rigidity_mask_census_np = rigidity_mask_census.cpu().data[0].numpy()
        rigidity_mask_bare_np = rigidity_mask.cpu().data[0].numpy()

        gt_mask_np = obj_map_gt[0].numpy()
        semantic_map_np = semantic_map_gt[0].numpy()

        _errors = mask_error(gt_mask_np, semantic_map_np,
                             rigidity_mask_combined_np[0])
        _errors_census = mask_error(gt_mask_np, semantic_map_np,
                                    rigidity_mask_census_np[0])
        _errors_bare = mask_error(gt_mask_np, semantic_map_np,
                                  rigidity_mask_bare_np[0])

        errors.update(_errors)
        errors_census.update(_errors_census)
        errors_bare.update(_errors_bare)

        if args.output_dir is not None:
            np.save(image_dir / str(i).zfill(3), tgt_img_np)
            np.save(gt_dir / str(i).zfill(3), gt_mask_np)
            np.save(mask_dir / str(i).zfill(3), rigidity_mask_combined_np)
            np.save(rigidity_mask_dir / str(i).zfill(3),
                    rigidity_mask.cpu().data[0].numpy())
            np.save(rigidity_census_mask_dir / str(i).zfill(3),
                    rigidity_mask_census.cpu().data[0].numpy())
            np.save(explainability_mask_dir / str(i).zfill(3),
                    explainability_mask[:, 1].cpu().data[0].numpy())
            # rigidity_mask_dir rigidity_mask.numpy()
            # rigidity_census_mask_dir rigidity_mask_census.numpy()

        if (args.output_dir is not None) and i % 10 == 0:
            ind = int(i // 10)
            output_writer.add_image(
                'val Dispnet Output Normalized',
                tensor2array(disp.data[0].cpu(),
                             max_value=None,
                             colormap='bone'), ind)
            output_writer.add_image('val Input',
                                    tensor2array(tgt_img[0].cpu()), i)
            output_writer.add_image(
                'val Total Flow Output',
                flow_to_image(tensor2array(total_flow.data[0].cpu())), ind)
            output_writer.add_image(
                'val Rigid Flow Output',
                flow_to_image(tensor2array(flow_fwd_rigid.data[0].cpu())), ind)
            output_writer.add_image(
                'val Non-rigid Flow Output',
                flow_to_image(tensor2array(flow_fwd_non_rigid.data[0].cpu())),
                ind)
            output_writer.add_image(
                'val Rigidity Mask',
                tensor2array(rigidity_mask.data[0].cpu(),
                             max_value=1,
                             colormap='bone'), ind)
            output_writer.add_image(
                'val Rigidity Mask Census',
                tensor2array(rigidity_mask_census.data[0].cpu(),
                             max_value=1,
                             colormap='bone'), ind)
            output_writer.add_image(
                'val Rigidity Mask Combined',
                tensor2array(rigidity_mask_combined.data[0].cpu(),
                             max_value=1,
                             colormap='bone'), ind)

        if args.output_dir is not None:
            tgt_img_viz = tensor2array(tgt_img[0].cpu())
            depth_viz = tensor2array(disp.data[0].cpu(),
                                     max_value=None,
                                     colormap='magma')
            mask_viz = tensor2array(rigidity_mask_census_soft.data[0].cpu(),
                                    max_value=1,
                                    colormap='bone')
            row2_viz = flow_to_image(
                np.hstack((tensor2array(flow_cam.data[0].cpu()),
                           tensor2array(flow_fwd_non_rigid.data[0].cpu()),
                           tensor2array(total_flow.data[0].cpu()))))

            row1_viz = np.hstack((tgt_img_viz, depth_viz, mask_viz))
            ####### sửa 2 cái vstack thành hstack ###############
            viz3 = np.hstack(
                (255 * tgt_img_viz, 255 * depth_viz, 255 * mask_viz,
                 flow_to_image(
                     np.hstack((tensor2array(flow_fwd_non_rigid.data[0].cpu()),
                                tensor2array(total_flow.data[0].cpu()))))))
            ########################################################
            ######## code tự thêm ####################
            row1_viz = np.transpose(row1_viz, (1, 2, 0))
            row2_viz = np.transpose(row2_viz, (1, 2, 0))
            viz3 = np.transpose(viz3, (1, 2, 0))
            ##########################################

            row1_viz_im = Image.fromarray((255 * row1_viz).astype('uint8'))
            row2_viz_im = Image.fromarray((row2_viz).astype('uint8'))
            viz3_im = Image.fromarray(viz3.astype('uint8'))

            row1_viz_im.save(viz_dir / str(i).zfill(3) + '01.png')
            row2_viz_im.save(viz_dir / str(i).zfill(3) + '02.png')
            viz3_im.save(viz_dir / str(i).zfill(3) + '03.png')

    bg_iou = errors.sum[0] / (errors.sum[0] + errors.sum[1] + errors.sum[2])
    fg_iou = errors.sum[3] / (errors.sum[3] + errors.sum[4] + errors.sum[5])
    avg_iou = (bg_iou + fg_iou) / 2

    bg_iou_census = errors_census.sum[0] / (
        errors_census.sum[0] + errors_census.sum[1] + errors_census.sum[2])
    fg_iou_census = errors_census.sum[3] / (
        errors_census.sum[3] + errors_census.sum[4] + errors_census.sum[5])
    avg_iou_census = (bg_iou_census + fg_iou_census) / 2

    bg_iou_bare = errors_bare.sum[0] / (
        errors_bare.sum[0] + errors_bare.sum[1] + errors_bare.sum[2])
    fg_iou_bare = errors_bare.sum[3] / (
        errors_bare.sum[3] + errors_bare.sum[4] + errors_bare.sum[5])
    avg_iou_bare = (bg_iou_bare + fg_iou_bare) / 2

    print("Results Full Model")
    print("\t {:>10}, {:>10}, {:>10} ".format('iou', 'bg_iou', 'fg_iou'))
    print("Errors \t {:10.4f}, {:10.4f} {:10.4f}".format(
        avg_iou, bg_iou, fg_iou))

    print("Results Census only")
    print("\t {:>10}, {:>10}, {:>10} ".format('iou', 'bg_iou', 'fg_iou'))
    print("Errors \t {:10.4f}, {:10.4f} {:10.4f}".format(
        avg_iou_census, bg_iou_census, fg_iou_census))

    print("Results Bare")
    print("\t {:>10}, {:>10}, {:>10} ".format('iou', 'bg_iou', 'fg_iou'))
    print("Errors \t {:10.4f}, {:10.4f} {:10.4f}".format(
        avg_iou_bare, bg_iou_bare, fg_iou_bare))
示例#16
0
def main():
    best_error = -1
    n_iter = 0
    torch_device = torch.device(
        "cuda") if torch.cuda.is_available() else torch.device("cpu")

    # parse training arguments
    args = parse_training_args()
    args.training_output_freq = 100  # resetting the training output frequency here.

    # create a folder to save the output of training
    save_path = make_save_path(args)
    args.save_path = save_path
    # save the current configuration to a pickel file
    dump_config(save_path, args)

    print('=> Saving checkpoints to {}'.format(save_path))
    # set manual seed. WHY??
    torch.manual_seed(args.seed)
    # tensorboard summary
    tb_writer = SummaryWriter(save_path)

    # Data preprocessing
    train_transform = valid_transform = custom_transforms.Compose([
        custom_transforms.ArrayToTensor(),
        custom_transforms.Normalize(mean=0.5, std=0.5)
    ])

    # Load datasets
    print("=> Fetching scenes in '{}'".format(args.data))

    train_set = val_set = None
    if args.lfformat is 'focalstack':
        train_set, val_set = get_focal_stack_loaders(args, train_transform,
                                                     valid_transform)
    elif args.lfformat is 'stack':
        train_set, val_set = get_stacked_lf_loaders(args, train_transform,
                                                    valid_transform)

    print('=> {} samples found in {} train scenes'.format(
        len(train_set), len(train_set.scenes)))
    print('=> {} samples found in {} valid scenes'.format(
        len(val_set), len(val_set.scenes)))

    # Create batch loader
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_set,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # Pull first example from dataset to check number of channels
    input_channels = train_set[0][1].shape[0]
    args.epoch_size = len(train_loader)
    print("=> Using {} input channels, {} total batches".format(
        input_channels, args.epoch_size))

    # create model
    print("=> Creating models")
    disp_net = models.LFDispNet(in_channels=input_channels).to(torch_device)
    output_exp = args.mask_loss_weight > 0
    pose_exp_net = models.LFPoseNet(in_channels=input_channels,
                                    nb_ref_imgs=args.sequence_length -
                                    1).to(torch_device)

    # Load or initialize weights
    if args.pretrained_exp_pose:
        print("=> Using pre-trained weights for explainabilty and pose net")
        weights = torch.load(args.pretrained_exp_pose)
        pose_exp_net.load_state_dict(weights['state_dict'], strict=False)
    else:
        pose_exp_net.init_weights()

    if args.pretrained_disp:
        print("=> Using pre-trained weights for Dispnet")
        weights = torch.load(args.pretrained_disp)
        disp_net.load_state_dict(weights['state_dict'])
    else:
        disp_net.init_weights()

    # Set some torch flags
    cudnn.benchmark = True
    disp_net = torch.nn.DataParallel(disp_net)
    pose_exp_net = torch.nn.DataParallel(pose_exp_net)

    # Define optimizer
    print('=> Setting adam solver')
    optim_params = [{
        'params': disp_net.parameters(),
        'lr': args.lr
    }, {
        'params': pose_exp_net.parameters(),
        'lr': args.lr
    }]
    optimizer = torch.optim.Adam(optim_params,
                                 betas=(args.momentum, args.beta),
                                 weight_decay=args.weight_decay)

    # Logging
    with open(os.path.join(save_path, args.log_summary), 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss', 'validation_loss'])

    with open(os.path.join(save_path, args.log_full), 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(
            ['train_loss', 'photo_loss', 'explainability_loss', 'smooth_loss'])

    logger = TermLogger(n_epochs=args.epochs,
                        train_size=min(len(train_loader), args.epoch_size),
                        valid_size=len(val_loader))
    logger.epoch_bar.start()

    # train the network
    for epoch in range(args.epochs):
        logger.epoch_bar.update(epoch)

        # train for one epoch
        logger.reset_train_bar()
        train_loss = train(args, train_loader, disp_net, pose_exp_net,
                           optimizer, args.epoch_size, logger, tb_writer,
                           n_iter, torch_device)
        logger.train_writer.write(' * Avg Loss : {:.3f}'.format(train_loss))

        # evaluate on validation set
        logger.reset_valid_bar()
        errors, error_names = validate_without_gt(args, val_loader, disp_net,
                                                  pose_exp_net, epoch, logger,
                                                  tb_writer, torch_device)
        error_string = ', '.join('{} : {:.3f}'.format(name, error)
                                 for name, error in zip(error_names, errors))
        logger.valid_writer.write(' * Avg {}'.format(error_string))

        # tensorboard logging
        for error, name in zip(errors, error_names):
            tb_writer.add_scalar(name, error, epoch)

        # Up to you to chose the most relevant error to measure your model's performance,
        # careful some measures are to maximize (such as a1,a2,a3)
        decisive_error = errors[1]
        if best_error < 0:
            best_error = decisive_error

        # remember lowest error and save checkpoint
        is_best = decisive_error < best_error
        best_error = min(best_error, decisive_error)
        save_checkpoint(save_path, {
            'epoch': epoch + 1,
            'state_dict': disp_net.module.state_dict()
        }, {
            'epoch': epoch + 1,
            'state_dict': pose_exp_net.module.state_dict()
        }, is_best)

        with open(os.path.join(save_path, args.log_summary), 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([train_loss, decisive_error])
    logger.epoch_bar.finish()
示例#17
0
def main():
    global args, best_photo_loss, n_iter
    args = parser.parse_args()
    if args.dataset_format == 'stacked':
        from datasets.stacked_sequence_folders import SequenceFolder
    elif args.dataset_format == 'sequential':
        from datasets.sequence_folders import SequenceFolder
    save_path = Path('{}epochs{},seq{},b{},lr{},p{},m{},s{}'.format(
        args.epochs,
        ',epochSize' + str(args.epoch_size) if args.epoch_size > 0 else '',
        args.sequence_length, args.batch_size, args.lr, args.photo_loss_weight,
        args.mask_loss_weight, args.smooth_loss_weight))
    timestamp = datetime.datetime.now().strftime("%m-%d-%H:%M")
    args.save_path = 'checkpoints' / save_path / timestamp
    print('=> will save everything to {}'.format(args.save_path))
    args.save_path.makedirs_p()
    torch.manual_seed(args.seed)

    train_writer = SummaryWriter(args.save_path / 'train')
    valid_writer = SummaryWriter(args.save_path / 'valid')
    output_writers = []
    if args.log_output:
        for i in range(3):
            output_writers.append(
                SummaryWriter(args.save_path / 'valid' / str(i)))

    # Data loading code
    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                            std=[0.5, 0.5, 0.5])
    input_transform = custom_transforms.Compose([
        custom_transforms.RandomHorizontalFlip(),
        custom_transforms.RandomScaleCrop(),
        custom_transforms.ArrayToTensor(), normalize
    ])

    print("=> fetching scenes in '{}'".format(args.data))
    train_set = SequenceFolder(args.data,
                               transform=input_transform,
                               seed=args.seed,
                               train=True,
                               sequence_length=args.sequence_length)
    val_set = SequenceFolder(args.data,
                             transform=custom_transforms.Compose([
                                 custom_transforms.ArrayToTensor(), normalize
                             ]),
                             seed=args.seed,
                             train=False,
                             sequence_length=args.sequence_length)
    print('{} samples found in {} train scenes'.format(len(train_set),
                                                       len(train_set.scenes)))
    print('{} samples found in {} valid scenes'.format(len(val_set),
                                                       len(val_set.scenes)))
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_set,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    if args.epoch_size == 0:
        args.epoch_size = len(train_loader)

    # create model
    print("=> creating model")

    disp_net = models.DispNetS().cuda()
    output_exp = args.mask_loss_weight > 0
    if not output_exp:
        print("=> no mask loss, PoseExpnet will only output pose")
    pose_exp_net = models.PoseExpNet(
        nb_ref_imgs=args.sequence_length - 1,
        output_exp=args.mask_loss_weight > 0).cuda()

    if args.pretrained_exp_pose:
        print("=> using pre-trained weights for explainabilty and pose net")
        weights = torch.load(args.pretrained_exp_pose)
        pose_exp_net.load_state_dict(weights['state_dict'], strict=False)
    else:
        pose_exp_net.init_weights()

    if args.pretrained_disp:
        print("=> using pre-trained weights for Dispnet")
        weights = torch.load(args.pretrained_disp)
        disp_net.load_state_dict(weights['state_dict'])
    else:
        disp_net.init_weights()

    cudnn.benchmark = True
    disp_net = torch.nn.DataParallel(disp_net)
    pose_exp_net = torch.nn.DataParallel(pose_exp_net)

    print('=> setting adam solver')

    parameters = chain(disp_net.parameters(), pose_exp_net.parameters())
    optimizer = torch.optim.Adam(parameters,
                                 args.lr,
                                 betas=(args.momentum, args.beta),
                                 weight_decay=args.weight_decay)

    with open(args.save_path / args.log_summary, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss', 'validation_loss'])

    with open(args.save_path / args.log_full, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(
            ['train_loss', 'photo_loss', 'explainability_loss', 'smooth_loss'])

    logger = TermLogger(n_epochs=args.epochs,
                        train_size=min(len(train_loader), args.epoch_size),
                        valid_size=len(val_loader))
    logger.epoch_bar.start()

    for epoch in range(args.epochs):
        logger.epoch_bar.update(epoch)

        # train for one epoch
        logger.reset_train_bar()
        train_loss = train(train_loader, disp_net, pose_exp_net, optimizer,
                           args.epoch_size, logger, train_writer)
        logger.train_writer.write(' * Avg Loss : {:.3f}'.format(train_loss))

        # evaluate on validation set
        logger.reset_valid_bar()
        valid_photo_loss, valid_exp_loss, valid_total_loss = validate(
            val_loader, disp_net, pose_exp_net, epoch, logger, output_writers)
        logger.valid_writer.write(
            ' * Avg Photo Loss : {:.3f}, Valid Loss : {:.3f}, Total Loss : {:.3f}'
            .format(valid_photo_loss, valid_exp_loss, valid_total_loss))
        valid_writer.add_scalar(
            'photometric_error', valid_photo_loss * 4, n_iter
        )  # Loss is multiplied by 4 because it's only one scale, instead of 4 during training
        valid_writer.add_scalar('explanability_loss', valid_exp_loss * 4,
                                n_iter)
        valid_writer.add_scalar('total_loss', valid_total_loss * 4, n_iter)

        if best_photo_loss < 0:
            best_photo_loss = valid_photo_loss

        # remember lowest error and save checkpoint
        is_best = valid_photo_loss < best_photo_loss
        best_photo_loss = min(valid_photo_loss, best_photo_loss)
        save_checkpoint(args.save_path, {
            'epoch': epoch + 1,
            'state_dict': disp_net.module.state_dict()
        }, {
            'epoch': epoch + 1,
            'state_dict': pose_exp_net.module.state_dict()
        }, is_best)

        with open(args.save_path / args.log_summary, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([train_loss, valid_total_loss])
    logger.epoch_bar.finish()
示例#18
0
文件: pose_test.py 项目: zuru/DeepSFM
def main():
	global n_iter
	args = parser.parse_args()

	# Data loading code
	normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
	                                        std=[0.5, 0.5, 0.5])
	train_transform = custom_transforms.Compose([
		# custom_transforms.RandomScaleCrop(),
		custom_transforms.ArrayToTensor(),
		normalize
	])

	print("=> fetching scenes in '{}'".format(args.data))
	train_set = SequenceFolder(
		args.data,
		transform=train_transform,
		seed=args.seed,
		ttype=args.ttype,
		add_geo=args.geo,
		depth_source=args.depth_init,
		sequence_length=args.sequence_length,
		gt_source='g',
		std=args.std_tr,
		pose_init=args.pose_init,
		dataset="",
		get_path=True
	)

	print('{} samples found in {} train scenes'.format(len(train_set), len(train_set.scenes)))
	val_loader = torch.utils.data.DataLoader(
		train_set, batch_size=args.batch_size, shuffle=False,
		num_workers=args.workers, pin_memory=True)

	# create model
	print("=> creating model")
	pose_net = PoseNet(args.nlabel, args.std_tr, args.std_rot, add_geo_cost=args.geo, depth_augment=False).cuda()

	if args.pretrained_dps:
		# freeze feature extra layers
		# for param in pose_net.feature_extraction.parameters():
		#     param.requires_grad = False

		print("=> using pre-trained weights for DPSNet")
		model_dict = pose_net.state_dict()
		weights = torch.load(args.pretrained_dps)['state_dict']
		pretrained_dict = {k: v for k, v in weights.items() if
		                   k in model_dict and weights[k].shape == model_dict[k].shape}

		model_dict.update(pretrained_dict)

		pose_net.load_state_dict(model_dict)

	else:
		pose_net.init_weights()

	cudnn.benchmark = True
	pose_net = torch.nn.DataParallel(pose_net)

	global n_iter
	data_time = AverageMeter()

	pose_net.eval()
	end = time.time()

	errors = np.zeros((2, 2, int(np.ceil(len(val_loader)))), np.float32)
	with torch.no_grad():
		for i, (tgt_img, ref_imgs, ref_poses, intrinsics, intrinsics_inv, tgt_depth, ref_depths,
		        ref_noise_poses, initial_pose, tgt_path, ref_paths) in enumerate(val_loader):

			data_time.update(time.time() - end)
			tgt_img_var = Variable(tgt_img.cuda())
			ref_imgs_var = [Variable(img.cuda()) for img in ref_imgs]
			ref_poses_var = [Variable(pose.cuda()) for pose in ref_poses]
			ref_noise_poses_var = [Variable(pose.cuda()) for pose in ref_noise_poses]
			initial_pose_var = Variable(initial_pose.cuda())

			ref_depths_var = [Variable(dep.cuda()) for dep in ref_depths]
			intrinsics_var = Variable(intrinsics.cuda())
			intrinsics_inv_var = Variable(intrinsics_inv.cuda())
			tgt_depth_var = Variable(tgt_depth.cuda())

			pose = torch.cat(ref_poses_var, 1)

			noise_pose = torch.cat(ref_noise_poses_var, 1)

			pose_norm = torch.norm(noise_pose[:, :, :3, 3], dim=-1, keepdim=True)  # b * n* 1

			p_angle, p_trans, rot_c, trans_c = pose_net(tgt_img_var, ref_imgs_var, initial_pose_var, noise_pose,
			                                            intrinsics_var,
			                                            intrinsics_inv_var,
			                                            tgt_depth_var,
			                                            ref_depths_var, trans_norm=pose_norm)

			batch_size = p_angle.shape[0]
			p_angle_v = torch.sum(F.softmax(p_angle, dim=1).view(batch_size, -1, 1) * rot_c, dim=1)
			p_trans_v = torch.sum(F.softmax(p_trans, dim=1).view(batch_size, -1, 1) * trans_c, dim=1)
			p_matrix = Variable(torch.zeros((batch_size, 4, 4)).float()).cuda()
			p_matrix[:, 3, 3] = 1
			p_matrix[:, :3, :] = torch.cat([angle2matrix(p_angle_v), p_trans_v.unsqueeze(-1)], dim=-1)  # 2*3*4

			p_rel_pose = torch.ones_like(noise_pose)
			for bat in range(batch_size):
				path = tgt_path[bat]
				dirname = Path.dirname(path)

				orig_poses = np.genfromtxt(Path.join(dirname, args.pose_init + "_poses.txt"))
				for j in range(len(ref_imgs)):
					p_rel_pose[:, j] = torch.matmul(noise_pose[:, j], inv(p_matrix))

					seq_num = int(Path.basename(ref_paths[bat][j])[:-4])
					orig_poses[seq_num] = p_rel_pose[bat, j, :3, :].data.cpu().numpy().reshape(12, )

					p_aa = mat2axangle(p_rel_pose[bat, j, :3, :3].data.cpu().numpy())
					gt_aa = mat2axangle(pose[bat, j, :3, :3].data.cpu().numpy(), unit_thresh=1e-2)

					n_aa = mat2axangle(noise_pose[bat, j, :3, :3].data.cpu().numpy(), unit_thresh=1e-2)
					p_t = p_rel_pose[bat, j, :3, 3].data.cpu().numpy()
					gt_t = pose[bat, j, :3, 3].data.cpu().numpy()
					n_t = noise_pose[bat, j, :3, 3].data.cpu().numpy()
					p_aa = p_aa[0] * p_aa[1]
					n_aa = n_aa[0] * n_aa[1]
					gt_aa = gt_aa[0] * gt_aa[1]
					error = compute_motion_errors(np.concatenate([n_aa, n_t]), np.concatenate([gt_aa, gt_t]), True)
					error_p = compute_motion_errors(np.concatenate([p_aa, p_t]), np.concatenate([gt_aa, gt_t]), True)
					print("%d n r%.6f, t%.6f" % (i, error[0], error[2]))
					print("%d p r%.6f, t%.6f" % (i, error_p[0], error_p[2]))
					errors[0, 0, i] += error[0]
					errors[0, 1, i] += error[2]
					errors[1, 0, i] += error_p[0]
					errors[1, 1, i] += error_p[2]
				errors[:, :, i] /= len(ref_imgs)
				if args.save and not Path.exists(Path.join(dirname, args.save + "_poses.txt")):
					np.savetxt(Path.join(dirname, args.save + "_poses.txt"), orig_poses)

		mean_error = errors.mean(2)
		error_names = ['rot', 'trans']
		print("%s Results : " % args.pose_init)
		print(
			"{:>10}, {:>10}".format(
				*error_names))
		print("{:10.4f}, {:10.4f}".format(*mean_error[0]))

		print("new Results : ")
		print(
			"{:>10}, {:>10}".format(
				*error_names))
		print("{:10.4f}, {:10.4f}".format(*mean_error[1]))
示例#19
0
def main():
    global args
    args = parser.parse_args()
    args.pretrained_disp = Path(args.pretrained_disp)
    args.pretrained_pose = Path(args.pretrained_pose)
    # args.pretrained_mask = Path(args.pretrained_mask)
    args.pretrained_flow = Path(args.pretrained_flow)

    if args.output_dir is not None:
        args.output_dir = Path(args.output_dir)
        args.output_dir.makedirs_p()

        image_dir = args.output_dir / 'images'
        gt_dir = args.output_dir / 'gt'
        mask_dir = args.output_dir / 'mask'
        viz_dir = args.output_dir / 'viz'

        image_dir.makedirs_p()
        gt_dir.makedirs_p()
        mask_dir.makedirs_p()
        viz_dir.makedirs_p()

        output_writer = SummaryWriter(args.output_dir)

    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                            std=[0.5, 0.5, 0.5])
    flow_loader_h, flow_loader_w = 256, 832
    valid_flow_transform = custom_transforms.Compose([
        custom_transforms.Scale(h=flow_loader_h, w=flow_loader_w),
        custom_transforms.ArrayToTensor(), normalize
    ])
    if args.dataset == "kitti2015":
        val_flow_set = ValidationFlow(root=args.kitti_dir,
                                      sequence_length=3,
                                      transform=valid_flow_transform)

    val_loader = torch.utils.data.DataLoader(val_flow_set,
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=2,
                                             pin_memory=True,
                                             drop_last=True)

    disp_net = getattr(models, args.dispnet)().cuda()
    pose_net = getattr(models, args.posenet)(nb_ref_imgs=2).cuda()
    # mask_net = getattr(models, args.masknet)(nb_ref_imgs=4).cuda()
    flow_net = getattr(models, args.flownet)(nlevels=args.nlevels).cuda()

    dispnet_weights = torch.load(args.pretrained_disp)
    posenet_weights = torch.load(args.pretrained_pose)
    # masknet_weights = torch.load(args.pretrained_mask)
    flownet_weights = torch.load(args.pretrained_flow)
    disp_net.load_state_dict(dispnet_weights['state_dict'])
    pose_net.load_state_dict(posenet_weights['state_dict'])
    flow_net.load_state_dict(flownet_weights['state_dict'])
    # mask_net.load_state_dict(masknet_weights['state_dict'])

    disp_net.eval()
    pose_net.eval()
    # mask_net.eval()
    flow_net.eval()

    error_names = [
        'epe_total', 'epe_sp', 'epe_mv', 'Fl', 'epe_total_gt_mask',
        'epe_sp_gt_mask', 'epe_mv_gt_mask', 'Fl_gt_mask'
    ]
    errors = AverageMeter(i=len(error_names))
    for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv, flow_gt,
            obj_map_gt) in enumerate(tqdm(val_loader)):
        tgt_img_var = Variable(tgt_img.cuda(), volatile=True)
        ref_imgs_var = [
            Variable(img.cuda(), volatile=True) for img in ref_imgs
        ]
        intrinsics_var = Variable(intrinsics.cuda(), volatile=True)
        intrinsics_inv_var = Variable(intrinsics_inv.cuda(), volatile=True)

        flow_gt_var = Variable(flow_gt.cuda(), volatile=True)
        obj_map_gt_var = Variable(obj_map_gt.cuda(), volatile=True)

        disp = disp_net(tgt_img_var)
        depth = 1 / disp
        pose = pose_net(tgt_img_var, ref_imgs_var)
        # explainability_mask = mask_net(tgt_img_var, ref_imgs_var)
        # print(len(explainability_mask))

        if args.flownet == 'Back2Future':
            flow_fwd, flow_bwd, _ = flow_net(tgt_img_var, ref_imgs_var)
        else:
            flow_fwd = flow_net(tgt_img_var, ref_imgs_var[2])

        flow_cam = pose2flow(depth.squeeze(1), pose[:, 1], intrinsics_var,
                             intrinsics_inv_var)

        # flow_cam_bwd = pose2flow(depth.squeeze(1), pose[:,1], intrinsics_var, intrinsics_inv_var)
        #---------------------------------------------------------------

        flows_cam_fwd = [
            pose2flow(depth_.squeeze(1), pose[:, 1], intrinsics_var,
                      intrinsics_inv_var) for depth_ in depth
        ]
        flows_cam_bwd = [
            pose2flow(depth_.squeeze(1), pose[:, 0], intrinsics_var,
                      intrinsics_inv_var) for depth_ in depth
        ]
        flow_fwd_list = []
        flow_fwd_list.append(flow_fwd)
        flow_bwd_list = []
        flow_bwd_list.append(flow_bwd)
        rigidity_mask_fwd = consensus_exp_masks(flows_cam_fwd,
                                                flows_cam_bwd,
                                                flow_fwd_list,
                                                flow_bwd_list,
                                                tgt_img_var,
                                                ref_imgs_var[1],
                                                ref_imgs_var[0],
                                                wssim=0.85,
                                                wrig=1.0,
                                                ws=0.1)[0]
        del flow_fwd_list
        del flow_bwd_list
        #--------------------------------------------------------------

        #rigidity_mask = 1 - (1-explainability_mask[:,1])*(1-explainability_mask[:,2]).unsqueeze(1) > 0.5
        rigidity_mask_census_soft = (flow_cam - flow_fwd).abs()  #.normalize()
        #rigidity_mask_census_u = rigidity_mask_census_soft[:,0] < args.THRESH
        #rigidity_mask_census_v = rigidity_mask_census_soft[:,1] < args.THRESH
        #rigidity_mask_census = (rigidity_mask_census_u).type_as(flow_fwd) * (rigidity_mask_census_v).type_as(flow_fwd)

        # rigidity_mask_census = ( torch.pow( (torch.pow(rigidity_mask_census_soft[:,0],2) + torch.pow(rigidity_mask_census_soft[:,1] , 2)), 0.5) < args.THRESH ).type_as(flow_fwd)
        THRESH_1 = 1
        THRESH_2 = 1
        rigidity_mask_census = (
            (torch.pow(rigidity_mask_census_soft[:, 0], 2) +
             torch.pow(rigidity_mask_census_soft[:, 1], 2)) < THRESH_1 *
            (flow_cam.pow(2).sum(dim=1) + flow_fwd.pow(2).sum(dim=1)) +
            THRESH_2).type_as(flow_fwd)

        # rigidity_mask_census = torch.zeros_like(rigidity_mask_census)
        rigidity_mask_fwd = torch.zeros_like(rigidity_mask_fwd)
        rigidity_mask_combined = 1 - (1 - rigidity_mask_fwd) * (
            1 - rigidity_mask_census)  #
        obj_map_gt_var_expanded = obj_map_gt_var.unsqueeze(1).type_as(flow_fwd)

        flow_fwd_non_rigid = (rigidity_mask_combined <= args.THRESH).type_as(
            flow_fwd).expand_as(flow_fwd) * flow_fwd
        flow_fwd_rigid = (rigidity_mask_combined > args.THRESH
                          ).type_as(flow_cam).expand_as(flow_cam) * flow_cam
        total_flow = flow_fwd_rigid + flow_fwd_non_rigid

        # rigidity_mask = rigidity_mask.type_as(flow_fwd)
        _epe_errors = compute_all_epes(
            flow_gt_var, flow_cam, flow_fwd,
            torch.zeros_like(rigidity_mask_combined)) + compute_all_epes(
                flow_gt_var, flow_cam, flow_fwd, (1 - obj_map_gt_var_expanded))
        errors.update(_epe_errors)

        tgt_img_np = tgt_img[0].numpy()
        rigidity_mask_combined_np = rigidity_mask_combined.cpu().data[0].numpy(
        )
        gt_mask_np = obj_map_gt[0].numpy()

        if args.output_dir is not None:
            np.save(image_dir / str(i).zfill(3), tgt_img_np)
            np.save(gt_dir / str(i).zfill(3), gt_mask_np)
            np.save(mask_dir / str(i).zfill(3), rigidity_mask_combined_np)

        if (args.output_dir is not None):
            tmp1 = flow_fwd.data[0].permute(1, 2, 0).cpu().numpy()
            tmp1 = flow_2_image(tmp1)
            scipy.misc.imsave(viz_dir / str(i).zfill(3) + 'flow.png', tmp1)

    print("Results")
    print("\t {:>10}, {:>10}, {:>10}, {:>6}, {:>10}, {:>10}, {:>10}, {:>10} ".
          format(*error_names))
    print(
        "Errors \t {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}"
        .format(*errors.avg))
def main():
    global best_error, n_iter, device
    args = parse_multiwarp_training_args()
    # Some non-optional parameters for training
    args.training_output_freq = 100
    args.tilesize = 8

    save_path = make_save_path(args)
    args.save_path = save_path

    print("Using device: {}".format(device))

    dump_config(save_path, args)
    print('\n\n=> Saving checkpoints to {}'.format(save_path))

    torch.manual_seed(args.seed)                # setting a manual seed for reproducability
    tb_writer = SummaryWriter(save_path)        # tensorboard summary writer

    # Data pre-processing - Just convert arrays to tensor and normalize the data to be largely between 0 and 1
    train_transform = valid_transform = custom_transforms.Compose([
        custom_transforms.ArrayToTensor(),
        custom_transforms.Normalize(mean=0.5, std=0.5)
    ])

    # Create data loader based on the format of the light field
    print("=> Fetching scenes in '{}'".format(args.data))
    train_set, val_set = None, None
    if args.lfformat == 'focalstack':
        train_set, val_set = get_focal_stack_loaders(args, train_transform, valid_transform)
    elif args.lfformat == 'stack':
        is_monocular = False
        if len(args.cameras) == 1 and args.cameras[0] == 8 and args.cameras_stacked == "input":
                is_monocular = True
        train_set, val_set = get_stacked_lf_loaders(args, train_transform, valid_transform, is_monocular=is_monocular)
    elif args.lfformat == 'epi':
        train_set, val_set = get_epi_loaders(args, train_transform, valid_transform)

    print('=> {} samples found in {} train scenes'.format(len(train_set), len(train_set.scenes)))
    print('=> {} samples found in {} validation scenes'.format(len(val_set), len(val_set.scenes)))

    print('=> Multi-warp training, warping {} sub-apertures'.format(len(args.cameras)))

    # Create batch loader
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_set,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    output_channels = len(args.cameras) # for multi-warp photometric loss, we request as many depth values as the cameras used
    args.epoch_size = len(train_loader)
    
    # Create models
    print("=> Creating models")

    if args.lfformat == "epi":
        print("=> Using EPI encoders")
        if args.cameras_epi == "vertical":
            disp_encoder = models.EpiEncoder('vertical', args.tilesize).to(device)
            pose_encoder = models.RelativeEpiEncoder('vertical', args.tilesize).to(device)
            dispnet_input_channels = 16 + len(args.cameras)     # 16 is the number of output channels of the encoder
            posenet_input_channels = 16 + len(args.cameras)     # 16 is the number of output channels of the encoder
        elif args.cameras_epi == "horizontal":
            disp_encoder = models.EpiEncoder('horizontal', args.tilesize).to(device)
            pose_encoder = models.RelativeEpiEncoder('horizontal', args.tilesize).to(device)
            dispnet_input_channels = 16 + len(args.cameras)  # 16 is the number of output channels of the encoder
            posenet_input_channels = 16 + len(args.cameras)  # 16 is the number of output channels of the encoder
        elif args.cameras_epi == "full":
            disp_encoder = models.EpiEncoder('full', args.tilesize).to(device)
            pose_encoder = models.RelativeEpiEncoder('full', args.tilesize).to(device)
            if args.without_disp_stack:
                dispnet_input_channels = 32  # 16 is the number of output channels of each encoder
            else:
                dispnet_input_channels = 32 + 5  # 16 is the number of output channels of each encoder, 5 from stack
            posenet_input_channels = 32 + 5  # 16 is the number of output channels of each encoder
        else:
            raise ValueError("Incorrect cameras epi format")
    else:
        disp_encoder = None
        pose_encoder = None
        # for stack lfformat channels = num_cameras * num_colour_channels
        # for focalstack lfformat channels = num_focal_planes * num_colour_channels
        dispnet_input_channels = posenet_input_channels = train_set[0]['tgt_lf_formatted'].shape[0]
    
    disp_net = models.LFDispNet(in_channels=dispnet_input_channels,
                                out_channels=output_channels, encoder=disp_encoder).to(device)
    pose_net = models.LFPoseNet(in_channels=posenet_input_channels,
                                nb_ref_imgs=args.sequence_length - 1, encoder=pose_encoder).to(device)

    print("=> [DispNet] Using {} input channels, {} output channels".format(dispnet_input_channels, output_channels))
    print("=> [PoseNet] Using {} input channels".format(posenet_input_channels))

    if args.pretrained_exp_pose:
        print("=> [PoseNet] Using pre-trained weights for pose net")
        weights = torch.load(args.pretrained_exp_pose)
        pose_net.load_state_dict(weights['state_dict'], strict=False)
    else:
        print("=> [PoseNet] training from scratch")
        pose_net.init_weights()

    if args.pretrained_disp:
        print("=> [DispNet] Using pre-trained weights for DispNet")
        weights = torch.load(args.pretrained_disp)
        disp_net.load_state_dict(weights['state_dict'])
    else:
        print("=> [DispNet] training from scratch")
        disp_net.init_weights()

    # this flag tells CUDNN to find the optimal set of algorithms for this specific input data size, which improves
    # runtime efficiency, but takes a while to load in the beginning.
    cudnn.benchmark = True
    # disp_net = torch.nn.DataParallel(disp_net)
    # pose_net = torch.nn.DataParallel(pose_net)

    print('=> Setting adam solver')

    optim_params = [
        {'params': disp_net.parameters(), 'lr': args.lr}, 
        {'params': pose_net.parameters(), 'lr': args.lr}
    ]

    optimizer = torch.optim.Adam(optim_params, betas=(args.momentum, args.beta), weight_decay=args.weight_decay)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)

    with open(save_path + "/" + args.log_summary, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss', 'validation_loss'])

    with open(save_path + "/" + args.log_full, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss', 'photo_loss', 'smooth_loss', 'pose_loss'])

    logger = TermLogger(n_epochs=args.epochs,
                        train_size=min(len(train_loader), args.epoch_size),
                        valid_size=len(val_loader))
    logger.epoch_bar.start()

    w1 = torch.tensor(args.photo_loss_weight, dtype=torch.float32, device=device, requires_grad=True)
    # w2 = torch.tensor(args.mask_loss_weight, dtype=torch.float32, device=device, requires_grad=True)
    w3 = torch.tensor(args.smooth_loss_weight, dtype=torch.float32, device=device, requires_grad=True)
    # w4 = torch.tensor(args.gt_pose_loss_weight, dtype=torch.float32, device=device, requires_grad=True)

    # add some constant parameters to the log for easy visualization
    tb_writer.add_scalar(tag="batch_size", scalar_value=args.batch_size)

    # tb_writer.add_scalar(tag="mask_loss_weight", scalar_value=args.mask_loss_weight)    # this is not used

    # tb_writer.add_scalar(tag="gt_pose_loss_weight", scalar_value=args.gt_pose_loss_weight)

    for epoch in range(args.epochs):
        logger.epoch_bar.update(epoch)

        # train for one epoch
        logger.reset_train_bar()
        train_loss = train(args, train_loader, disp_net, pose_net, optimizer, args.epoch_size, logger, tb_writer, w1, w3)
        logger.train_writer.write(' * Avg Loss : {:.3f}'.format(train_loss))

        # evaluate on validation set
        logger.reset_valid_bar()
        errors, error_names = validate_without_gt(args, val_loader, disp_net, pose_net, epoch, logger, tb_writer, w1, w3)
        error_string = ', '.join('{} : {:.3f}'.format(name, error) for name, error in zip(error_names, errors))
        logger.valid_writer.write(' * Avg {}'.format(error_string))

        # update the learning rate (annealing)
        lr_scheduler.step()

        # add the learning rate to the tensorboard logging
        tb_writer.add_scalar(tag="learning_rate", scalar_value=lr_scheduler.get_last_lr()[0], global_step=epoch)

        tb_writer.add_scalar(tag="photometric_loss_weight", scalar_value=w1, global_step=epoch)
        tb_writer.add_scalar(tag="smooth_loss_weight", scalar_value=w3, global_step=epoch)

        # add validation errors to the tensorboard logging
        for error, name in zip(errors, error_names):
            tb_writer.add_scalar(tag=name, scalar_value=error, global_step=epoch)

        decisive_error = errors[2]
        if best_error < 0:
            best_error = decisive_error

        # remember lowest error and save checkpoint
        is_best = decisive_error < best_error
        best_error = min(best_error, decisive_error)
        save_checkpoint(save_path, {'epoch': epoch + 1, 'state_dict': disp_net.state_dict()},
                        {'epoch': epoch + 1, 'state_dict': pose_net.state_dict()}, is_best)

        # save a checkpoint every 20 epochs anyway
        if epoch % 20 == 0:
            save_checkpoint_current(save_path, {'epoch': epoch + 1, 'state_dict': disp_net.state_dict()},
                                    {'epoch': epoch + 1, 'state_dict': pose_net.state_dict()}, epoch)

        with open(save_path + "/" + args.log_summary, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([train_loss, decisive_error])
    logger.epoch_bar.finish()
示例#21
0
def prepare_environment():
    env = {}
    args = parser.parse_args()
    if args.dataset_format == 'KITTI':
        from datasets.shifted_sequence_folders import ShiftedSequenceFolder
    elif args.dataset_format == 'StillBox':
        from datasets.shifted_sequence_folders import StillBox as ShiftedSequenceFolder
    elif args.dataset_format == 'TUM':
        from datasets.shifted_sequence_folders import TUM as ShiftedSequenceFolder
    save_path = save_path_formatter(args, parser)
    args.save_path = 'checkpoints' / save_path
    print('=> will save everything to {}'.format(args.save_path))
    args.save_path.makedirs_p()
    torch.manual_seed(args.seed)

    args.test_batch_size = 4 * args.batch_size
    if args.evaluate:
        args.epochs = 0

    env['training_writer'] = SummaryWriter(args.save_path)
    output_writers = []
    if args.log_output:
        for i in range(3):
            output_writers.append(
                SummaryWriter(args.save_path / 'valid' / str(i)))
    env['output_writers'] = output_writers

    # Data loading code
    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                            std=[0.5, 0.5, 0.5])
    train_transform = custom_transforms.Compose([
        # custom_transforms.RandomHorizontalFlip(),
        custom_transforms.ArrayToTensor(),
        normalize
    ])

    valid_transform = custom_transforms.Compose(
        [custom_transforms.ArrayToTensor(), normalize])

    print("=> fetching scenes in '{}'".format(args.data))
    train_set = ShiftedSequenceFolder(args.data,
                                      transform=train_transform,
                                      seed=args.seed,
                                      train=True,
                                      with_depth_gt=False,
                                      with_pose_gt=args.supervise_pose,
                                      sequence_length=args.sequence_length)
    val_set = ShiftedSequenceFolder(args.data,
                                    transform=valid_transform,
                                    seed=args.seed,
                                    train=False,
                                    sequence_length=args.sequence_length,
                                    with_depth_gt=args.with_gt,
                                    with_pose_gt=args.with_gt)
    print('{} samples found in {} train scenes'.format(len(train_set),
                                                       len(train_set.scenes)))
    print('{} samples found in {} valid scenes'.format(len(val_set),
                                                       len(val_set.scenes)))
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_set,
                                             batch_size=4 * args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    env['train_set'] = train_set
    env['val_set'] = val_set
    env['train_loader'] = train_loader
    env['val_loader'] = val_loader

    if args.epoch_size == 0:
        args.epoch_size = len(train_loader)

    # create model
    print("=> creating model")
    pose_net = models.PoseNet(seq_length=args.sequence_length,
                              batch_norm=args.bn in ['pose',
                                                     'both']).to(device)

    if args.pretrained_pose:
        print("=> using pre-trained weights for pose net")
        weights = torch.load(args.pretrained_pose)
        pose_net.load_state_dict(weights['state_dict'], strict=False)

    depth_net = models.DepthNet(depth_activation="elu",
                                batch_norm=args.bn in ['depth',
                                                       'both']).to(device)

    if args.pretrained_depth:
        print("=> using pre-trained DepthNet model")
        data = torch.load(args.pretrained_depth)
        depth_net.load_state_dict(data['state_dict'])

    cudnn.benchmark = True
    depth_net = torch.nn.DataParallel(depth_net)
    pose_net = torch.nn.DataParallel(pose_net)

    env['depth_net'] = depth_net
    env['pose_net'] = pose_net

    print('=> setting adam solver')

    optim_params = [{
        'params': depth_net.parameters(),
        'lr': args.lr
    }, {
        'params': pose_net.parameters(),
        'lr': args.lr
    }]
    # parameters = chain(depth_net.parameters(), pose_exp_net.parameters())
    optimizer = torch.optim.Adam(optim_params,
                                 betas=(args.momentum, args.beta),
                                 weight_decay=args.weight_decay)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                args.lr_decay_frequency,
                                                gamma=0.5)
    env['optimizer'] = optimizer
    env['scheduler'] = scheduler

    with open(args.save_path / args.log_summary, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss', 'validation_loss'])

    with open(args.save_path / args.log_full, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(
            ['train_loss', 'photo_loss', 'explainability_loss', 'smooth_loss'])

    logger = TermLogger(n_epochs=args.epochs,
                        train_size=min(len(train_loader), args.epoch_size),
                        valid_size=len(val_loader))
    logger.epoch_bar.start()
    env['logger'] = logger

    env['args'] = args

    return env
示例#22
0
def main():
    global best_error, n_iter, device
    args = parser.parse_args()

    save_path = make_save_path(args)
    args.save_path = save_path
    dump_config(save_path, args)
    print('=> Saving checkpoints to {}'.format(save_path))
    torch.manual_seed(args.seed)
    tb_writer = SummaryWriter(save_path)

    # Data preprocessing
    train_transform = valid_transform = custom_transforms.Compose([
        custom_transforms.ArrayToTensor(),
        custom_transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])

    # Create dataloader
    print("=> Fetching scenes in '{}'".format(args.data))
    train_set = SequenceFolder(
        args.data,
        gray=args.gray,
        cameras=args.cameras,
        transform=train_transform,
        seed=args.seed,
        train=True,
        sequence_length=args.sequence_length
    )
    
    val_set = SequenceFolder(
        args.data,
        gray=args.gray,
        cameras=args.cameras,
        transform=valid_transform,
        seed=args.seed,
        train=False,
        sequence_length=args.sequence_length,
        shuffle=False
    )

    print('=> {} samples found in {} train scenes'.format(len(train_set), len(train_set.scenes)))
    print('=> {} samples found in {} valid scenes'.format(len(val_set), len(val_set.scenes)))

    # Create batch loader
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_set, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True)

    # Pull first example from dataset to check number of channels
    input_channels = train_set[0][1].shape[0]   
    args.epoch_size = len(train_loader)
    print("=> Using {} input channels, {} total batches".format(input_channels, args.epoch_size))
    
    # create model
    print("=> Creating models")
    pose_exp_net = models.LFPoseNet(in_channels=input_channels, nb_ref_imgs=args.sequence_length - 1, output_exp=args.mask_loss_weight > 0).to(device)

    if args.pretrained_exp_pose:
        print("=> Using pre-trained weights for explainabilty and pose net")
        weights = torch.load(args.pretrained_exp_pose)
        pose_exp_net.load_state_dict(weights['state_dict'], strict=False)
    else:
        pose_exp_net.init_weights()

    cudnn.benchmark = True
    pose_exp_net = torch.nn.DataParallel(pose_exp_net)

    print('=> Setting adam solver')

    optim_params = [
        {'params': pose_exp_net.parameters(), 'lr': args.lr}
    ]

    optimizer = torch.optim.Adam(optim_params, betas=(args.momentum, args.beta), weight_decay=args.weight_decay)


    with open(save_path + "/" + args.log_summary, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss', 'validation_loss'])

    logger = TermLogger(n_epochs=args.epochs, train_size=min(len(train_loader), args.epoch_size), valid_size=len(val_loader))
    logger.epoch_bar.start()

    for epoch in range(args.epochs):
        logger.epoch_bar.update(epoch)

        # train for one epoch
        logger.reset_train_bar()
        train_loss = train(args, train_loader, pose_exp_net, optimizer, args.epoch_size, logger, tb_writer)
        logger.train_writer.write(' * Avg Loss : {:.3f}'.format(train_loss))
        
        # evaluate on validation set
        logger.reset_valid_bar()
        valid_loss = validate(args, val_loader, pose_exp_net, logger, tb_writer)

        if valid_loss < best_error or best_error < 0:
            best_error = valid_loss
            checkpoint = {
                "epoch": epoch + 1,
                "state_dict": pose_exp_net.module.state_dict()
            }
            torch.save(checkpoint, save_path + "/" + 'posenet_best.pth.tar')
        torch.save(checkpoint, save_path + "/" + 'posenet_checkpoint.pth.tar')

    logger.epoch_bar.finish()
示例#23
0
def main():
    global args, best_error, n_iter
    args = parser.parse_args()
    save_path = Path(args.name)
    args.save_path = 'checkpoints' / save_path
    print('=> will save everything to {}'.format(args.save_path))
    args.save_path.makedirs_p()
    torch.manual_seed(args.seed)

    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                            std=[0.5, 0.5, 0.5])
    train_transform = custom_transforms.Compose([
        # custom_transforms.RandomRotate(),
        # custom_transforms.RandomHorizontalFlip(),
        custom_transforms.RandomScaleCrop(),
        custom_transforms.ArrayToTensor(),
        normalize
    ])
    training_writer = SummaryWriter(args.save_path)

    intrinsics = np.array(
        [542.822841, 0, 315.593520, 0, 542.576870, 237.756098, 0, 0,
         1]).astype(np.float32).reshape((3, 3))

    train_set = SequenceFolder(root=args.dataset_dir,
                               intrinsics=intrinsics,
                               transform=train_transform,
                               seed=args.seed,
                               train=True,
                               sequence_length=args.sequence_length)

    print('{} samples found in {} train scenes'.format(len(train_set),
                                                       len(train_set.scenes)))
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               drop_last=True)

    print("=> creating model")
    mask_net = MaskResNet6.MaskResNet6().cuda()
    pose_net = PoseNetB6.PoseNetB6().cuda()
    mask_net = torch.nn.DataParallel(mask_net)
    pose_net = torch.nn.DataParallel(pose_net)

    if args.pretrained_mask:
        print("=> using pre-trained weights for explainabilty and pose net")
        weights = torch.load(args.pretrained_mask)
        mask_net.load_state_dict(weights['state_dict'], False)
    else:
        mask_net.init_weights()

    if args.pretrained_pose:
        print("=> using pre-trained weights for explainabilty and pose net")
        weights = torch.load(args.pretrained_pose)
        pose_net.load_state_dict(weights['state_dict'], False)
    else:
        pose_net.init_weights()

    print('=> setting adam solver')
    parameters = chain(mask_net.parameters(), pose_net.parameters())
    optimizer = torch.optim.Adam(parameters,
                                 args.lr,
                                 betas=(args.momentum, args.beta),
                                 weight_decay=args.weight_decay)

    # training
    best_error = 100001
    train_loss = 100000
    for epoch in tqdm(range(args.epochs)):

        train_loss = train(train_loader, mask_net, pose_net, optimizer,
                           args.epoch_size, training_writer)
        is_best = train_loss < best_error
        best_error = min(best_error, train_loss)
        save_checkpoint(args.save_path, {
            'epoch': epoch + 1,
            'state_dict': mask_net.state_dict()
        }, {
            'epoch': epoch + 1,
            'state_dict': optimizer.state_dict()
        }, {
            'epoch': epoch + 1,
            'state_dict': pose_net.state_dict()
        }, is_best)
示例#24
0
def main():
    global best_error, n_iter, device
    args = parser.parse_args()
    if args.dataset_format == 'stacked':
        from datasets.stacked_sequence_folders import SequenceFolder
    elif args.dataset_format == 'sequential':
        from datasets.sequence_folders import SequenceFolder
    save_path = save_path_formatter(args, parser)
    args.save_path = 'checkpoints' / save_path
    print('=> will save everything to {}'.format(args.save_path))
    args.save_path.makedirs_p()
    torch.manual_seed(args.seed)
    if args.evaluate:
        args.epochs = 0

    tb_writer = SummaryWriter(args.save_path)
    # Data loading code
    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                            std=[0.5, 0.5, 0.5])
    train_transform = custom_transforms.Compose([
        custom_transforms.RandomHorizontalFlip(),
        custom_transforms.RandomScaleCrop(),
        custom_transforms.ArrayToTensor(), normalize
    ])

    valid_transform = custom_transforms.Compose(
        [custom_transforms.ArrayToTensor(), normalize])

    print("=> fetching scenes in '{}'".format(args.data))
    train_set = SequenceFolder(args.data,
                               transform=train_transform,
                               seed=args.seed,
                               train=True,
                               sequence_length=args.sequence_length)

    # if no Groundtruth is avalaible, Validation set is the same type as training set to measure photometric loss from warping
    if args.with_gt:
        from datasets.validation_folders import ValidationSet
        val_set = ValidationSet(args.data, transform=valid_transform)
    else:
        val_set = SequenceFolder(
            args.data,
            transform=valid_transform,
            seed=args.seed,
            train=False,
            sequence_length=args.sequence_length,
        )
    print('{} samples found in {} train scenes'.format(len(train_set),
                                                       len(train_set.scenes)))
    print('{} samples found in {} valid scenes'.format(len(val_set),
                                                       len(val_set.scenes)))
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               drop_last=True)
    val_loader = torch.utils.data.DataLoader(val_set,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    if args.epoch_size == 0:
        args.epoch_size = len(train_loader)

    # create model
    print("=> creating model")

    disp_net = models.DispNetS().to(device)
    seg_net = DeepLab(num_classes=args.nclass,
                      backbone=args.backbone,
                      output_stride=args.out_stride,
                      sync_bn=args.sync_bn,
                      freeze_bn=args.freeze_bn).to(device)
    if args.pretrained_seg:
        print("=> using pre-trained weights for seg net")
        weights = torch.load(args.pretrained_seg)
        seg_net.load_state_dict(weights, strict=False)
    output_exp = args.mask_loss_weight > 0
    if not output_exp:
        print("=> no mask loss, PoseExpnet will only output pose")
    pose_exp_net = models.PoseExpNet(
        nb_ref_imgs=args.sequence_length - 1,
        output_exp=args.mask_loss_weight > 0).to(device)

    if args.pretrained_exp_pose:
        print("=> using pre-trained weights for explainabilty and pose net")
        weights = torch.load(args.pretrained_exp_pose)
        pose_exp_net.load_state_dict(weights['state_dict'], strict=False)
    else:
        pose_exp_net.init_weights()

    if args.pretrained_disp:
        print("=> using pre-trained weights for Dispnet")
        weights = torch.load(args.pretrained_disp)
        disp_net.load_state_dict(weights['state_dict'])
    else:
        disp_net.init_weights()

    cudnn.benchmark = True
    disp_net = torch.nn.DataParallel(disp_net)
    pose_exp_net = torch.nn.DataParallel(pose_exp_net)
    seg_net = torch.nn.DataParallel(seg_net)

    print('=> setting adam solver')

    optim_params = [{
        'params': disp_net.parameters(),
        'lr': args.lr
    }, {
        'params': pose_exp_net.parameters(),
        'lr': args.lr
    }]
    optimizer = torch.optim.Adam(optim_params,
                                 betas=(args.momentum, args.beta),
                                 weight_decay=args.weight_decay)

    with open(args.save_path / args.log_summary, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss', 'validation_loss'])

    with open(args.save_path / args.log_full, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(
            ['train_loss', 'photo_loss', 'explainability_loss', 'smooth_loss'])

    logger = TermLogger(n_epochs=args.epochs,
                        train_size=min(len(train_loader), args.epoch_size),
                        valid_size=len(val_loader))
    logger.epoch_bar.start()

    if args.pretrained_disp or args.evaluate:
        logger.reset_valid_bar()
        if args.with_gt:
            errors, error_names = validate_with_gt(args, val_loader, disp_net,
                                                   0, logger, tb_writer)
        else:
            errors, error_names = validate_without_gt(args, val_loader,
                                                      disp_net, pose_exp_net,
                                                      0, logger, tb_writer)
        for error, name in zip(errors, error_names):
            tb_writer.add_scalar(name, error, 0)
        error_string = ', '.join(
            '{} : {:.3f}'.format(name, error)
            for name, error in zip(error_names[2:9], errors[2:9]))
        logger.valid_writer.write(' * Avg {}'.format(error_string))

    for epoch in range(args.epochs):
        logger.epoch_bar.update(epoch)

        # train for one epoch
        logger.reset_train_bar()
        train_loss = train(args, train_loader, disp_net, pose_exp_net, seg_net,
                           optimizer, args.epoch_size, logger, tb_writer)
        logger.train_writer.write(' * Avg Loss : {:.3f}'.format(train_loss))

        # evaluate on validation set
        logger.reset_valid_bar()
        if args.with_gt:
            errors, error_names = validate_with_gt(args, val_loader, disp_net,
                                                   seg_net, epoch, logger,
                                                   tb_writer)
        else:
            errors, error_names = validate_without_gt(args, val_loader,
                                                      disp_net, pose_exp_net,
                                                      epoch, logger, tb_writer)
        error_string = ', '.join('{} : {:.3f}'.format(name, error)
                                 for name, error in zip(error_names, errors))
        logger.valid_writer.write(' * Avg {}'.format(error_string))

        for error, name in zip(errors, error_names):
            tb_writer.add_scalar(name, error, epoch)

        # Up to you to chose the most relevant error to measure your model's performance, careful some measures are to maximize (such as a1,a2,a3)
        decisive_error = errors[1]
        if best_error < 0:
            best_error = decisive_error

        # remember lowest error and save checkpoint
        is_best = decisive_error < best_error
        best_error = min(best_error, decisive_error)
        save_checkpoint(args.save_path, {
            'epoch': epoch + 1,
            'state_dict': disp_net.module.state_dict()
        }, {
            'epoch': epoch + 1,
            'state_dict': pose_exp_net.module.state_dict()
        }, is_best)

        with open(args.save_path / args.log_summary, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([train_loss, decisive_error])
    logger.epoch_bar.finish()
示例#25
0
def main(args):
  # parse args
  best_acc1 = 0.0

  if args.gpu >= 0:
    torch.cuda.set_device(args.gpu)
    print("Use GPU: {}".format(args.gpu))
  else:
    print('You are using CPU for computing!',
          'Yet we assume you are using a GPU.',
          'You will NOT be able to switch between CPU and GPU training!')

  # fix the random seeds (the best we can)
  fixed_random_seed = 2019
  torch.manual_seed(fixed_random_seed)
  np.random.seed(fixed_random_seed)
  random.seed(fixed_random_seed)

  # set up the model + loss
  if args.use_custom_conv:
    print("Using custom convolutions in the network")
    model = default_model(conv_op=CustomConv2d, num_classes=100)
  elif args.use_resnet18:
    model = torchvision.models.resnet18(pretrained=True)
    model.fc = nn.Linear(512, 100)
  elif args.use_adv_training:
    model = AdvSimpleNet(num_classes=100)
  else:
    model = default_model(num_classes=100)
  model_arch = "simplenet"
  criterion = nn.CrossEntropyLoss()
  # put everthing to gpu
  if args.gpu >= 0:
    model = model.cuda(args.gpu)
    criterion = criterion.cuda(args.gpu)

  # setup the optimizer
  optimizer = torch.optim.SGD(model.parameters(), args.lr,
                momentum=args.momentum,
                weight_decay=args.weight_decay)

  # resume from a checkpoint?
  if args.resume:
    if os.path.isfile(args.resume):
      print("=> loading checkpoint '{}'".format(args.resume))
      checkpoint = torch.load(args.resume)
      args.start_epoch = checkpoint['epoch']
      best_acc1 = checkpoint['best_acc1']
      model.load_state_dict(checkpoint['state_dict'])
      if args.gpu < 0:
        model = model.cpu()
      else:
        model = model.cuda(args.gpu)
      # only load the optimizer if necessary
      if (not args.evaluate) and (not args.attack):
        optimizer.load_state_dict(checkpoint['optimizer'])
      print("=> loaded checkpoint '{}' (epoch {}, acc1 {})"
          .format(args.resume, checkpoint['epoch'], best_acc1))
    else:
      print("=> no checkpoint found at '{}'".format(args.resume))

  # set up transforms for data augmentation
  normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                       std=[0.229, 0.224, 0.225])
  train_transforms = get_train_transforms(normalize)
  # val transofrms
  val_transforms=[]
  val_transforms.append(transforms.Scale(160, interpolations=None))
  val_transforms.append(transforms.ToTensor())
  val_transforms.append(normalize)
  val_transforms = transforms.Compose(val_transforms)
  if (not args.evaluate) and (not args.attack):
    print("Training time data augmentations:")
    print(train_transforms)

  # setup dataset and dataloader
  train_dataset = MiniPlacesLoader(args.data_folder,
                  split='train', transforms=train_transforms)
  val_dataset = MiniPlacesLoader(args.data_folder,
                  split='val', transforms=val_transforms)

  train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=args.batch_size, shuffle=True,
    num_workers=args.workers, pin_memory=True, sampler=None, drop_last=True)
  val_loader = torch.utils.data.DataLoader(
    val_dataset, batch_size=100, shuffle=False,
    num_workers=args.workers, pin_memory=True, sampler=None, drop_last=False)

  # testing only
  if (args.evaluate==args.attack) and args.evaluate:
    print("Cann't set evaluate and attack to True at the same time!")
    return

  # set up visualizer
  if args.vis:
    visualizer = default_attention(criterion)
  else:
    visualizer = None

  # evaluation
  if args.resume and args.evaluate:
    print("Testing the model ...")
    cudnn.deterministic = True
    validate(val_loader, model, -1, args, visualizer=visualizer)
    return

  # attack
  if args.resume and args.attack:
    print("Generating adversarial samples for the model ..")
    cudnn.deterministic = True
    validate(val_loader, model, -1, args,
             attacker=default_attack(criterion),
             visualizer=visualizer)
    return

  # enable cudnn benchmark
  cudnn.enabled = True
  cudnn.benchmark = True

  # warmup the training
  if (args.start_epoch == 0) and (args.warmup_epochs > 0):
    print("Warmup the training ...")
    for epoch in range(0, args.warmup_epochs):
      train(train_loader, model, criterion, optimizer, epoch, "warmup", args)

  # start the training
  print("Training the model ...")
  for epoch in range(args.start_epoch, args.epochs):
    # train for one epoch
    train(train_loader, model, criterion, optimizer, epoch, "train", args)

    # evaluate on validation set
    acc1 = validate(val_loader, model, epoch, args)

    # remember best acc@1 and save checkpoint
    is_best = acc1 > best_acc1
    best_acc1 = max(acc1, best_acc1)
    save_checkpoint({
      'epoch': epoch + 1,
      'model_arch': model_arch,
      'state_dict': model.state_dict(),
      'best_acc1': best_acc1,
      'optimizer' : optimizer.state_dict(),
    }, is_best)
def main():
    global opt, best_prec1
    opt = parser.parse_args()
    print(opt)

    # Data loading
    train_transform = custom_transforms.Compose([
        custom_transforms.RandomHorizontalFlip(),
        custom_transforms.RandomScaleCrop(),
        custom_transforms.ArrayToTensor(),
        custom_transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
    valid_transform = custom_transforms.Compose([
        custom_transforms.ArrayToTensor(),
        custom_transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
    print('Loading scenes in', opt.data_dir)
    train_set = SequenceFolder(opt.data_dir,
                               transform=train_transform,
                               seed=opt.seed,
                               train=True,
                               sequence_length=opt.sequence_length)

    val_set = ValidationSet(opt.data_dir, transform=valid_transform)

    print(len(train_set), 'samples found')
    print(len(val_set), 'samples found')

    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=opt.batch_size,
                                               shuffle=True,
                                               num_workers=opt.workers,
                                               pin_memory=True)
    # val_loader = torch.utils.data.DataLoader(val_set, batch_size=opt.batch_size,
    #                                             shuffle=False, num_workers=opt.workers,
    #                                             pin_memory=True)
    if opt.epoch == 0:
        opt.epoch_size = len(train_loader)
    # Done loading

    disp_model = dispnet.DispNet().cuda()
    pose_model = posenet.PoseNet().cuda()
    disp_model, pose_model, optimizer = init.setup(disp_model, pose_model, opt)
    print(disp_model, pose_model)
    trainer = train.Trainer(disp_model, pose_model, optimizer, opt)
    if opt.resume:
        if os.path.isfile(opt.resume):
            # disp_model, pose_model, optimizer, opt, best_prec1 = init.resumer(opt, disp_model, pose_model, optimizer)
            disp_model, pose_model, optimizer, opt = init.resumer(
                opt, disp_model, pose_model, optimizer)
        else:
            print("=> no checkpoint found at '{}'".format(opt.resume))

    cudnn.benchmark = True
    for epoch in range(opt.start_epoch, opt.epochs):
        utils.adjust_learning_rate(opt, optimizer, epoch)
        print("Starting epoch number:", epoch + 1, "Learning rate:",
              optimizer.param_groups[0]["lr"])
        if opt.testOnly == False:
            trainer.train(train_loader, epoch, opt)
        # init.save_checkpoint(opt, disp_model, pose_model, optimizer, best_prec1, epoch)
        init.save_checkpoint(opt, disp_model, pose_model, optimizer, epoch)
def main():
    global args, best_error, n_iter, device
    args = parser.parse_args()
    save_path = save_path_formatter(args, parser)
    args.save_path = 'checkpoints_shifted' / save_path
    print('=> will save everything to {}'.format(args.save_path))
    args.save_path.makedirs_p()
    torch.manual_seed(args.seed)

    training_writer = SummaryWriter(args.save_path)
    output_writers = []
    if args.log_output:
        for i in range(3):
            output_writers.append(
                SummaryWriter(args.save_path / 'valid' / str(i)))

    # Data loading code
    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                            std=[0.5, 0.5, 0.5])
    train_transform = custom_transforms.Compose([
        custom_transforms.RandomHorizontalFlip(),
        custom_transforms.RandomScaleCrop(),
        custom_transforms.ArrayToTensor(), normalize
    ])

    valid_transform = custom_transforms.Compose(
        [custom_transforms.ArrayToTensor(), normalize])

    print("=> fetching scenes in '{}'".format(args.data))
    train_set = ShiftedSequenceFolder(
        args.data,
        transform=train_transform,
        seed=args.seed,
        train=True,
        sequence_length=args.sequence_length,
        target_displacement=args.target_displacement)

    # if no Groundtruth is avalaible, Validation set is the same type as training set to measure photometric loss from warping
    if args.with_gt:
        from datasets.validation_folders import ValidationSet
        val_set = ValidationSet(args.data, transform=valid_transform)
    else:
        val_set = SequenceFolder(
            args.data,
            transform=valid_transform,
            seed=args.seed,
            train=False,
            sequence_length=args.sequence_length,
        )
    print('{} samples found in {} train scenes'.format(len(train_set),
                                                       len(train_set.scenes)))
    print('{} samples found in {} valid scenes'.format(len(val_set),
                                                       len(val_set.scenes)))
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    adjust_loader = torch.utils.data.DataLoader(
        train_set,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=0,
        pin_memory=True
    )  # workers is set to 0 to avoid multiple instances to be modified at the same time
    val_loader = torch.utils.data.DataLoader(val_set,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    if args.epoch_size == 0:
        args.epoch_size = len(train_loader)

    train.args = args
    # create model
    print("=> creating model")

    disp_net = models.DispNetS().cuda()
    output_exp = args.mask_loss_weight > 0
    if not output_exp:
        print("=> no mask loss, PoseExpnet will only output pose")
    pose_exp_net = models.PoseExpNet(
        nb_ref_imgs=args.sequence_length - 1,
        output_exp=args.mask_loss_weight > 0).to(device)

    if args.pretrained_exp_pose:
        print("=> using pre-trained weights for explainabilty and pose net")
        weights = torch.load(args.pretrained_exp_pose)
        pose_exp_net.load_state_dict(weights['state_dict'], strict=False)
    else:
        pose_exp_net.init_weights()

    if args.pretrained_disp:
        print("=> using pre-trained weights for Dispnet")
        weights = torch.load(args.pretrained_disp)
        disp_net.load_state_dict(weights['state_dict'])
    else:
        disp_net.init_weights()

    cudnn.benchmark = True
    disp_net = torch.nn.DataParallel(disp_net)
    pose_exp_net = torch.nn.DataParallel(pose_exp_net)

    print('=> setting adam solver')

    parameters = chain(disp_net.parameters(), pose_exp_net.parameters())
    optimizer = torch.optim.Adam(parameters,
                                 args.lr,
                                 betas=(args.momentum, args.beta),
                                 weight_decay=args.weight_decay)

    with open(args.save_path / args.log_summary, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss', 'validation_loss'])

    with open(args.save_path / args.log_full, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(
            ['train_loss', 'photo_loss', 'explainability_loss', 'smooth_loss'])

    logger = TermLogger(n_epochs=args.epochs,
                        train_size=min(len(train_loader), args.epoch_size),
                        valid_size=len(val_loader))
    logger.epoch_bar.start()

    for epoch in range(args.epochs):
        logger.epoch_bar.update(epoch)

        # train for one epoch
        logger.reset_train_bar()
        train_loss = train(args, train_loader, disp_net, pose_exp_net,
                           optimizer, args.epoch_size, logger, training_writer)
        logger.train_writer.write(' * Avg Loss : {:.3f}'.format(train_loss))

        if (epoch + 1) % 5 == 0:
            train_set.adjust = True
            logger.reset_train_bar(len(adjust_loader))
            average_shifts = adjust_shifts(args, train_set, adjust_loader,
                                           pose_exp_net, epoch, logger,
                                           training_writer)
            shifts_string = ' '.join(
                ['{:.3f}'.format(s) for s in average_shifts])
            logger.train_writer.write(
                ' * adjusted shifts, average shifts are now : {}'.format(
                    shifts_string))
            for i, shift in enumerate(average_shifts):
                training_writer.add_scalar('shifts{}'.format(i), shift, epoch)
            train_set.adjust = False

        # evaluate on validation set
        logger.reset_valid_bar()
        if args.with_gt:
            errors, error_names = validate_with_gt(args, val_loader, disp_net,
                                                   epoch, logger,
                                                   output_writers)
        else:
            errors, error_names = validate_without_gt(args, val_loader,
                                                      disp_net, pose_exp_net,
                                                      epoch, logger,
                                                      output_writers)
        error_string = ', '.join('{} : {:.3f}'.format(name, error)
                                 for name, error in zip(error_names, errors))
        logger.valid_writer.write(' * Avg {}'.format(error_string))

        for error, name in zip(errors, error_names):
            training_writer.add_scalar(name, error, epoch)

        # Up to you to chose the most relevant error to measure your model's performance, careful some measures are to maximize (such as a1,a2,a3)
        decisive_error = errors[0]
        if best_error < 0:
            best_error = decisive_error

        # remember lowest error and save checkpoint
        is_best = decisive_error < best_error
        best_error = min(best_error, decisive_error)
        save_checkpoint(args.save_path, {
            'epoch': epoch + 1,
            'state_dict': disp_net.module.state_dict()
        }, {
            'epoch': epoch + 1,
            'state_dict': pose_exp_net.module.state_dict()
        }, is_best)

        with open(args.save_path / args.log_summary, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([train_loss, decisive_error])
    logger.epoch_bar.finish()
示例#28
0
def main():
    global args
    args = parser.parse_args()
    args.pretrained_disp = Path(args.pretrained_disp)
    args.pretrained_pose = Path(args.pretrained_pose)
    args.pretrained_mask = Path(args.pretrained_mask)
    args.pretrained_flow = Path(args.pretrained_flow)

    if args.output_dir is not None:
        args.output_dir = Path(args.output_dir)
        args.output_dir.makedirs_p()

        image_dir = args.output_dir / 'images'
        gt_dir = args.output_dir / 'gt'
        mask_dir = args.output_dir / 'mask'
        viz_dir = args.output_dir / 'viz'

        image_dir.makedirs_p()
        gt_dir.makedirs_p()
        mask_dir.makedirs_p()
        viz_dir.makedirs_p()

        output_writer = SummaryWriter(args.output_dir)

    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                            std=[0.5, 0.5, 0.5])
    flow_loader_h, flow_loader_w = 256, 832
    valid_flow_transform = custom_transforms.Compose([
        custom_transforms.Scale(h=flow_loader_h, w=flow_loader_w),
        custom_transforms.ArrayToTensor(), normalize
    ])
    if args.dataset == "kitti2015":
        val_flow_set = ValidationFlow(root=args.kitti_dir,
                                      sequence_length=5,
                                      transform=valid_flow_transform)

    val_loader = torch.utils.data.DataLoader(val_flow_set,
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=2,
                                             pin_memory=True,
                                             drop_last=True)

    disp_net = getattr(models, args.dispnet)().cuda()
    pose_net = getattr(models, args.posenet)(nb_ref_imgs=4).cuda()
    mask_net = getattr(models, args.masknet)(nb_ref_imgs=4).cuda()
    flow_net = getattr(models, args.flownet)(nlevels=args.nlevels).cuda()

    dispnet_weights = torch.load(args.pretrained_disp)
    posenet_weights = torch.load(args.pretrained_pose)
    masknet_weights = torch.load(args.pretrained_mask)
    flownet_weights = torch.load(args.pretrained_flow)
    disp_net.load_state_dict(dispnet_weights['state_dict'])
    pose_net.load_state_dict(posenet_weights['state_dict'])
    flow_net.load_state_dict(flownet_weights['state_dict'])
    mask_net.load_state_dict(masknet_weights['state_dict'])

    disp_net.eval()
    pose_net.eval()
    mask_net.eval()
    flow_net.eval()

    error_names = [
        'epe_total', 'epe_sp', 'epe_mv', 'Fl', 'epe_total_gt_mask',
        'epe_sp_gt_mask', 'epe_mv_gt_mask', 'Fl_gt_mask'
    ]
    errors = AverageMeter(i=len(error_names))
    for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv, flow_gt,
            obj_map_gt) in enumerate(tqdm(val_loader)):
        tgt_img_var = Variable(tgt_img.cuda(), volatile=True)
        ref_imgs_var = [
            Variable(img.cuda(), volatile=True) for img in ref_imgs
        ]
        intrinsics_var = Variable(intrinsics.cuda(), volatile=True)
        intrinsics_inv_var = Variable(intrinsics_inv.cuda(), volatile=True)

        flow_gt_var = Variable(flow_gt.cuda(), volatile=True)
        obj_map_gt_var = Variable(obj_map_gt.cuda(), volatile=True)

        disp = disp_net(tgt_img_var)
        depth = 1 / disp
        pose = pose_net(tgt_img_var, ref_imgs_var)
        explainability_mask = mask_net(tgt_img_var, ref_imgs_var)

        if args.flownet == 'Back2Future':
            flow_fwd, flow_bwd, _ = flow_net(tgt_img_var, ref_imgs_var[1:3])
        else:
            flow_fwd = flow_net(tgt_img_var, ref_imgs_var[2])

        flow_cam = pose2flow(depth.squeeze(1), pose[:, 2], intrinsics_var,
                             intrinsics_inv_var)
        flow_cam_bwd = pose2flow(depth.squeeze(1), pose[:, 1], intrinsics_var,
                                 intrinsics_inv_var)

        rigidity_mask = 1 - (1 - explainability_mask[:, 1]) * (
            1 - explainability_mask[:, 2]).unsqueeze(1) > 0.5
        rigidity_mask_census_soft = (flow_cam - flow_fwd).abs()  #.normalize()
        rigidity_mask_census_u = rigidity_mask_census_soft[:, 0] < args.THRESH
        rigidity_mask_census_v = rigidity_mask_census_soft[:, 1] < args.THRESH
        rigidity_mask_census = (rigidity_mask_census_u).type_as(flow_fwd) * (
            rigidity_mask_census_v).type_as(flow_fwd)

        rigidity_mask_combined = 1 - (
            1 - rigidity_mask.type_as(explainability_mask)) * (
                1 - rigidity_mask_census.type_as(explainability_mask))

        obj_map_gt_var_expanded = obj_map_gt_var.unsqueeze(1).type_as(flow_fwd)

        flow_fwd_non_rigid = (rigidity_mask_combined <= args.THRESH).type_as(
            flow_fwd).expand_as(flow_fwd) * flow_fwd
        flow_fwd_rigid = (rigidity_mask_combined > args.THRESH
                          ).type_as(flow_cam).expand_as(flow_cam) * flow_cam
        total_flow = flow_fwd_rigid + flow_fwd_non_rigid

        rigidity_mask = rigidity_mask.type_as(flow_fwd)
        _epe_errors = compute_all_epes(
            flow_gt_var, flow_cam,
            flow_fwd, rigidity_mask_combined) + compute_all_epes(
                flow_gt_var, flow_cam, flow_fwd, (1 - obj_map_gt_var_expanded))
        errors.update(_epe_errors)

        tgt_img_np = tgt_img[0].numpy()
        rigidity_mask_combined_np = rigidity_mask_combined.cpu().data[0].numpy(
        )
        gt_mask_np = obj_map_gt[0].numpy()

        if args.output_dir is not None:
            np.save(image_dir / str(i).zfill(3), tgt_img_np)
            np.save(gt_dir / str(i).zfill(3), gt_mask_np)
            np.save(mask_dir / str(i).zfill(3), rigidity_mask_combined_np)

        if (args.output_dir is not None) and i % 10 == 0:
            ind = int(i // 10)
            output_writer.add_image(
                'val Dispnet Output Normalized',
                tensor2array(disp.data[0].cpu(),
                             max_value=None,
                             colormap='bone'), ind)
            output_writer.add_image('val Input',
                                    tensor2array(tgt_img[0].cpu()), i)
            output_writer.add_image(
                'val Total Flow Output',
                flow_to_image(tensor2array(total_flow.data[0].cpu())), ind)
            output_writer.add_image(
                'val Rigid Flow Output',
                flow_to_image(tensor2array(flow_fwd_rigid.data[0].cpu())), ind)
            output_writer.add_image(
                'val Non-rigid Flow Output',
                flow_to_image(tensor2array(flow_fwd_non_rigid.data[0].cpu())),
                ind)
            output_writer.add_image(
                'val Rigidity Mask',
                tensor2array(rigidity_mask.data[0].cpu(),
                             max_value=1,
                             colormap='bone'), ind)
            output_writer.add_image(
                'val Rigidity Mask Census',
                tensor2array(rigidity_mask_census.data[0].cpu(),
                             max_value=1,
                             colormap='bone'), ind)
            output_writer.add_image(
                'val Rigidity Mask Combined',
                tensor2array(rigidity_mask_combined.data[0].cpu(),
                             max_value=1,
                             colormap='bone'), ind)

            tgt_img_viz = tensor2array(tgt_img[0].cpu())
            depth_viz = tensor2array(disp.data[0].cpu(),
                                     max_value=None,
                                     colormap='bone')
            mask_viz = tensor2array(
                rigidity_mask_census_soft.data[0].prod(dim=0).cpu(),
                max_value=1,
                colormap='bone')
            rigid_flow_viz = flow_to_image(tensor2array(
                flow_cam.data[0].cpu()))
            non_rigid_flow_viz = flow_to_image(
                tensor2array(flow_fwd_non_rigid.data[0].cpu()))
            total_flow_viz = flow_to_image(
                tensor2array(total_flow.data[0].cpu()))
            row1_viz = np.hstack((tgt_img_viz, depth_viz, mask_viz))
            row2_viz = np.hstack(
                (rigid_flow_viz, non_rigid_flow_viz, total_flow_viz))

            row1_viz_im = Image.fromarray((255 * row1_viz).astype('uint8'))
            row2_viz_im = Image.fromarray((row2_viz).astype('uint8'))

            row1_viz_im.save(viz_dir / str(i).zfill(3) + '01.png')
            row2_viz_im.save(viz_dir / str(i).zfill(3) + '02.png')

    print("Results")
    print("\t {:>10}, {:>10}, {:>10}, {:>6}, {:>10}, {:>10}, {:>10}, {:>10} ".
          format(*error_names))
    print(
        "Errors \t {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}"
        .format(*errors.avg))
示例#29
0
def main():
    global args, best_error, n_iter
    args = parser.parse_args()
    save_path = Path(args.name)
    args.save_path = 'checkpoints' / save_path  #/timestamp
    print('=> will save everything to {}'.format(args.save_path))
    args.save_path.makedirs_p()
    torch.manual_seed(args.seed)

    training_writer = SummaryWriter(args.save_path)
    output_writer = SummaryWriter(args.save_path / 'valid')

    # Data loading code
    flow_loader_h, flow_loader_w = 384, 1280

    train_transform = custom_transforms.Compose([
        custom_transforms.RandomHorizontalFlip(),
        custom_transforms.RandomScaleCrop(h=256, w=256),
        custom_transforms.ArrayToTensor(),
    ])

    valid_transform = custom_transforms.Compose([
        custom_transforms.Scale(h=flow_loader_h, w=flow_loader_w),
        custom_transforms.ArrayToTensor()
    ])

    print("=> fetching scenes in '{}'".format(args.data))
    train_set = SequenceFolder(args.data,
                               transform=train_transform,
                               seed=args.seed,
                               train=True,
                               sequence_length=3)

    if args.valset == "kitti2015":
        from datasets.validation_flow import ValidationFlowKitti2015
        val_set = ValidationFlowKitti2015(root=args.kitti_data,
                                          transform=valid_transform)
    elif args.valset == "kitti2012":
        from datasets.validation_flow import ValidationFlowKitti2012
        val_set = ValidationFlowKitti2012(root=args.kitti_data,
                                          transform=valid_transform)

    if args.DEBUG:
        train_set.__len__ = 32
        train_set.samples = train_set.samples[:32]

    print('{} samples found in {} train scenes'.format(len(train_set),
                                                       len(train_set.scenes)))
    print('{} samples found in valid scenes'.format(len(val_set)))
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=1,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               drop_last=True)

    val_loader = torch.utils.data.DataLoader(
        val_set,
        batch_size=
        1,  # batch size is 1 since images in kitti have different sizes
        shuffle=False,
        num_workers=args.workers,
        pin_memory=True,
        drop_last=True)

    if args.epoch_size == 0:
        args.epoch_size = len(train_loader)

    # create model
    print("=> creating model")

    if args.flownet == 'SpyNet':
        flow_net = getattr(models, args.flownet)(nlevels=6, pretrained=True)
    elif args.flownet == 'Back2Future':
        flow_net = getattr(
            models, args.flownet)(pretrained='pretrained/b2f_rm_hard.pth.tar')
    elif args.flownet == 'PWCNet':
        flow_net = models.pwc_dc_net(
            'pretrained/pwc_net_chairs.pth.tar')  # pwc_net.pth.tar')
    else:
        flow_net = getattr(models, args.flownet)()

    if args.flownet in ['SpyNet', 'Back2Future', 'PWCNet']:
        print("=> using pre-trained weights for " + args.flownet)
    elif args.flownet in ['FlowNetC']:
        print("=> using pre-trained weights for FlowNetC")
        weights = torch.load('pretrained/FlowNet2-C_checkpoint.pth.tar')
        flow_net.load_state_dict(weights['state_dict'])
    elif args.flownet in ['FlowNetS']:
        print("=> using pre-trained weights for FlowNetS")
        weights = torch.load('pretrained/flownets.pth.tar')
        flow_net.load_state_dict(weights['state_dict'])
    elif args.flownet in ['FlowNet2']:
        print("=> using pre-trained weights for FlowNet2")
        weights = torch.load('pretrained/FlowNet2_checkpoint.pth.tar')
        flow_net.load_state_dict(weights['state_dict'])
    else:
        flow_net.init_weights()

    pytorch_total_params = sum(p.numel() for p in flow_net.parameters())
    print("Number of model paramters: " + str(pytorch_total_params))

    flow_net = flow_net.cuda()

    cudnn.benchmark = True
    if args.patch_type == 'circle':
        patch, mask, patch_shape = init_patch_circle(args.image_size,
                                                     args.patch_size)
        patch_init = patch.copy()
    elif args.patch_type == 'square':
        patch, patch_shape = init_patch_square(args.image_size,
                                               args.patch_size)
        patch_init = patch.copy()
        mask = np.ones(patch_shape)
    else:
        sys.exit("Please choose a square or circle patch")

    if args.patch_path:
        patch, mask, patch_shape = init_patch_from_image(
            args.patch_path, args.mask_path, args.image_size, args.patch_size)
        patch_init = patch.copy()

    if args.log_terminal:
        logger = TermLogger(n_epochs=args.epochs,
                            train_size=min(len(train_loader), args.epoch_size),
                            valid_size=len(val_loader),
                            attack_size=args.max_count)
        logger.epoch_bar.start()
    else:
        logger = None

    for epoch in range(args.epochs):

        if args.log_terminal:
            logger.epoch_bar.update(epoch)
            logger.reset_train_bar()

        # train for one epoch
        patch, mask, patch_init, patch_shape = train(patch, mask, patch_init,
                                                     patch_shape, train_loader,
                                                     flow_net, epoch, logger,
                                                     training_writer)

        # Validate
        errors, error_names = validate_flow_with_gt(patch, mask, patch_shape,
                                                    val_loader, flow_net,
                                                    epoch, logger,
                                                    output_writer)

        error_string = ', '.join('{} : {:.3f}'.format(name, error)
                                 for name, error in zip(error_names, errors))
        #
        if args.log_terminal:
            logger.valid_writer.write(' * Avg {}'.format(error_string))
        else:
            print('Epoch {} completed'.format(epoch))

        for error, name in zip(errors, error_names):
            training_writer.add_scalar(name, error, epoch)

        torch.save(patch, args.save_path / 'patch_epoch_{}'.format(str(epoch)))

    if args.log_terminal:
        logger.epoch_bar.finish()
示例#30
0
def main():
    global best_error, n_iter, device
    args = parser.parse_args()

    timestamp = datetime.datetime.now().strftime("%m-%d-%H:%M")
    save_path = Path(args.name)
    args.save_path = 'checkpoints' / save_path / timestamp
    print('=> will save everything to {}'.format(args.save_path))
    args.save_path.makedirs_p()

    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    cudnn.deterministic = True
    cudnn.benchmark = True

    training_writer = SummaryWriter(args.save_path)
    output_writers = []
    if args.log_output:
        for i in range(3):
            output_writers.append(
                SummaryWriter(args.save_path / 'valid' / str(i)))

    # Data loading code
    normalize = custom_transforms.Normalize(mean=[0.45, 0.45, 0.45],
                                            std=[0.225, 0.225, 0.225])

    train_transform = custom_transforms.Compose([
        custom_transforms.RandomHorizontalFlip(),
        custom_transforms.RandomScaleCrop(),
        custom_transforms.ArrayToTensor(), normalize
    ])

    valid_transform = custom_transforms.Compose(
        [custom_transforms.ArrayToTensor(), normalize])

    print("=> fetching scenes in '{}'".format(args.data))
    train_set = SequenceFolder(args.data,
                               transform=train_transform,
                               seed=args.seed,
                               train=True,
                               sequence_length=args.sequence_length)

    # if no Groundtruth is avalaible, Validation set is the same type as training set to measure photometric loss from warping
    if args.with_gt:
        from datasets.validation_folders import ValidationSet
        val_set = ValidationSet(args.data, transform=valid_transform)
    else:
        val_set = SequenceFolder(args.data,
                                 transform=valid_transform,
                                 seed=args.seed,
                                 train=False,
                                 sequence_length=args.sequence_length)
    print('{} samples found in {} train scenes'.format(len(train_set),
                                                       len(train_set.scenes)))
    print('{} samples found in {} valid scenes'.format(len(val_set),
                                                       len(val_set.scenes)))
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_set,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    if args.epoch_size == 0:
        args.epoch_size = len(train_loader)

    # create model
    print("=> creating model")
    disp_net = models.DispResNet(args.resnet_layers,
                                 args.with_pretrain).to(device)
    pose_net = models.PoseResNet(18, args.with_pretrain).to(device)

    # load parameters
    if args.pretrained_disp:
        print("=> using pre-trained weights for DispResNet")
        weights = torch.load(args.pretrained_disp)
        disp_net.load_state_dict(weights['state_dict'], strict=False)

    if args.pretrained_pose:
        print("=> using pre-trained weights for PoseResNet")
        weights = torch.load(args.pretrained_pose)
        pose_net.load_state_dict(weights['state_dict'], strict=False)

    disp_net = torch.nn.DataParallel(disp_net)
    pose_net = torch.nn.DataParallel(pose_net)

    print('=> setting adam solver')
    optim_params = [{
        'params': disp_net.parameters(),
        'lr': args.lr
    }, {
        'params': pose_net.parameters(),
        'lr': args.lr
    }]
    optimizer = torch.optim.Adam(optim_params,
                                 betas=(args.momentum, args.beta),
                                 weight_decay=args.weight_decay)

    with open(args.save_path / args.log_summary, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss', 'validation_loss'])

    with open(args.save_path / args.log_full, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow([
            'train_loss', 'photo_loss', 'smooth_loss',
            'geometry_consistency_loss'
        ])

    logger = TermLogger(n_epochs=args.epochs,
                        train_size=min(len(train_loader), args.epoch_size),
                        valid_size=len(val_loader))
    logger.epoch_bar.start()

    for epoch in range(args.epochs):
        logger.epoch_bar.update(epoch)

        # train for one epoch
        logger.reset_train_bar()
        train_loss = train(args, train_loader, disp_net, pose_net, optimizer,
                           args.epoch_size, logger, training_writer)
        logger.train_writer.write(' * Avg Loss : {:.3f}'.format(train_loss))

        # evaluate on validation set
        logger.reset_valid_bar()
        if args.with_gt:
            errors, error_names = validate_with_gt(args, val_loader, disp_net,
                                                   epoch, logger,
                                                   output_writers)
        else:
            errors, error_names = validate_without_gt(args, val_loader,
                                                      disp_net, pose_net,
                                                      epoch, logger,
                                                      output_writers)
        error_string = ', '.join('{} : {:.3f}'.format(name, error)
                                 for name, error in zip(error_names, errors))
        logger.valid_writer.write(' * Avg {}'.format(error_string))

        for error, name in zip(errors, error_names):
            training_writer.add_scalar(name, error, epoch)

        # Up to you to chose the most relevant error to measure your model's performance, careful some measures are to maximize (such as a1,a2,a3)
        decisive_error = errors[1]
        if best_error < 0:
            best_error = decisive_error

        # remember lowest error and save checkpoint
        is_best = decisive_error < best_error
        best_error = min(best_error, decisive_error)
        save_checkpoint(args.save_path, {
            'epoch': epoch + 1,
            'state_dict': disp_net.module.state_dict()
        }, {
            'epoch': epoch + 1,
            'state_dict': pose_net.module.state_dict()
        }, is_best)

        with open(args.save_path / args.log_summary, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([train_loss, decisive_error])
    logger.epoch_bar.finish()