示例#1
0
def train_tfdata(x_train, learning_rate, epochs=10):
    vae = VAE(input_shape=(hop, 3 * shape, 1),
              conv_filters=(512, 256, 128, 64, 32),
              conv_kernels=(3, 3, 3, 3, 3),
              conv_strides=(2, 2, 2, 2, (2, 1)),
              latent_space_dim=VECTOR_DIM)
    vae.summary()
    vae.compile(learning_rate)
    vae.train(x_train, num_epochs=epochs)
    return vae
示例#2
0
def train(x_train, learning_rate, batch_size, epochs):
    vae = VAE(input_shape=(hop, shape * spec_split, 1),
              conv_filters=(512, 256, 128, 64, 32),
              conv_kernels=(3, 3, 3, 3, 3),
              conv_strides=(2, 2, 2, 2, (2, 1)),
              latent_space_dim=VECTOR_DIM)
    vae.summary()
    vae.compile(learning_rate)
    vae.train(x_train, batch_size, epochs)
    return vae
示例#3
0
def train(x_train, learning_rate, batch_size, epochs):
    autoencoder = VAE(
        #input_shape=(28, 28, 1),
        input_shape=(106, 100, 1),
        conv_filters=(32, 64, 64, 64),
        conv_kernels=(3, 3, 3, 3),
        conv_strides=(1, 1, 1, 2),
        latent_space_dim=2
    )
    autoencoder.summary()
    autoencoder.compile(learning_rate)
    autoencoder.train(x_train, batch_size, epochs)
    return autoencoder
示例#4
0
tb_callback = tf.keras.callbacks.TensorBoard(LOG_DIR)
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=CHECKPOINT_DIR,
    monitor='total_loss',
    mode='min',
    save_freq=5,
    save_best_only=True)

# Get dataset
train_dataset, test_dataset, shape = get_mnist()

vae = VAE(shape, LATENT_SHAPE)
tb_callback.set_model(vae)

vae.load_weights(CHECKPOINT_DIR)
vae.compile(optimizer="adam")
vae.fit(train_dataset,
        epochs=20,
        callbacks=[tb_callback, model_checkpoint_callback])

# Validate
vae.evaluate(test_dataset)

# Save sampes in tensorboard
predictions = vae.random_sample()
file_writer = tf.summary.create_file_writer(LOG_DIR)
with file_writer.as_default():
    tf.summary.image("Images", predictions, max_outputs=10, step=0)

# Visual experiment
start_gui(vae)
示例#5
0
from vae import VAE

img_shape = (28, 28, 1)
batch_size = 32
latent_dim = 2

(x_train, _), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype('float32') / 255.0
x_train = x_train.reshape(x_train.shape + (1, ))
x_test = x_test.astype('float32') / 255.0
x_test = x_test.reshape(x_test.shape + (1, ))

vae = VAE(img_shape, latent_dim)
decoder = vae.get_decoder()
vae = vae.get_model()
vae.compile(optimizer=RMSprop(), loss=None)
history = vae.fit(x=x_train,
                  y=None,
                  shuffle=True,
                  epochs=10,
                  batch_size=batch_size)
with open('loss.txt', 'a') as f:
    for loss in history.history['loss']:
        f.write(str(loss) + '\r')

n = 15
digit_size = 28
figure = np.zeros((digit_size * n, digit_size * n))
grid_x = norm.ppf(np.linspace(0.05, 0.95, n))
grid_y = norm.ppf(np.linspace(0.05, 0.95, n))