Exemple #1
0
def _get_maf_original(data_name):
    warnings.warn(
        "This function should generally not be called because it "
        "requires special setup but is kept here in order to reproduce functions if "
        "needed.")
    if sys.version_info < (3, ):
        # Load MNIST from MAF code
        maf_path = os.path.join(os.path.dirname(os.path.realpath(__file__)),
                                "..", "..", "maf")
        sys.path.append(maf_path)
        # noinspection PyPackageRequirements
        import datasets  # maf/datasets/*

        # Reset datasets root directory relative to this file
        datasets.root = os.path.join(maf_path, "data") + "/"

        # Copied from maf/experiments.py
        if data_name == "mnist":
            data = datasets.MNIST(logit=True, dequantize=True)
        elif data_name == "bsds300":
            data = datasets.BSDS300()
        elif data_name == "cifar10":
            data = datasets.CIFAR10(logit=True, flip=True, dequantize=True)
        elif data_name == "power":
            data = datasets.POWER()
        elif data_name == "gas":
            data = datasets.GAS()
        elif data_name == "hepmass":
            data = datasets.HEPMASS()
        elif data_name == "miniboone":
            data = datasets.MINIBOONE()
        else:
            raise ValueError("Unknown dataset")

        # Make a dictionary instead of pickled object for better compatibility
        if hasattr(data.trn, "labels"):
            data_dict = dict(
                X_train=data.trn.x,
                y_train=data.trn.labels,
                X_validation=data.val.x,
                y_validation=data.val.labels,
                X_test=data.tst.x,
                y_test=data.tst.labels,
                data_name=data_name,
            )
        else:
            data_dict = dict(
                X_train=data.trn.x,
                X_validation=data.val.x,
                X_test=data.tst.x,
                data_name=data_name,
            )
    else:
        raise RuntimeError(
            "Must create data using Python 2 to load data since MAF is written for "
            "Python 2")
    return data_dict
Exemple #2
0
def _get_maf_original(data_name):
    warnings.warn(
        'This function should generally not be called because it '
        'requires special setup but is kept here in order to reproduce functions if '
        'needed.')
    if sys.version_info < (3, ):
        # Load MNIST from MAF code
        maf_path = os.path.join(os.path.dirname(os.path.realpath(__file__)),
                                '..', '..', 'maf')
        sys.path.append(maf_path)
        # noinspection PyPackageRequirements
        import datasets  # maf/datasets/*

        # Reset datasets root directory relative to this file
        datasets.root = os.path.join(maf_path, 'data') + '/'

        # Copied from maf/experiments.py
        if data_name == 'mnist':
            data = datasets.MNIST(logit=True, dequantize=True)
        elif data_name == 'bsds300':
            data = datasets.BSDS300()
        elif data_name == 'cifar10':
            data = datasets.CIFAR10(logit=True, flip=True, dequantize=True)
        elif data_name == 'power':
            data = datasets.POWER()
        elif data_name == 'gas':
            data = datasets.GAS()
        elif data_name == 'hepmass':
            data = datasets.HEPMASS()
        elif data_name == 'miniboone':
            data = datasets.MINIBOONE()
        else:
            raise ValueError('Unknown dataset')

        # Make a dictionary instead of pickled object for better compatibility
        if hasattr(data.trn, 'labels'):
            data_dict = dict(
                X_train=data.trn.x,
                y_train=data.trn.labels,
                X_validation=data.val.x,
                y_validation=data.val.labels,
                X_test=data.tst.x,
                y_test=data.tst.labels,
                data_name=data_name,
            )
        else:
            data_dict = dict(
                X_train=data.trn.x,
                X_validation=data.val.x,
                X_test=data.tst.x,
                data_name=data_name,
            )
    else:
        raise RuntimeError(
            'Must create data using Python 2 to load data since MAF is written for '
            'Python 2')
    return data_dict
Exemple #3
0
def load_data(name):
    if name == 'bsds300':
        return datasets.BSDS300()

    elif name == 'power':
        return datasets.POWER()

    elif name == 'gas':
        return datasets.GAS()

    elif name == 'hepmass':
        return datasets.HEPMASS()

    elif name == 'miniboone':
        return datasets.MINIBOONE()

    else:
        raise ValueError('Unknown dataset')
