Exemplo n.º 1
0
def main():
    print('=> PyTorch version: ' + torch.__version__ + ' || CUDA_VISIBLE_DEVICES: ' + os.environ["CUDA_VISIBLE_DEVICES"])

    global device
    args = parser.parse_args()
    if args.save_fig:
        timestamp = datetime.datetime.now().strftime("%m-%d-%H:%M")
        args.save_path = 'outputs'/Path(args.name)/timestamp
        print('=> will save everything to {}'.format(args.save_path))
        args.save_path.makedirs_p()

    print("=> fetching scenes in '{}'".format(args.data))
    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    demo_transform = custom_transforms.Compose([
        custom_transforms.ArrayToTensor(),
        normalize
    ])
    demo_set = SequenceFolder(
        args.data,
        transform=demo_transform,
        max_num_instances=args.mni
    )
    print('=> {} samples found'.format(len(demo_set)))
    demo_loader = torch.utils.data.DataLoader(demo_set, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True)

    # create model
    print("=> creating model")
    disp_net = models.DispResNet().to(device)
    ego_pose_net = models.EgoPoseNet().to(device)
    obj_pose_net = models.ObjPoseNet().to(device)

    if args.pretrained_ego_pose:
        print("=> using pre-trained weights for EgoPoseNet")
        weights = torch.load(args.pretrained_ego_pose)
        ego_pose_net.load_state_dict(weights['state_dict'], strict=False)
    else:
        ego_pose_net.init_weights()

    if args.pretrained_obj_pose:
        print("=> using pre-trained weights for ObjPoseNet")
        weights = torch.load(args.pretrained_obj_pose)
        obj_pose_net.load_state_dict(weights['state_dict'], strict=False)
    else:
        obj_pose_net.init_weights()

    if args.pretrained_disp:
        print("=> using pre-trained weights for DispResNet")
        weights = torch.load(args.pretrained_disp)
        disp_net.load_state_dict(weights['state_dict'], strict=False)
    else:
        disp_net.init_weights()

    cudnn.benchmark = True
    disp_net = torch.nn.DataParallel(disp_net)
    ego_pose_net = torch.nn.DataParallel(ego_pose_net)
    obj_pose_net = torch.nn.DataParallel(obj_pose_net)

    demo_visualize(args, demo_loader, disp_net, ego_pose_net, obj_pose_net)
Exemplo n.º 2
0
def main():
    args = parser.parse_args()

    disp_net = models.DispResNet(args.resnet_layers, False).to(device)
    weights = torch.load(args.pretrained_dispnet)
    disp_net.load_state_dict(weights['state_dict'])
    disp_net.eval()

    dataset_dir = Path(args.dataset_dir)

    if args.dataset_list is not None:
        with open(args.dataset_list, 'r') as f:
            test_files = list(f.read().splitlines())
    else:
        test_files = sorted(dataset_dir.files('*.png'))

    print('{} files to test'.format(len(test_files)))

    output_dir = Path(args.output_dir)
    output_dir.makedirs_p()

    avg_time = 0
    for j in tqdm(range(len(test_files))):
        tgt_img = load_tensor_image(test_files[j], args)
        # tgt_img = load_tensor_image( dataset_dir + test_files[j], args)

        # compute speed
        torch.cuda.synchronize()
        t_start = time.time()

        output = disp_net(tgt_img)

        torch.cuda.synchronize()
        elapsed_time = time.time() - t_start

        avg_time += elapsed_time

        pred_disp = output.cpu().numpy()[0, 0]

        if j == 0:
            predictions = np.zeros((len(test_files), *pred_disp.shape))
        predictions[j] = 1 / pred_disp

    np.save(output_dir / 'predictions.npy', predictions)

    avg_time /= len(test_files)
    print('Avg Time: ', avg_time, ' seconds.')
    print('Avg Speed: ', 1.0 / avg_time, ' fps')
