예제 #1
0
def load_dataset(dataset, data_root_x, max_step, image_width, data_type):
    if dataset == 'moving_mnist':
        train_data = MovingMNIST(train=True,
                                 data_root=data_root_x,
                                 seq_len=max_step,
                                 image_size=image_width,
                                 num_digits=2)
        test_data = MovingMNIST(train=False,
                                data_root=data_root_x,
                                seq_len=max_step,
                                image_size=image_width,
                                num_digits=2)
    elif dataset == 'suncg':
        train_data = suncg.SUNCG(train=True,
                                 data_root=data_root_x,
                                 seq_len=max_step,
                                 image_size=image_width)
        test_data = suncg.SUNCG(train=False,
                                data_root=data_root_x,
                                seq_len=max_step,
                                image_size=image_width)
    elif dataset == 'kth':
        train_data = KTH(train=True,
                         data_root=data_root_x,
                         seq_len=max_step,
                         image_size=image_width,
                         data_type=data_type)
        test_data = KTH(train=False,
                        data_root=data_root_x,
                        seq_len=max_step,
                        image_size=image_width,
                        data_type=data_type)
    return train_data, test_data
예제 #2
0
def load_dataset(opt):
    if opt.dataset == 'mnist':
        train_data = MovingMNIST(train=True,
                                 data_root=opt.data_root,
                                 seq_len=opt.max_step,
                                 image_size=opt.image_width,
                                 num_digits=2)
        test_data = MovingMNIST(train=False,
                                data_root=opt.data_root,
                                seq_len=opt.max_step,
                                image_size=opt.image_width,
                                num_digits=2)
    elif opt.dataset == 'suncg':
        train_data = suncg.SUNCG(train=True,
                                 data_root=opt.data_root,
                                 seq_len=opt.max_step,
                                 image_size=opt.image_width)
        test_data = suncg.SUNCG(train=False,
                                data_root=opt.data_root,
                                seq_len=opt.max_step,
                                image_size=opt.image_width)
    elif opt.dataset == 'kth':
        train_data = KTH(train=True,
                         data_root=opt.data_root,
                         seq_len=opt.max_step,
                         image_size=opt.image_width,
                         data_type=opt.data_type)
        test_data = KTH(train=False,
                        data_root=opt.data_root,
                        seq_len=opt.max_step,
                        image_size=opt.image_width,
                        data_type=opt.data_type)
    return train_data, test_data
예제 #3
0
def load_dataset(opt):
    if opt.dataset == 'smmnist':
        from data.moving_mnist import MovingMNIST
        train_data = MovingMNIST(
                train=True,
                data_root=opt.data_root,
                seq_len=opt.max_step,
                image_size=opt.image_width,
                deterministic=False,
                num_digits=opt.num_digits)
        test_data = MovingMNIST(
                train=False,
                data_root=opt.data_root,
                seq_len=opt.n_eval,
                image_size=opt.image_width,
                deterministic=False,
                num_digits=opt.num_digits)
    elif opt.dataset == 'bair':
        from data.bair import RobotPush 
        train_data = RobotPush(
                data_root=opt.data_root,
                train=True,
                seq_len=opt.max_step,
                image_size=opt.image_width)
        test_data = RobotPush(
                data_root=opt.data_root,
                train=False,
                seq_len=opt.n_eval,
                image_size=opt.image_width)
    elif opt.dataset == 'KTH':
        from data.kth import KTH
        train_data = KTH(
                root=opt.data_root,
                train=True,
                seq_len=opt.max_step,
                label="./label/train.txt")
        test_data = KTH(
                root=opt.data_root,
                train=False,
                seq_len=opt.max_step,
                label="./label/test.txt")
    
    return train_data, test_data
