def sample_vae2(args):
    """ For vae from https://github.com/hardmaru/WorldModelsExperiments.git
    """
    z_size = 64  # This needs to match the size of the trained vae
    batch_size = args.count
    learning_rate = 0.0001
    kl_tolerance = 0.5
    model_path_name = "tf_vae"

    reset_graph()
    vae = ConvVAE(
        z_size=z_size,
        batch_size=batch_size,
        learning_rate=learning_rate,
        kl_tolerance=kl_tolerance,
        is_training=False,
        reuse=False,
        gpu_mode=False)  # use GPU on batchsize of 1000 -> much faster

    vae.load_json(os.path.join(model_path_name, 'vae.json'))

    z = np.random.normal(size=(args.count, z_size))
    samples = vae.decode(z)
    input_dim = samples.shape[1:]

    gen = DriveDataGenerator(args.dirs,
                             image_size=(64, 64),
                             batch_size=args.count,
                             shuffle=True,
                             max_load=10000,
                             images_only=True)
    orig = gen[0].astype(np.float) / 255.0
    #mu, logvar = vae.encode_mu_logvar(orig)
    #recon = vae.decode( mu )
    recon = vae.decode(vae.encode(orig))

    n = args.count
    plt.figure(figsize=(20, 6), tight_layout=False)
    plt.title('VAE samples')
    for i in range(n):
        ax = plt.subplot(3, n, i + 1)
        plt.imshow(samples[i].reshape(input_dim[0], input_dim[1],
                                      input_dim[2]))
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        if 0 == i:
            ax.set_title("Random")

    for i in range(n):
        ax = plt.subplot(3, n, n + i + 1)
        plt.imshow(orig[i].reshape(input_dim[0], input_dim[1], input_dim[2]))
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        if 0 == i:
            ax.set_title("Real")

        ax = plt.subplot(3, n, (2 * n) + i + 1)
        plt.imshow(recon[i].reshape(input_dim[0], input_dim[1], input_dim[2]))
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        if 0 == i:
            ax.set_title("Reconstructed")

    plt.savefig("samples_vae.png")
    plt.show()
Example #2
0
output_dir = "vae_test_result"

z_size=32

filelist = os.listdir(DATA_DIR)
filelist = [f for f in filelist if '.npz' in f]

obs = np.load(os.path.join(DATA_DIR, random.choice(filelist)))["obs"]
obs = np.expand_dims(obs, axis=-1)
obs = obs.astype(np.float32)/255.0

n = len(obs)

vae = ConvVAE(z_size=z_size,
              batch_size=1,
              is_training=False,
              reuse=False,
              gpu_mode=False)

vae.load_json(os.path.join(model_path_name, 'vae.json'))

if not os.path.exists(output_dir):
    os.mkdir(output_dir)

print(n, "images loaded")
for i in range(n):
    frame = obs[i].reshape(1, 64, 64, 1)
    batch_z = vae.encode(frame)
    reconstruct = vae.decode(batch_z)
    imsave(output_dir+'/%s.png' % pad_num(i), 255.*frame[0].reshape(64, 64))
    imsave(output_dir+'/%s_vae.png' % pad_num(i), 255.*reconstruct[0].reshape(64, 64))

print("All model loaded.")
# Fifth, run the evaluation. -> We have no predictions about the first frame.

start = time.time()

state = rnn_init_state(rnn) # initialize the state.
pz = None

for i in range(steps):

  ob = obs[i:i+1] # (1, 64, 64, 1)
  action = oh_actions[i:i+1] # (1, n)

  z = vae.encode(ob) # (1, 32) VAE done!
  rnn_z = np.expand_dims(z, axis=0) # (1, 1, 32)
  action = np.expand_dims(action, axis=0) # (1, 1, n)


  input_x = np.concatenate([rnn_z, action], axis=2) # (1, 1, 32+n)
  feed = {rnn.input_x: input_x, rnn.initial_state: state} # predict the next state and next z.

  if pz is not None: # decode from the z
    frame = vae.decode(pz[None])
    frame2 = vae.decode(z)
    #neglogp = neg_likelihood(logmix, mean, logstd, z.reshape(32,1))
    #imsave(output_dir + '/%s_origin_%.2f.png' % (pad_num(i), np.exp(-neglogp)), 255.*ob.reshape(64, 64))
    #imsave(output_dir + '/%s_reconstruct.png' % pad_num(i), 255. * frame[0].reshape(64, 64))
    img = concat_img(255.*ob, 255*frame2, 255.*frame)
    imsave(output_dir + '/%s.png' % pad_num(i), img)