def sample_vae2(args): """ For vae from https://github.com/hardmaru/WorldModelsExperiments.git """ z_size = 64 # This needs to match the size of the trained vae batch_size = args.count learning_rate = 0.0001 kl_tolerance = 0.5 model_path_name = "tf_vae" reset_graph() vae = ConvVAE( z_size=z_size, batch_size=batch_size, learning_rate=learning_rate, kl_tolerance=kl_tolerance, is_training=False, reuse=False, gpu_mode=False) # use GPU on batchsize of 1000 -> much faster vae.load_json(os.path.join(model_path_name, 'vae.json')) z = np.random.normal(size=(args.count, z_size)) samples = vae.decode(z) input_dim = samples.shape[1:] gen = DriveDataGenerator(args.dirs, image_size=(64, 64), batch_size=args.count, shuffle=True, max_load=10000, images_only=True) orig = gen[0].astype(np.float) / 255.0 #mu, logvar = vae.encode_mu_logvar(orig) #recon = vae.decode( mu ) recon = vae.decode(vae.encode(orig)) n = args.count plt.figure(figsize=(20, 6), tight_layout=False) plt.title('VAE samples') for i in range(n): ax = plt.subplot(3, n, i + 1) plt.imshow(samples[i].reshape(input_dim[0], input_dim[1], input_dim[2])) ax.get_xaxis().set_visible(False) ax.get_yaxis().set_visible(False) if 0 == i: ax.set_title("Random") for i in range(n): ax = plt.subplot(3, n, n + i + 1) plt.imshow(orig[i].reshape(input_dim[0], input_dim[1], input_dim[2])) ax.get_xaxis().set_visible(False) ax.get_yaxis().set_visible(False) if 0 == i: ax.set_title("Real") ax = plt.subplot(3, n, (2 * n) + i + 1) plt.imshow(recon[i].reshape(input_dim[0], input_dim[1], input_dim[2])) ax.get_xaxis().set_visible(False) ax.get_yaxis().set_visible(False) if 0 == i: ax.set_title("Reconstructed") plt.savefig("samples_vae.png") plt.show()
output_dir = "vae_test_result" z_size=32 filelist = os.listdir(DATA_DIR) filelist = [f for f in filelist if '.npz' in f] obs = np.load(os.path.join(DATA_DIR, random.choice(filelist)))["obs"] obs = np.expand_dims(obs, axis=-1) obs = obs.astype(np.float32)/255.0 n = len(obs) vae = ConvVAE(z_size=z_size, batch_size=1, is_training=False, reuse=False, gpu_mode=False) vae.load_json(os.path.join(model_path_name, 'vae.json')) if not os.path.exists(output_dir): os.mkdir(output_dir) print(n, "images loaded") for i in range(n): frame = obs[i].reshape(1, 64, 64, 1) batch_z = vae.encode(frame) reconstruct = vae.decode(batch_z) imsave(output_dir+'/%s.png' % pad_num(i), 255.*frame[0].reshape(64, 64)) imsave(output_dir+'/%s_vae.png' % pad_num(i), 255.*reconstruct[0].reshape(64, 64))
print("All model loaded.") # Fifth, run the evaluation. -> We have no predictions about the first frame. start = time.time() state = rnn_init_state(rnn) # initialize the state. pz = None for i in range(steps): ob = obs[i:i+1] # (1, 64, 64, 1) action = oh_actions[i:i+1] # (1, n) z = vae.encode(ob) # (1, 32) VAE done! rnn_z = np.expand_dims(z, axis=0) # (1, 1, 32) action = np.expand_dims(action, axis=0) # (1, 1, n) input_x = np.concatenate([rnn_z, action], axis=2) # (1, 1, 32+n) feed = {rnn.input_x: input_x, rnn.initial_state: state} # predict the next state and next z. if pz is not None: # decode from the z frame = vae.decode(pz[None]) frame2 = vae.decode(z) #neglogp = neg_likelihood(logmix, mean, logstd, z.reshape(32,1)) #imsave(output_dir + '/%s_origin_%.2f.png' % (pad_num(i), np.exp(-neglogp)), 255.*ob.reshape(64, 64)) #imsave(output_dir + '/%s_reconstruct.png' % pad_num(i), 255. * frame[0].reshape(64, 64)) img = concat_img(255.*ob, 255*frame2, 255.*frame) imsave(output_dir + '/%s.png' % pad_num(i), img)