コード例 #1
0
def load_dataset(dataset, data_root_x, max_step, image_width, data_type):
    if dataset == 'moving_mnist':
        train_data = MovingMNIST(train=True,
                                 data_root=data_root_x,
                                 seq_len=max_step,
                                 image_size=image_width,
                                 num_digits=2)
        test_data = MovingMNIST(train=False,
                                data_root=data_root_x,
                                seq_len=max_step,
                                image_size=image_width,
                                num_digits=2)
    elif dataset == 'suncg':
        train_data = suncg.SUNCG(train=True,
                                 data_root=data_root_x,
                                 seq_len=max_step,
                                 image_size=image_width)
        test_data = suncg.SUNCG(train=False,
                                data_root=data_root_x,
                                seq_len=max_step,
                                image_size=image_width)
    elif dataset == 'kth':
        train_data = KTH(train=True,
                         data_root=data_root_x,
                         seq_len=max_step,
                         image_size=image_width,
                         data_type=data_type)
        test_data = KTH(train=False,
                        data_root=data_root_x,
                        seq_len=max_step,
                        image_size=image_width,
                        data_type=data_type)
    return train_data, test_data
コード例 #2
0
ファイル: utils.py プロジェクト: paidaxing13/DRNET-1
def load_dataset(opt):
    if opt.dataset == 'mnist':
        train_data = MovingMNIST(train=True,
                                 data_root=opt.data_root,
                                 seq_len=opt.max_step,
                                 image_size=opt.image_width,
                                 num_digits=2)
        test_data = MovingMNIST(train=False,
                                data_root=opt.data_root,
                                seq_len=opt.max_step,
                                image_size=opt.image_width,
                                num_digits=2)
    elif opt.dataset == 'suncg':
        train_data = suncg.SUNCG(train=True,
                                 data_root=opt.data_root,
                                 seq_len=opt.max_step,
                                 image_size=opt.image_width)
        test_data = suncg.SUNCG(train=False,
                                data_root=opt.data_root,
                                seq_len=opt.max_step,
                                image_size=opt.image_width)
    elif opt.dataset == 'kth':
        train_data = KTH(train=True,
                         data_root=opt.data_root,
                         seq_len=opt.max_step,
                         image_size=opt.image_width,
                         data_type=opt.data_type)
        test_data = KTH(train=False,
                        data_root=opt.data_root,
                        seq_len=opt.max_step,
                        image_size=opt.image_width,
                        data_type=opt.data_type)
    return train_data, test_data
コード例 #3
0
def load_dataset(config, train):
    """
    Loads a dataset.

    Parameters
    ----------
    config : DotDict
        Configuration to use.
    train : bool
        Whether to load the training or testing dataset.
    """
    name = config.dataset
    if name == 'smmnist':
        from data.mmnist import MovingMNIST
        return MovingMNIST.make_dataset(config.data_dir, config.nx, config.seq_len, config.max_speed,
                                        config.deterministic, config.ndigits, train)
    if name == 'kth':
        from data.kth import KTH
        return KTH.make_dataset(config.data_dir, config.nx, config.seq_len, train)
    if name == 'human':
        from data.human import Human
        return Human.make_dataset(config.data_dir, config.nx, config.seq_len, config.subsampling, train)
    if name == 'bair':
        from data.bair import Bair
        return Bair.make_dataset(config.data_dir, config.seq_len, train)
    raise ValueError(f'No dataset named `{name}`')
コード例 #4
0
ファイル: base.py プロジェクト: ry85/srvp
def load_dataset(config, train):
    """
    Loads a dataset.

    Parameters
    ----------
    config : helper.DotDict
        Configuration to use.
    train : bool
        Whether to load the training or testing dataset.

    Returns
    -------
    data.base.VideoDataset
        Dataset corresponding to the input configuration.
    """
    name = config.dataset
    if name == 'smmnist':
        from data.mmnist import MovingMNIST
        return MovingMNIST.make_dataset(config.data_dir, config.nx,
                                        config.seq_len, config.max_speed,
                                        config.deterministic, config.ndigits,
                                        train)
    if name == 'kth':
        from data.kth import KTH
        return KTH.make_dataset(config.data_dir, config.nx, config.seq_len,
                                train)
    if name == 'human':
        from data.human import Human
        return Human.make_dataset(config.data_dir, config.nx, config.seq_len,
                                  config.subsampling, train)
    if name == 'bair':
        from data.bair import BAIR
        return BAIR.make_dataset(config.data_dir, config.seq_len, train)
    raise ValueError(f'No dataset named \'{name}\'')
コード例 #5
0
def load_dataset(opt):
    if opt.dataset == 'smmnist':
        from data.moving_mnist import MovingMNIST
        train_data = MovingMNIST(
                train=True,
                data_root=opt.data_root,
                seq_len=opt.max_step,
                image_size=opt.image_width,
                deterministic=False,
                num_digits=opt.num_digits)
        test_data = MovingMNIST(
                train=False,
                data_root=opt.data_root,
                seq_len=opt.n_eval,
                image_size=opt.image_width,
                deterministic=False,
                num_digits=opt.num_digits)
    elif opt.dataset == 'bair':
        from data.bair import RobotPush 
        train_data = RobotPush(
                data_root=opt.data_root,
                train=True,
                seq_len=opt.max_step,
                image_size=opt.image_width)
        test_data = RobotPush(
                data_root=opt.data_root,
                train=False,
                seq_len=opt.n_eval,
                image_size=opt.image_width)
    elif opt.dataset == 'KTH':
        from data.kth import KTH
        train_data = KTH(
                root=opt.data_root,
                train=True,
                seq_len=opt.max_step,
                label="./label/train.txt")
        test_data = KTH(
                root=opt.data_root,
                train=False,
                seq_len=opt.max_step,
                label="./label/test.txt")
    
    return train_data, test_data
