all_sigma = fixed_sigma * np.ones((n_classes, fixed_sigma.shape[1])).astype(
        "float32")
    all_classes = np.eye(n_classes).astype('int32')
    all_recs, = np.dot(decode_function(all_mu + np.exp(all_sigma), all_classes),
                       pca_tf) + mean_norm
    for j in range(1, n_classes + 1):
        axarr[n, j].matshow(all_recs[j - 1].reshape(width, height), cmap="gray")
        axarr[n, j].axis('off')
f.suptitle("Style variation by changing conditional")
plt.savefig('vae_style.png')
plt.close()

# Calculate noisy linear path between points in space
mus, log_sigmas = encode_function(sample_X)
n_steps = 20
mu_path = interpolate_between_points(mus, n_steps=n_steps)
log_sigma_path = interpolate_between_points(log_sigmas, n_steps=n_steps)

# Noisy path across space from one point to another
path_X = mu_path + np.exp(log_sigma_path)
path_y = np.zeros((len(path_X), n_classes), dtype="int32")

for i in range(n_plot_samples):
    path_y[i * n_steps:(i + 1) * n_steps] = sample_y[i]

out, = np.dot(decode_function(path_X, path_y), pca_tf) + mean_norm
text_y = [str(np.argmax(path_y[i])) for i in range(len(path_y))]
color_y = ["white"] * len(text_y)
make_gif(out, "vae_code.gif", width, height, list_text_per_frame=text_y,
         list_text_per_frame_color=color_y, delay=1, grayscale=True)
Example #2
0
def gen_samples(arr):
    mu, log_sig = encode_function(arr)
    # No noise at test time
    out, = decode_function(mu + np.exp(log_sig))
    return out

# VAE specific plotting
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
samples = gen_samples(sample_X)
f, axarr = plt.subplots(n_plot_samples, 2)
for n, (X_i, s_i) in enumerate(zip(sample_X, samples)):
    axarr[n, 0].matshow(X_i.reshape(width, height), cmap="gray")
    axarr[n, 1].matshow(s_i.reshape(width, height), cmap="gray")
    axarr[n, 0].axis('off')
    axarr[n, 1].axis('off')
plt.savefig('vae_reconstruction.png')
plt.close()

# Calculate linear path between points in space
mus, log_sigmas = encode_function(sample_X)
mu_path = interpolate_between_points(mus)
log_sigma_path = interpolate_between_points(log_sigmas)

# Path across space from one point to another
path = mu_path + np.exp(log_sigma_path)
out, = decode_function(path)
make_gif(out, "vae_code.gif", width, height, delay=1, grayscale=True)