コード例 #1
0
def get_test_dataloader(config, train=False):
    data_dir = osp.join('..', 'data', config.dataset)
    stats_filename = osp.join(data_dir, config.scene, 'stats.txt')
    stats = np.loadtxt(stats_filename)
    # transformer
    data_transform = transforms.Compose([
        transforms.Resize(config.image_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=stats[0], std=np.sqrt(stats[1]))
    ])
    target_transform = transforms.Lambda(lambda x: torch.from_numpy(x).float())

    kwargs = dict(scene=config.scene,
                  data_path=config.dataset_path,
                  train=train,
                  transform=data_transform,
                  target_transform=target_transform,
                  seed=config.seed,
                  data_dir=config.preprocessed_data_path,
                  config=config)
    if config.dataset == '7Scenes':
        dataset = SevenScenes(**kwargs)
    elif config.dataset == 'RobotCar':
        dataset = RobotCar(**kwargs)
    elif config.dataset == 'SenseTime':
        dataset = SenseTime(**kwargs)
    elif config.dataset == 'Cambridge' or config.dataset == 'NewCambridge':
        dataset = Cambridge(**kwargs)

    loader = DataLoader(dataset,
                        batch_size=config.batch_size,
                        shuffle=False,
                        num_workers=config.num_workers,
                        pin_memory=True)
    return loader
コード例 #2
0
def get_posenet_train_dataloader(config):
    data_transform, target_transform = get_train_transforms(config)
    kwargs = dict(
        scene=config.scene,
        data_path=config.dataset_path,
        transform=data_transform,
        target_transform=target_transform,
        seed=config.seed,
        data_dir=config.preprocessed_data_path,
        config=config
    )
    if config.dataset == '7Scenes':
        train_data = SevenScenes(train=True,  **kwargs)
        valid_data = SevenScenes(train=False,  **kwargs)
    elif config.dataset == 'RobotCar':
        train_data = RobotCar(train=True,  **kwargs)
        valid_data = RobotCar(train=False,  **kwargs)
    elif config.dataset == 'SenseTime':
        train_data = SenseTime(train=True, **kwargs)
        valid_data = SenseTime(train=False, **kwargs)
    elif config.dataset == 'Cambridge' or config.dataset == 'NewCambridge':
        train_data = Cambridge(train=True, **kwargs)
        valid_data = Cambridge(train=False, **kwargs)
    else:
        raise NotImplementedError

    dataloader_kwargs = dict(
        batch_size=config.batch_size,
        shuffle=config.shuffle,
        num_workers=config.num_workers,
        pin_memory=True,
        collate_fn=safe_collate
    )
    train_dataloader = torch.utils.data.DataLoader(
        train_data,
        **dataloader_kwargs
    )
    if config.do_val:
        val_dataloader = torch.utils.data.DataLoader(
            valid_data,
            **dataloader_kwargs
        )
    else:
        val_dataloader = None

    return train_dataloader, val_dataloader
コード例 #3
0
def ResizeImages(train=True):
    data_transform = transforms.Compose([
        transforms.Resize(image_size),
        transforms.Lambda(lambda x: np.asarray(x))
    ])
    dset = SenseTime(scene=args.scene,
                     data_path=args.data_dir,
                     train=train,
                     transform=data_transform)
    loader = DataLoader(dset,
                        batch_size=batch_size,
                        num_workers=num_workers,
                        collate_fn=safe_collate)
    base_dir = osp.join(args.data_dir, args.scene)
    if train:
        split_filename = osp.join(base_dir, 'train_split.txt')
    else:
        split_filename = osp.join(base_dir, 'test_split.txt')
    with open(split_filename, 'r') as f:
        seqs = [l.rstrip() for l in f if not l.startswith('#')]
    im_fns = []
    for seq in seqs:
        seq_dir = osp.join(base_dir, seq)
        count = np.loadtxt(osp.join(seq_dir, 'images', 'count.txt'))
        print('count = {}'.format(count))
        im_fns.extend([
            osp.join(seq_dir, 'resized_images', 'Frame{:06d}.jpg'.format(idx))
            for idx in range(count)
        ])
        np.savetxt(osp.join(seq_dir, 'resized_images', 'count.txt'), [count],
                   fmt='%d')
    assert len(dset) == len(im_fns)

    for batch_idx, (imgs, _) in enumerate(loader):
        for idx, im in enumerate(imgs):
            im_fn = im_fns[batch_idx * batch_size + idx]
            im = Image.fromarray(im.numpy())
            try:
                im.save(im_fn)
            except IOError:
                print('IOError while saving {:s}'.format(im_fn))
        if batch_idx % 50 == 0:
            print('Processed {:d}/{:d}'.format(batch_idx * batch_size,
                                               len(dset)))
コード例 #4
0
ファイル: train.py プロジェクト: zju3dv/RVL-Dynamic
    data_transform, target_transform = get_transforms(configuration)

    kwargs = dict(scene=configuration.scene,
                  data_path=configuration.dataset_path,
                  transform=data_transform,
                  target_transform=target_transform,
                  seed=configuration.seed,
                  data_dir=configuration.preprocessed_data_path)
    if configuration.dataset == '7Scenes':
        train_data = SevenScenes(train=True, **kwargs)
        valid_data = SevenScenes(train=False, **kwargs)
    elif configuration.dataset == 'RobotCar':
        train_data = RobotCar(train=True, **kwargs)
        valid_data = RobotCar(train=False, **kwargs)
    elif configuration.dataset == 'SenseTime':
        train_data = SenseTime(train=True, **kwargs)
        valid_data = SenseTime(train=False, **kwargs)
    else:
        raise NotImplementedError

    # Trainer
    print("Setup trainer...")
    pose_stats_file = osp.join(configuration.preprocessed_data_path,
                               configuration.scene, 'pose_stats.txt')

    trainer = Trainer(
        model=model,
        optimizer=optimizer,
        train_criterion=train_criterion,
        val_criterion=val_criterion,
        result_criterion=AbsoluteCriterion(),
コード例 #5
0
# In[6]:


# get statistical information from training set only

data_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.ToTensor()
])
kwargs = dict(
    scene=args.scene,
    data_path=args.data_dir,
    train=True,
    transform=data_transform
)
dset = SenseTime(**kwargs)

loader = DataLoader(
    dset,
    batch_size=batch_size,
    num_workers=num_workers,
    collate_fn=safe_collate
)
acc = np.zeros((3, image_size[0], image_size[1]))
sq_acc = np.zeros((3, image_size[0], image_size[1]))
for batch_idx, (imgs, _) in enumerate(loader):
    imgs = imgs.numpy()
    acc += np.sum(imgs, axis=0)
    sq_acc += np.sum(imgs**2, axis=0)

    if batch_idx % 50 == 0: