def load_data(dataset_name, data_path, batch_size):
    if dataset_name == 'mnist':
        train_data, _, test_data, _ = data.load_mnist(data_path)
    elif dataset_name == 'cifar10':
        train_data, _, test_data, _ = data.load_cifar10(data_path)
    else:
        assert False, "Must specify a valid dataset_name"
    data_shape = (batch_size, ) + train_data.shape[1:]
    batches_per_epoch = train_data.shape[0] // batch_size
    train_gen = data.data_generator(train_data, batch_size)
    test_gen = data.data_generator(test_data, batch_size)
    return train_gen, test_gen, batches_per_epoch, data_shape
Exemplo n.º 2
0
from random import shuffle

import sounds_deep.contrib.data.data as data
import sounds_deep.contrib.util.util as util

tfd = tf.contrib.distributions
tfb = tfd.bijectors
parser = argparse.ArgumentParser(description='Train a VAE model.')
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--learning_rate', type=float, default=5e-6)
args = parser.parse_args()

# load the data
# train_data, _, _, _ = data.load_cifar10('./data/')
train_data, _, _, _ = data.load_mnist('./data/')
train_data = np.reshape(train_data, [-1, 28, 28, 1])
# train_data, _, _, _ = data.load_sudoku('./data')
train_data = train_data.astype(np.float32)
data_shape = (args.batch_size, ) + train_data.shape[1:]
batches_per_epoch = train_data.shape[0] // args.batch_size
train_gen = data.data_generator(train_data, args.batch_size)


def fnfn(i):
    def _fn(x, output_units):
        first = snt.Linear(512)
        net = snt.Sequential([
            first, tf.nn.relu,
            snt.Linear(512), tf.nn.relu,
            snt.Linear(output_units * 2,
Exemplo n.º 3
0
        output_directory = args.output_dir
        os.mkdir(output_directory)


def unison_shuffled_copies(arrays):
    assert all([len(a) == len(arrays[0]) for a in arrays])
    p = np.random.permutation(len(arrays[0]))
    return [a[p] for a in arrays]


# load the data
if args.dataset == 'cifar10':
    train_data, train_labels, test_data, test_labels = data.load_cifar10(
        './data/')
elif args.dataset == 'mnist':
    train_data, train_labels, test_data, test_labels = data.load_mnist(
        './data/')
    train_data = np.reshape(train_data, [-1, 28, 28, 1])
    test_data = np.reshape(test_data, [-1, 28, 28, 1])
data_shape = train_data.shape[1:]
label_shape = train_labels.shape[1:]
train_batches_per_epoch = train_data.shape[0] // args.unlabeled_batch_size
test_batches_per_epoch = test_data.shape[0] // args.labeled_batch_size

# choose labeled training data
train_data, train_labels = unison_shuffled_copies([train_data, train_labels])
labeled_train_data = train_data[:args.num_labeled_data]
labeled_train_labels = train_labels[:args.num_labeled_data]

# shuffle data and create generators
labeled_train_gen = data.parallel_data_generator(
    [labeled_train_data, labeled_train_labels], args.labeled_batch_size)
Exemplo n.º 4
0

def apply_temp(a, temperature=1.0):
    # helper function to sample an index from a probability array
    a = tf.log(a) / temperature
    a = tf.exp(a) / tf.reduce_sum(tf.exp(a), axis=1, keepdims=True)
    return a


# celebA data
# idxable, train_idxs, test_idxs, attributes = data.load_celeba('./sounds_deep/contrib/data/')
# batches_per_epoch = train_idxs.shape[0] // args.batch_size
# train_gen = data.idxable_data_generator(idxable, train_idxs, args.batch_size)
# data_shape = next(train_gen).shape

train_data, train_labels, _, _ = data.load_mnist('./data/')
data_shape = (args.batch_size, ) + train_data.shape[1:]
label_shape = (args.batch_size, ) + train_labels.shape[1:]
batches_per_epoch = train_data.shape[0] // args.batch_size
train_gen = data.parallel_data_generator([train_data, train_labels],
                                         args.batch_size)

# build the model
temp_ph = tf.placeholder(tf.float32)
encoder_module = snt.nets.MLP([200, 200])
decoder_module = snt.nets.MLP([200, 200, 784])
model = hvae.HVAE(args.latent_dimension,
                  encoder_module,
                  decoder_module,
                  hvar_shape=10,
                  temperature=temp_ph)