예제 #4
0
def load_dataset(opt):
    if opt.dataset == 'smmnist':
        from data.moving_mnist import MovingMNIST
        train_data = MovingMNIST(
                train=True,
                data_root=opt.data_root,
                seq_len=opt.n_past+opt.n_future,
                image_size=opt.image_width,
                deterministic=False,
                num_digits=opt.num_digits)
        test_data = MovingMNIST(
                train=False,
                data_root=opt.data_root,
                seq_len=opt.n_eval,
                image_size=opt.image_width,
                deterministic=False,
                num_digits=opt.num_digits)
    elif opt.dataset == 'bair':
        from data.bair import RobotPush 
        train_data = RobotPush(
                data_root=opt.data_root,
                train=True,
                seq_len=opt.n_past+opt.n_future,
                image_size=opt.image_width)
        test_data = RobotPush(
                data_root=opt.data_root,
                train=False,
                seq_len=opt.n_eval,
                image_size=opt.image_width)
    elif opt.dataset == 'kth':
        from data.kth import KTH 
        train_data = KTH(
                train=True, 
                data_root=opt.data_root,
                seq_len=opt.n_past+opt.n_future, 
                image_size=opt.image_width)
        test_data = KTH(
                train=False, 
                data_root=opt.data_root,
                seq_len=opt.n_eval, 
                image_size=opt.image_width)
    
    return train_data, test_data
예제 #5
0
def load_dataset(dataset):
    if dataset == 'mmnist':
        from data.moving_mnist import MovingMNIST
        train_data = MovingMNIST(train=True)
    elif dataset == 'kth':
        from data.kth import KTH
        train_data = KTH(train=True)
    elif dataset == 'mazes':
        from data.mazes import Mazes
        train_data = Mazes()
    return train_data
예제 #6
0
def load_data(opt):
    """
    :return: raw data
    """
    if opt.dataset == 'moving_mnist':
        train_data = MovingMNIST(train=True,
                                 data_root=opt.data_root,
                                 seq_len=opt.max_step,
                                 image_size=opt.image_width,
                                 num_digits=2)
        test_data = MovingMNIST(train=False,
                                data_root=opt.data_root,
                                seq_len=opt.max_step,
                                image_size=opt.image_width,
                                num_digits=2)
    elif opt.dataset == 'suncg':
        train_data = suncg.SUNCG(train=True,
                                 data_root=opt.data_root,
                                 seq_len=opt.max_step,
                                 image_size=opt.image_width)
        test_data = suncg.SUNCG(train=False,
                                data_root=opt.data_root,
                                seq_len=opt.max_step,
                                image_size=opt.image_width)
    elif opt.dataset == 'kth':
        train_data = KTH(train=True,
                         epoch_samples=opt.epoch_size,
                         pose=opt.pose,
                         data_root=opt.data_root,
                         seq_len=opt.max_step,
                         image_size=opt.image_width,
                         data_type=opt.data_type)
        test_data = KTH(train=False,
                        epoch_samples=opt.epoch_size,
                        pose=opt.pose,
                        data_root=opt.data_root,
                        seq_len=opt.max_step,
                        image_size=opt.image_width,
                        data_type=opt.data_type)
    return train_data, test_data
예제 #7
0
def load_dataset(dataset):
    if dataset == 'mmnist':
        from data.moving_mnist import MovingMNIST
        # train_data = MovingMNIST(train=True, data_root='../data/mmnist/mnist_test_set.npy')
        train_data = MovingMNIST(
            train=True, data_root='../data/mmnist/mnist_training_set.npy')
    elif dataset == 'kth':
        from data.kth import KTH
        train_data = KTH(train=True)
    elif dataset == 'mazes':
        from data.mazes import Mazes
        train_data = Mazes(data_root='../data/mazes/np_mazes_train.npy')
    return train_data
예제 #8
0
def load_dataset(opt):
    if opt.data == 'moving_mnist':
        train_data = MovingMNIST(train=True,
                                 seq_len=opt.max_step,
                                 image_size=opt.image_width,
                                 num_digits=2)
        test_data = MovingMNIST(train=False,
                                seq_len=opt.max_step,
                                image_size=opt.image_width,
                                num_digits=2)
        load_workers = 5
    elif opt.data == 'suncg':
        train_data = suncg.SUNCG(True, opt.max_step, opt.image_width)
        test_data = suncg.SUNCG(False, opt.max_step, opt.image_width)
        load_workers = 5
    elif opt.data == 'suncg_dual':
        train_data = suncg.DualSUNCG(opt.max_step, opt.image_width)
        test_data = suncg.DualSUNCG(opt.max_step, opt.image_width)
        load_workers = 5
    elif opt.data == 'kth':
        train_data = KTH(True, opt.max_step, opt.image_width)
        test_data = KTH(False, opt.max_step, opt.image_width)
        load_workers = 0
    return train_data, test_data, load_workers