def sample_vae2(args): """ For vae from https://github.com/hardmaru/WorldModelsExperiments.git """ z_size = 32 batch_size = args.count learning_rate = 0.0001 kl_tolerance = 0.5 model_path_name = "tf_vae" reset_graph() vae = ConvVAE( z_size=z_size, batch_size=batch_size, learning_rate=learning_rate, kl_tolerance=kl_tolerance, is_training=False, reuse=False, gpu_mode=False) # use GPU on batchsize of 1000 -> much faster vae.load_json(os.path.join(model_path_name, 'vae.json')) z = np.random.normal(size=(args.count, z_size)) samples = vae.decode(z) input_dim = samples.shape[1:] n = args.count plt.figure(figsize=(20, 4)) plt.title('VAE samples') for i in range(n): ax = plt.subplot(2, n, i + 1) plt.imshow(samples[i].reshape(input_dim[0], input_dim[1], input_dim[2])) ax.get_xaxis().set_visible(False) ax.get_yaxis().set_visible(False) #plt.savefig( image_path ) plt.show()
if __name__ == '__main__': arglist = parse_args() if not os.path.exists(arglist.series_dir): os.makedirs(arglist.series_dir) filelist = os.listdir(arglist.data_dir) filelist.sort() filelist = filelist[0:10000] dataset, action_dataset, oppo_action_dataset = load_raw_data_list( filelist, arglist) reset_graph() if arglist.use_vae: vae = ConvVAE( z_size=arglist.z_size, batch_size=arglist.batch_size, learning_rate=arglist.lr, kl_tolerance=arglist.kl_tolerance, is_training=False, reuse=False, gpu_mode=True) # use GPU on batchsize of 1000 -> much faster vae.load_json(os.path.join(arglist.vae_path, 'vae.json')) mu_dataset = [] logvar_dataset = [] action_dataset_real = []
def main( dirs, z_size=32, batch_size=100, learning_rate=0.0001, kl_tolerance=0.5, epochs=100, save_model=False, verbose=True, optimizer="Adam" ): if save_model: model_save_path = "tf_vae" if not os.path.exists(model_save_path): os.makedirs(model_save_path) gen = DriveDataGenerator(dirs, image_size=(64,64), batch_size=batch_size, shuffle=True, max_load=10000, images_only=True ) num_batches = len(gen) reset_graph() vae = ConvVAE(z_size=z_size, batch_size=batch_size, learning_rate=learning_rate, kl_tolerance=kl_tolerance, is_training=True, reuse=False, gpu_mode=True, optimizer=optimizer) early = EarlyStopping(monitor='loss', min_delta=0.1, patience=5, verbose=verbose, mode='auto') early.set_model(vae) early.on_train_begin() best_loss = sys.maxsize if verbose: print("epoch\tstep\tloss\trecon_loss\tkl_loss") for epoch in range(epochs): for idx in range(num_batches): batch = gen[idx] obs = batch.astype(np.float)/255.0 feed = {vae.x: obs,} (train_loss, r_loss, kl_loss, train_step, _) = vae.sess.run([ vae.loss, vae.r_loss, vae.kl_loss, vae.global_step, vae.train_op ], feed) if train_loss < best_loss: best_loss = train_loss if save_model: if ((train_step+1) % 5000 == 0): vae.save_json("tf_vae/vae.json") if verbose: print("{} of {}\t{}\t{:.2f}\t{:.2f}\t{:.2f}".format( epoch, epochs, (train_step+1), train_loss, r_loss, kl_loss) ) gen.on_epoch_end() early.on_epoch_end(epoch, logs={"loss": train_loss}) if vae.stop_training: break early.on_train_end() # finished, final model: if save_model: vae.save_json("tf_vae/vae.json") return best_loss
def sample_vae2(args): """ For vae from https://github.com/hardmaru/WorldModelsExperiments.git """ z_size = 64 # This needs to match the size of the trained vae batch_size = args.count learning_rate = 0.0001 kl_tolerance = 0.5 model_path_name = "tf_vae" reset_graph() vae = ConvVAE( z_size=z_size, batch_size=batch_size, learning_rate=learning_rate, kl_tolerance=kl_tolerance, is_training=False, reuse=False, gpu_mode=False) # use GPU on batchsize of 1000 -> much faster vae.load_json(os.path.join(model_path_name, 'vae.json')) z = np.random.normal(size=(args.count, z_size)) samples = vae.decode(z) input_dim = samples.shape[1:] gen = DriveDataGenerator(args.dirs, image_size=(64, 64), batch_size=args.count, shuffle=True, max_load=10000, images_only=True) orig = gen[0].astype(np.float) / 255.0 #mu, logvar = vae.encode_mu_logvar(orig) #recon = vae.decode( mu ) recon = vae.decode(vae.encode(orig)) n = args.count plt.figure(figsize=(20, 6), tight_layout=False) plt.title('VAE samples') for i in range(n): ax = plt.subplot(3, n, i + 1) plt.imshow(samples[i].reshape(input_dim[0], input_dim[1], input_dim[2])) ax.get_xaxis().set_visible(False) ax.get_yaxis().set_visible(False) if 0 == i: ax.set_title("Random") for i in range(n): ax = plt.subplot(3, n, n + i + 1) plt.imshow(orig[i].reshape(input_dim[0], input_dim[1], input_dim[2])) ax.get_xaxis().set_visible(False) ax.get_yaxis().set_visible(False) if 0 == i: ax.set_title("Real") ax = plt.subplot(3, n, (2 * n) + i + 1) plt.imshow(recon[i].reshape(input_dim[0], input_dim[1], input_dim[2])) ax.get_xaxis().set_visible(False) ax.get_yaxis().set_visible(False) if 0 == i: ax.set_title("Reconstructed") plt.savefig("samples_vae.png") plt.show()