Exemplo n.º 3
0
def main():
    args = parser.parse_args()

    print("=> Tested at {}".format(datetime.datetime.now().strftime("%m-%d-%H:%M")))
    
    print('=> Load dispnet model from {}'.format(args.pretrained_dispnet))

    sys.path.insert(1, os.path.join(sys.path[0], '..'))
    import models

    disp_net = models.DispResNet().to(device)
    weights = torch.load(args.pretrained_dispnet)
    disp_net.load_state_dict(weights['state_dict'])
    disp_net.eval()

    dataset_dir = Path(args.dataset_dir)
    with open(args.dataset_list, 'r') as f:
        test_files = list(f.read().splitlines())
    print('=> {} files to test'.format(len(test_files)))
  
    output_dir = Path(args.output_dir)
    output_dir.makedirs_p()

    for j in tqdm(range(len(test_files))):
        tgt_img, ori_img = load_tensor_image(dataset_dir + test_files[j], args)
        pred_disp = disp_net(tgt_img).cpu().numpy()[0,0]
        # pdb.set_trace()
        '''
            fig = plt.figure(9, figsize=(8, 10))
            fig.add_subplot(2,1,1)
            plt.imshow(ori_img.transpose(1,2,0)/255, vmin=0, vmax=1), plt.grid(linestyle=':', linewidth=0.4), plt.colorbar()
            fig.add_subplot(2,1,2)
            plt.imshow(pred_disp), plt.grid(linestyle=':', linewidth=0.4), plt.colorbar()
            fig.tight_layout(), plt.ion(), plt.show()
        '''

        if j == 0:
            predictions = np.zeros((len(test_files), *pred_disp.shape))
        predictions[j] = 1/pred_disp
    
    np.save(output_dir/'predictions.npy', predictions)
