def test_encode_image_rVAE(latent_dim): large_image = np.random.randn(2, 64, 64) window_size = (16, 16) patches_data = np.random.randn(32, *window_size) v = rVAE(window_size, latent_dim) v.fit(patches_data, training_cycles=2) img, img_encoded = v.encode_images(large_image) cropped_dim = 64 - window_size[0] + 1 assert_equal(img.shape, (2, cropped_dim, cropped_dim)) assert_equal(img_encoded.shape, (2, cropped_dim, cropped_dim, latent_dim + 3))
def test_encoding_rVAE(translation, latent_dim, encoded_dim): input_dim = (28, 28) data = gen_image_data() v = rVAE(input_dim, latent_dim, translation=translation, numhidden_encoder=16, numhidden_decoder=16) v.fit(data, training_cycles=2) z = v.encode(data) assert_equal(len(z), 2) assert_equal(z[0].shape[-1], encoded_dim) assert_equal(z[1].shape[-1], encoded_dim)
def test_decoding_rVAE(conv_encoder, conv_decoder, translation, latent_dim): input_dim = (28, 28) data = gen_image_data() v = rVAE(input_dim, latent_dim, translation=translation, numhidden_encoder=16, numhidden_decoder=16) v.fit(data, training_cycles=2) z_sample = np.zeros((latent_dim))[None] decoded = v.decode(z_sample) assert_equal(decoded.shape[1:], input_dim) assert_(np.sum(decoded) != 0)