コード例 #6
0
def load_dataset(opt):
    if opt.dataset == 'smmnist':
        from data.moving_mnist import MovingMNIST
        train_data = MovingMNIST(
                train=True,
                data_root=opt.data_root,
                seq_len=opt.n_past+opt.n_future,
                image_size=opt.image_width,
                deterministic=False,
                num_digits=opt.num_digits)
        test_data = MovingMNIST(
                train=False,
                data_root=opt.data_root,
                seq_len=opt.n_eval,
                image_size=opt.image_width,
                deterministic=False,
                num_digits=opt.num_digits)
    elif opt.dataset == 'bair':
        from data.bair import RobotPush 
        train_data = RobotPush(
                data_root=opt.data_root,
                train=True,
                seq_len=opt.n_past+opt.n_future,
                image_size=opt.image_width)
        test_data = RobotPush(
                data_root=opt.data_root,
                train=False,
                seq_len=opt.n_eval,
                image_size=opt.image_width)
    elif opt.dataset == 'kth':
        from data.kth import KTH 
        train_data = KTH(
                train=True, 
                data_root=opt.data_root,
                seq_len=opt.n_past+opt.n_future, 
                image_size=opt.image_width)
        test_data = KTH(
                train=False, 
                data_root=opt.data_root,
                seq_len=opt.n_eval, 
                image_size=opt.image_width)
    
    return train_data, test_data
コード例 #7
0
def load_dataset(dataset):
    if dataset == 'mmnist':
        from data.moving_mnist import MovingMNIST
        train_data = MovingMNIST(train=True)
    elif dataset == 'kth':
        from data.kth import KTH
        train_data = KTH(train=True)
    elif dataset == 'mazes':
        from data.mazes import Mazes
        train_data = Mazes()
    return train_data
コード例 #8
0
def load_data(opt):
    """
    :return: raw data
    """
    if opt.dataset == 'moving_mnist':
        train_data = MovingMNIST(train=True,
                                 data_root=opt.data_root,
                                 seq_len=opt.max_step,
                                 image_size=opt.image_width,
                                 num_digits=2)
        test_data = MovingMNIST(train=False,
                                data_root=opt.data_root,
                                seq_len=opt.max_step,
                                image_size=opt.image_width,
                                num_digits=2)
    elif opt.dataset == 'suncg':
        train_data = suncg.SUNCG(train=True,
                                 data_root=opt.data_root,
                                 seq_len=opt.max_step,
                                 image_size=opt.image_width)
        test_data = suncg.SUNCG(train=False,
                                data_root=opt.data_root,
                                seq_len=opt.max_step,
                                image_size=opt.image_width)
    elif opt.dataset == 'kth':
        train_data = KTH(train=True,
                         epoch_samples=opt.epoch_size,
                         pose=opt.pose,
                         data_root=opt.data_root,
                         seq_len=opt.max_step,
                         image_size=opt.image_width,
                         data_type=opt.data_type)
        test_data = KTH(train=False,
                        epoch_samples=opt.epoch_size,
                        pose=opt.pose,
                        data_root=opt.data_root,
                        seq_len=opt.max_step,
                        image_size=opt.image_width,
                        data_type=opt.data_type)
    return train_data, test_data
コード例 #9
0
def load_dataset(dataset):
    if dataset == 'mmnist':
        from data.moving_mnist import MovingMNIST
        # train_data = MovingMNIST(train=True, data_root='../data/mmnist/mnist_test_set.npy')
        train_data = MovingMNIST(
            train=True, data_root='../data/mmnist/mnist_training_set.npy')
    elif dataset == 'kth':
        from data.kth import KTH
        train_data = KTH(train=True)
    elif dataset == 'mazes':
        from data.mazes import Mazes
        train_data = Mazes(data_root='../data/mazes/np_mazes_train.npy')
    return train_data
コード例 #10
0
def load_dataset(opt):
    if opt.data == 'moving_mnist':
        train_data = MovingMNIST(train=True,
                                 seq_len=opt.max_step,
                                 image_size=opt.image_width,
                                 num_digits=2)
        test_data = MovingMNIST(train=False,
                                seq_len=opt.max_step,
                                image_size=opt.image_width,
                                num_digits=2)
        load_workers = 5
    elif opt.data == 'suncg':
        train_data = suncg.SUNCG(True, opt.max_step, opt.image_width)
        test_data = suncg.SUNCG(False, opt.max_step, opt.image_width)
        load_workers = 5
    elif opt.data == 'suncg_dual':
        train_data = suncg.DualSUNCG(opt.max_step, opt.image_width)
        test_data = suncg.DualSUNCG(opt.max_step, opt.image_width)
        load_workers = 5
    elif opt.data == 'kth':
        train_data = KTH(True, opt.max_step, opt.image_width)
        test_data = KTH(False, opt.max_step, opt.image_width)
        load_workers = 0
    return train_data, test_data, load_workers