Exemple #4
0
def load_data(name):
    """
    Loads the dataset. Has to be called before anything else.
    :param name: string, the dataset's name
    """

    assert isinstance(name, str), 'Name must be a string'
    datasets.root = root_data
    global data, data_name

    if data_name == name:
        return

    if name == 'mnist':
        data = datasets.MNIST(logit=True, dequantize=True)
        data_name = name

    elif name == 'bsds300':
        data = datasets.BSDS300()
        data_name = name

    elif name == 'cifar10':
        data = datasets.CIFAR10(logit=True, flip=True, dequantize=True)
        data_name = name

    elif name == 'power':
        data = POWER()
        data_name = name

    elif name == 'gas':
        data = datasets.GAS()
        data_name = name

    elif name == 'hepmass':
        data = datasets.HEPMASS()
        data_name = name

    elif name == 'miniboone':
        data = datasets.MINIBOONE()
        data_name = name

    else:
        raise ValueError('Unknown dataset')
Exemple #5
0
def load_data(name,logit=False,dequantize=False,flip = False):
    """
    Loads the dataset. Has to be called before anything else.
    :param name: string, the dataset's name
    """
    
    assert isinstance(name, str), 'Name must be a string'
    # global data
    
    
    if name == 'mnist':
        data = datasets.MNIST(logit=logit, dequantize=dequantize)
    elif name == 'bsds300':
        data = datasets.BSDS300()
    elif name == 'cifar10':
        data = datasets.CIFAR10(logit=logit, flip=flip, dequantize=dequantize)
    elif name == 'power':
        data = datasets.POWER()
    elif name == 'gas':
        data = datasets.GAS()
    elif name == 'hepmass':
        data = datasets.HEPMASS()
    elif name == 'miniboone':
        data = datasets.MINIBOONE()
    else:
        raise Exception('Unknown dataset')

    # get data splits
    X_train = data.trn.x
    X_val = data.val.x
    X_test = data.tst.x
    
    # Convert to float32
    X_train = X_train.astype(np.float32)
    X_val = X_val.astype(np.float32)
    X_test = X_test.astype(np.float32)
    
    return data, X_train, X_val, X_test
Exemple #6
0
def init_data():
    """
    Initialize data.
    """
    data = datasets.MINIBOONE()

    num_train = data.trn.N
    # num_test = data.trn.N
    num_test = data.val.N

    if float_64:
        convert = jnp.float64
    else:
        convert = jnp.float32

    data.trn.x = convert(data.trn.x)
    data.val.x = convert(data.val.x)
    data.tst.x = convert(data.tst.x)

    num_batches = num_train // parse_args.batch_size + 1 * (
        num_train % parse_args.batch_size != 0)
    num_test_batches = num_test // parse_args.test_batch_size + 1 * (
        num_train % parse_args.test_batch_size != 0)

    # make sure we always save the model on the last iteration
    assert num_batches * parse_args.nepochs % parse_args.save_freq == 0

    def gen_train_data():
        """
        Generator for train data.
        """
        key = rng
        inds = jnp.arange(num_train)

        while True:
            key, = jax.random.split(key, num=1)
            epoch_inds = jax.random.shuffle(key, inds)
            for i in range(num_batches):
                batch_inds = epoch_inds[i * parse_args.batch_size:min(
                    (i + 1) * parse_args.batch_size, num_train)]
                yield data.trn.x[batch_inds]

    def gen_val_data():
        """
        Generator for train data.
        """
        inds = jnp.arange(num_test)
        while True:
            for i in range(num_test_batches):
                batch_inds = inds[i * parse_args.test_batch_size:min(
                    (i + 1) * parse_args.test_batch_size, num_test)]
                yield data.val.x[batch_inds]

    def gen_test_data():
        """
        Generator for train data.
        """
        inds = jnp.arange(num_test)
        while True:
            for i in range(num_test_batches):
                batch_inds = inds[i * parse_args.test_batch_size:min(
                    (i + 1) * parse_args.test_batch_size, num_test)]
                yield data.tst.x[batch_inds]

    ds_train = gen_train_data()
    ds_test = gen_val_data()

    meta = {
        "dims": data.n_dims,
        "num_batches": num_batches,
        "num_test_batches": num_test_batches
    }

    return ds_train, ds_test, meta