예제 #1
0
def get_dataset(all_cfg):
    cfg = all_cfg.data

    input_transform = transforms.Compose([
        sep_transforms.ArrayToTensor(),
        transforms.Normalize(mean=[0, 0, 0], std=[255, 255, 255]),
    ])

    co_transform = get_co_transforms(aug_args=all_cfg.data_aug)

    if cfg.type == 'KITTI_15':
        train_input_transform = copy.deepcopy(input_transform)
        train_input_transform.transforms.insert(
            0, sep_transforms.Zoom(*cfg.train_shape))
        train_set = KITTIRawFile(
            cfg.root,
            cfg.train_file,
            with_stereo=cfg.train_stereo,
            transform=train_input_transform,
            co_transform=co_transform  # no target here
        )

        valid_input_transform = copy.deepcopy(input_transform)
        valid_input_transform.transforms.insert(
            0, sep_transforms.Zoom(*cfg.test_shape))

        valid_set = KITTIFlow(
            cfg.flow_data,
            with_stereo=cfg.test_stereo,
            transform=valid_input_transform,
        )
    else:
        raise NotImplementedError(cfg.type)
    return train_set, valid_set
예제 #2
0
 def __init__(self, cfg):
     self.cfg = EasyDict(cfg)
     self.device = torch.device(
         "cuda") if torch.cuda.is_available() else torch.device("cpu")
     self.model = self.init_model()
     self.input_transform = transforms.Compose([
         sep_transforms.Zoom(*self.cfg.test_shape),
         sep_transforms.ArrayToTensor(),
         transforms.Normalize(mean=[0, 0, 0], std=[255, 255, 255]),
     ])
예제 #3
0
def get_dataset(all_cfg):
    cfg = all_cfg.data

    input_transform = transforms.Compose([
        sep_transforms.ArrayToTensor(),
        transforms.Normalize(mean=[0, 0, 0], std=[255, 255, 255]),
    ])

    co_transform = get_co_transforms(aug_args=all_cfg.data_aug)

    if cfg.type == 'Sintel_Flow':
        ap_transform = get_ap_transforms(cfg.at_cfg) if cfg.run_at else None

        train_set_1 = Sintel(cfg.root_sintel, n_frames=cfg.train_n_frames, type='clean',
                             split='training', subsplit=cfg.train_subsplit,
                             with_flow=False,
                             ap_transform=ap_transform,
                             transform=input_transform,
                             co_transform=co_transform
                             )
        train_set_2 = Sintel(cfg.root_sintel, n_frames=cfg.train_n_frames, type='final',
                             split='training', subsplit=cfg.train_subsplit,
                             with_flow=False,
                             ap_transform=ap_transform,
                             transform=input_transform,
                             co_transform=co_transform
                             )
        train_set = ConcatDataset([train_set_1, train_set_2])

        valid_input_transform = copy.deepcopy(input_transform)
        valid_input_transform.transforms.insert(0, sep_transforms.Zoom(*cfg.test_shape))

        valid_set_1 = Sintel(cfg.root_sintel, n_frames=cfg.val_n_frames, type='clean',
                             split='training', subsplit=cfg.val_subsplit,
                             transform=valid_input_transform,
                             target_transform={'flow': sep_transforms.ArrayToTensor()}
                             )
        valid_set_2 = Sintel(cfg.root_sintel, n_frames=cfg.val_n_frames, type='final',
                             split='training', subsplit=cfg.val_subsplit,
                             transform=valid_input_transform,
                             target_transform={'flow': sep_transforms.ArrayToTensor()}
                             )
        valid_set = ConcatDataset([valid_set_1, valid_set_2])

    elif cfg.type == 'Sintel_Raw':
        train_set = SintelRaw(cfg.root_sintel_raw, n_frames=cfg.train_n_frames,
                              transform=input_transform, co_transform=co_transform)
        valid_input_transform = copy.deepcopy(input_transform)
        valid_input_transform.transforms.insert(0, sep_transforms.Zoom(*cfg.test_shape))
        valid_set_1 = Sintel(cfg.root_sintel, n_frames=cfg.val_n_frames, type='clean',
                             split='training', subsplit=cfg.val_subsplit,
                             transform=valid_input_transform,
                             target_transform={'flow': sep_transforms.ArrayToTensor()}
                             )
        valid_set_2 = Sintel(cfg.root_sintel, n_frames=cfg.val_n_frames, type='final',
                             split='training', subsplit=cfg.val_subsplit,
                             transform=valid_input_transform,
                             target_transform={'flow': sep_transforms.ArrayToTensor()}
                             )
        valid_set = ConcatDataset([valid_set_1, valid_set_2])
    else:
        raise NotImplementedError(cfg.type)
    return train_set, valid_set