Exemplo n.º 4
0
def main():
    global best_error, n_iter, device
    args = parser.parse_args()

    timestamp = datetime.datetime.now().strftime("%m-%d-%H:%M")
    save_path = Path(args.name)
    args.save_path = 'checkpoints' / save_path / timestamp
    print('=> will save everything to {}'.format(args.save_path))
    args.save_path.makedirs_p()

    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    cudnn.deterministic = True
    cudnn.benchmark = True

    training_writer = SummaryWriter(args.save_path)
    output_writers = []
    if args.log_output:
        for i in range(3):
            output_writers.append(
                SummaryWriter(args.save_path / 'valid' / str(i)))

    # Data loading code
    normalize = custom_transforms.Normalize(mean=[0.45, 0.45, 0.45],
                                            std=[0.225, 0.225, 0.225])

    train_transform = custom_transforms.Compose([
        custom_transforms.RandomHorizontalFlip(),
        custom_transforms.RandomScaleCrop(),
        custom_transforms.ArrayToTensor(), normalize
    ])

    valid_transform = custom_transforms.Compose(
        [custom_transforms.ArrayToTensor(), normalize])

    print("=> fetching scenes in '{}'".format(args.data))
    train_set = SequenceFolder(args.data,
                               transform=train_transform,
                               seed=args.seed,
                               train=True,
                               sequence_length=args.sequence_length)

    # if no Groundtruth is avalaible, Validation set is the same type as training set to measure photometric loss from warping
    if args.with_gt:
        from datasets.validation_folders import ValidationSet
        val_set = ValidationSet(args.data, transform=valid_transform)
    else:
        val_set = SequenceFolder(args.data,
                                 transform=valid_transform,
                                 seed=args.seed,
                                 train=False,
                                 sequence_length=args.sequence_length)
    print('{} samples found in {} train scenes'.format(len(train_set),
                                                       len(train_set.scenes)))
    print('{} samples found in {} valid scenes'.format(len(val_set),
                                                       len(val_set.scenes)))
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_set,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    if args.epoch_size == 0:
        args.epoch_size = len(train_loader)

    # create model
    print("=> creating model")
    disp_net = models.DispResNet(args.resnet_layers,
                                 args.with_pretrain).to(device)
    pose_net = models.PoseResNet(18, args.with_pretrain).to(device)

    # load parameters
    if args.pretrained_disp:
        print("=> using pre-trained weights for DispResNet")
        weights = torch.load(args.pretrained_disp)
        disp_net.load_state_dict(weights['state_dict'], strict=False)

    if args.pretrained_pose:
        print("=> using pre-trained weights for PoseResNet")
        weights = torch.load(args.pretrained_pose)
        pose_net.load_state_dict(weights['state_dict'], strict=False)

    disp_net = torch.nn.DataParallel(disp_net)
    pose_net = torch.nn.DataParallel(pose_net)

    print('=> setting adam solver')
    optim_params = [{
        'params': disp_net.parameters(),
        'lr': args.lr
    }, {
        'params': pose_net.parameters(),
        'lr': args.lr
    }]
    optimizer = torch.optim.Adam(optim_params,
                                 betas=(args.momentum, args.beta),
                                 weight_decay=args.weight_decay)

    with open(args.save_path / args.log_summary, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss', 'validation_loss'])

    with open(args.save_path / args.log_full, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow([
            'train_loss', 'photo_loss', 'smooth_loss',
            'geometry_consistency_loss'
        ])

    logger = TermLogger(n_epochs=args.epochs,
                        train_size=min(len(train_loader), args.epoch_size),
                        valid_size=len(val_loader))
    logger.epoch_bar.start()

    for epoch in range(args.epochs):
        logger.epoch_bar.update(epoch)

        # train for one epoch
        logger.reset_train_bar()
        train_loss = train(args, train_loader, disp_net, pose_net, optimizer,
                           args.epoch_size, logger, training_writer)
        logger.train_writer.write(' * Avg Loss : {:.3f}'.format(train_loss))

        # evaluate on validation set
        logger.reset_valid_bar()
        if args.with_gt:
            errors, error_names = validate_with_gt(args, val_loader, disp_net,
                                                   epoch, logger,
                                                   output_writers)
        else:
            errors, error_names = validate_without_gt(args, val_loader,
                                                      disp_net, pose_net,
                                                      epoch, logger,
                                                      output_writers)
        error_string = ', '.join('{} : {:.3f}'.format(name, error)
                                 for name, error in zip(error_names, errors))
        logger.valid_writer.write(' * Avg {}'.format(error_string))

        for error, name in zip(errors, error_names):
            training_writer.add_scalar(name, error, epoch)

        # Up to you to chose the most relevant error to measure your model's performance, careful some measures are to maximize (such as a1,a2,a3)
        decisive_error = errors[1]
        if best_error < 0:
            best_error = decisive_error

        # remember lowest error and save checkpoint
        is_best = decisive_error < best_error
        best_error = min(best_error, decisive_error)
        save_checkpoint(args.save_path, {
            'epoch': epoch + 1,
            'state_dict': disp_net.module.state_dict()
        }, {
            'epoch': epoch + 1,
            'state_dict': pose_net.module.state_dict()
        }, is_best)

        with open(args.save_path / args.log_summary, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([train_loss, decisive_error])
    logger.epoch_bar.finish()
Exemplo n.º 5
0
def main(args):

    disp_net = models.DispResNet(args.resnet_layers, False).to(device)

    weights_p = Path(args.models_path) / args.model_name
    dataset_path = Path(args.dataset_path)
    root = Path(args.wk_root)

    #model
    weights = torch.load(weights_p)
    disp_net.load_state_dict(weights['state_dict'])
    disp_net.eval()

    #inputs
    txt = root / 'splits' / args.split / args.txt_file
    print('\n')
    print('-> inference file: ', txt)
    print('-> model_path: ', weights_p)

    rel_paths = readlines(txt)

    test_files = []
    if args.split in ['custom', 'custom_lite', 'eigen', 'eigen_zhou']:  #kitti
        for item in rel_paths:
            item = item.split(' ')
            if item[2] == 'l': camera = 'image_02'
            elif item[2] == 'r': camera = 'image_01'
            test_files.append(dataset_path / item[0] / camera / 'data' /
                              "{:010d}.png".format(int(item[1])))

    print('-> {} files to test'.format(len(test_files)))

    output_dir = Path(args.output_dir)
    output_dir.makedirs_p()

    avg_time = 0
    for img_p in tqdm(test_files):
        tgt_img = load_tensor_image(img_p, args)
        # tgt_img = load_tensor_image( dataset_dir + test_files[j], args)

        # compute speed
        torch.cuda.synchronize()
        t_start = time.time()

        output = disp_net(tgt_img)

        torch.cuda.synchronize()
        elapsed_time = time.time() - t_start

        avg_time += elapsed_time

        pred_disp = output.cpu().numpy()[0, 0]

        pred_depth = 1 / pred_disp

        rel_path = img_p.relpath(dataset_path)
        if args.split in ['eigen']:
            output_name = str(rel_path).split('/')[-4] + '_{}'.format(
                rel_path.stem)

        plt.imsave(output_dir / output_name + '.png', pred_disp, cmap='magma')

    avg_time /= len(test_files)
    print('Avg Time: ', avg_time, ' seconds.')
    print('Avg Speed: ', 1.0 / avg_time, ' fps')
Exemplo n.º 6
0
def main():
    global best_error, n_iter, device
    args = parser.parse_args()
    if args.dataset_format == 'stacked':
        from datasets.stacked_sequence_folders import SequenceFolder
    elif args.dataset_format == 'sequential':
        from datasets.sequence_folders import SequenceFolder, StereoSequenceFolder
    save_path = save_path_formatter(args, parser)
    args.save_path = 'checkpoints'/save_path
    print('=> will save everything to {}'.format(args.save_path))
    args.save_path.makedirs_p()
    torch.manual_seed(args.seed)
    if args.evaluate:
        args.epochs = 0

    training_writer = SummaryWriter(args.save_path)
    output_writers = []
    if args.log_output:
        for i in range(3):
            output_writers.append(SummaryWriter(args.save_path/'valid'/str(i)))

    # Data loading code
    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                            std=[0.5, 0.5, 0.5])
    train_transform = custom_transforms.Compose([
        custom_transforms.RandomHorizontalFlip(),
        custom_transforms.RandomScaleCrop(),
        custom_transforms.ArrayToTensor(),
        normalize
    ])

    valid_transform = custom_transforms.Compose([custom_transforms.ArrayToTensor(), normalize])

    print("=> fetching scenes in '{}'".format(args.data))
    train_set = StereoSequenceFolder(
        args.data,
        transform=train_transform,
        seed=args.seed,
        train=True,
        sequence_length=args.sequence_length
    )

    # if no Groundtruth is avalaible, Validation set is the same type as training set to measure photometric loss from warping
    if args.with_gt:
        from datasets.validation_folders import ValidationSet
        val_set = ValidationSet(
            args.data,
            transform=valid_transform
        )
    else:
        val_set = StereoSequenceFolder(
            args.data,
            transform=valid_transform,
            seed=args.seed,
            train=False,
            sequence_length=args.sequence_length,
        )
    print('{} samples found in {} train scenes'.format(len(train_set), len(train_set.scenes)))
    print('{} samples found in {} valid scenes'.format(len(val_set), len(val_set.scenes)))
    train_loader = torch.utils.data.DataLoader(
        train_set, batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True)
    val_loader = torch.utils.data.DataLoader(
        val_set, batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    # 没有epoch_size的时候(=0),每个epoch训练train_set中所有的samples
    # 有epoch_size的时候,每个epoch只训练一部分train_set
    if args.epoch_size == 0:
        args.epoch_size = len(train_loader)

    # create model
    # 初始化网络结构
    print("=> creating model")

    # disp_net = models.DispNetS().to(device)
    disp_net = models.DispResNet(3).to(device)
    output_exp = args.mask_loss_weight > 0
    if not output_exp:
        print("=> no mask loss, PoseExpnet will only output pose")
    # 如果有mask loss,PoseExpNet 要输出mask和pose estimation,因为两个输出共享encoder网络
    # pose_exp_net = PoseExpNet(nb_ref_imgs=args.sequence_length - 1, output_exp=args.mask_loss_weight > 0).to(device)
    pose_exp_net = models.PoseExpNet(nb_ref_imgs=args.sequence_length - 1, output_exp=args.mask_loss_weight > 0).to(device)

    if args.pretrained_exp_pose:
        print("=> using pre-trained weights for explainabilty and pose net")
        weights = torch.load(args.pretrained_exp_pose)
        pose_exp_net.load_state_dict(weights['state_dict'], strict=False)
    else:
        pose_exp_net.init_weights()

    if args.pretrained_disp:
        print("=> using pre-trained weights for Dispnet")
        weights = torch.load(args.pretrained_disp)
        disp_net.load_state_dict(weights['state_dict'])
    else:
        disp_net.init_weights()

    cudnn.benchmark = True
    # 并行化
    disp_net = torch.nn.DataParallel(disp_net)
    pose_exp_net = torch.nn.DataParallel(pose_exp_net)

    # 训练方式:Adam
    print('=> setting adam solver')
    # 两个网络一起
    optim_params = [
        {'params': disp_net.parameters(), 'lr': args.lr},
        {'params': pose_exp_net.parameters(), 'lr': args.lr}
    ]
    optimizer = torch.optim.Adam(optim_params,
                                 betas=(args.momentum, args.beta),
                                 weight_decay=args.weight_decay)

    with open(args.save_path/args.log_summary, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss', 'validation_loss'])

    with open(args.save_path/args.log_full, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss', 'photo_loss', 'explainability_loss', 'smooth_loss'])

    # 对pretrained模型先做评估
    if args.pretrained_disp or args.evaluate:
        if args.with_gt:
            errors, error_names = validate_with_gt(args, val_loader, disp_net, 0, output_writers)
        else:
            errors, error_names = validate_without_gt(args, val_loader, disp_net, pose_exp_net, 0, output_writers)
        for error, name in zip(errors, error_names):
            training_writer.add_scalar(name, error, 0)
        error_string = ', '.join('{} : {:.3f}'.format(name, error) for name, error in zip(error_names[2:9], errors[2:9]))

    # 正式训练
    for epoch in range(args.epochs):

        # train for one epoch 训练一个周期
        print('\n')
        train_loss = train(args, train_loader, disp_net, pose_exp_net, optimizer, args.epoch_size, training_writer, epoch)

        # evaluate on validation set
        print('\n')
        if args.with_gt:
            errors, error_names = validate_with_gt(args, val_loader, disp_net, epoch, output_writers)
        else:
            errors, error_names = validate_without_gt(args, val_loader, disp_net, pose_exp_net, epoch, output_writers)
        error_string = ', '.join('{} : {:.3f}'.format(name, error) for name, error in zip(error_names, errors))

        for error, name in zip(errors, error_names):
            training_writer.add_scalar(name, error, epoch)

        # Up to you to chose the most relevant error to measure your model's performance, careful some measures are to maximize (such as a1,a2,a3)
        # 验证输出四个loss:总体final loss,warping loss以及mask正则化loss
        # 可自选以哪一种loss作为best model的标准
        decisive_error = errors[0]
        if best_error < 0:
            best_error = decisive_error

        # remember lowest error and save checkpoint
        # 保存validation最佳model
        is_best = decisive_error < best_error
        best_error = min(best_error, decisive_error)
        save_checkpoint(
            args.save_path, {
                'epoch': epoch + 1,
                'state_dict': disp_net.module.state_dict()
            }, {
                'epoch': epoch + 1,
                'state_dict': pose_exp_net.module.state_dict()
            },
            is_best)

        with open(args.save_path/args.log_summary, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([train_loss, decisive_error])