def test_bernoulli_vector_default_output_layer(): """ BernoulliVector's default output layer is compatible with its required output space """ mlp = MLP(layers=[Linear(layer_name="h", dim=5, irange=0.01, max_col_norm=0.01)]) conditional = BernoulliVector(mlp=mlp, name="conditional") vae = DummyVAE() conditional.set_vae(vae) input_space = VectorSpace(dim=5) conditional.initialize_parameters(input_space=input_space, ndim=5)
def test_convolutional_compatible(): """ VAE allows convolutional encoding networks """ encoding_model = MLP( layers=[ SpaceConverter( layer_name='conv2d_converter', output_space=Conv2DSpace(shape=[4, 4], num_channels=1) ), ConvRectifiedLinear( layer_name='h', output_channels=2, kernel_shape=[2, 2], kernel_stride=[1, 1], pool_shape=[1, 1], pool_stride=[1, 1], pool_type='max', irange=0.01) ] ) decoding_model = MLP(layers=[Linear(layer_name='h', dim=16, irange=0.01)]) prior = DiagonalGaussianPrior() conditional = BernoulliVector(mlp=decoding_model, name='conditional') posterior = DiagonalGaussian(mlp=encoding_model, name='posterior') vae = VAE(nvis=16, prior=prior, conditional=conditional, posterior=posterior, nhid=16) X = T.matrix('X') lower_bound = vae.log_likelihood_lower_bound(X, num_samples=10) f = theano.function(inputs=[X], outputs=lower_bound) rng = make_np_rng(default_seed=11223) f(as_floatX(rng.uniform(size=(10, 16))))
def test_bernoulli_vector_conditional_expectation(): """ BernoulliVector.conditional_expectation doesn't crash """ mlp = MLP(layers=[Linear(layer_name='h', dim=5, irange=0.01, max_col_norm=0.01)]) conditional = BernoulliVector(mlp=mlp, name='conditional') vae = DummyVAE() conditional.set_vae(vae) input_space = VectorSpace(dim=5) conditional.initialize_parameters(input_space=input_space, ndim=5) mu = T.matrix('mu') conditional.conditional_expectation([mu])
def test_bernoulli_vector_sample_from_conditional(): """ BernoulliVector.sample_from_conditional works when num_samples is provided """ mlp = MLP(layers=[Linear(layer_name='h', dim=5, irange=0.01, max_col_norm=0.01)]) conditional = BernoulliVector(mlp=mlp, name='conditional') vae = DummyVAE() conditional.set_vae(vae) input_space = VectorSpace(dim=5) conditional.initialize_parameters(input_space=input_space, ndim=5) mu = T.matrix('mu') conditional.sample_from_conditional([mu], num_samples=2)
def test_vae_automatically_finds_kl_integrator(): """ VAE automatically finds the right KLIntegrator """ encoding_model = MLP(layers=[Linear(layer_name='h', dim=10, irange=0.01)]) decoding_model = MLP(layers=[Linear(layer_name='h', dim=10, irange=0.01)]) prior = DiagonalGaussianPrior() conditional = BernoulliVector(mlp=decoding_model, name='conditional') posterior = DiagonalGaussian(mlp=encoding_model, name='posterior') vae = VAE(nvis=10, prior=prior, conditional=conditional, posterior=posterior, nhid=5) assert (vae.kl_integrator is not None and isinstance(vae.kl_integrator, DiagonalGaussianPriorPosteriorKL))
def test_bernoulli_vector_reparametrization_trick(): """ BernoulliVector.sample_from_conditional raises an error when asked to sample using the reparametrization trick """ mlp = MLP(layers=[Linear(layer_name='h', dim=5, irange=0.01, max_col_norm=0.01)]) conditional = BernoulliVector(mlp=mlp, name='conditional') vae = DummyVAE() conditional.set_vae(vae) input_space = VectorSpace(dim=5) conditional.initialize_parameters(input_space=input_space, ndim=5) mu = T.matrix('mu') epsilon = T.tensor3('epsilon') conditional.sample_from_conditional([mu], epsilon=epsilon)
def test_bernoulli_vector_conditional_expectation(): """ BernoulliVector.conditional_expectation doesn't crash """ mlp = MLP(layers=[Linear(layer_name="h", dim=5, irange=0.01, max_col_norm=0.01)]) conditional = BernoulliVector(mlp=mlp, name="conditional") vae = DummyVAE() conditional.set_vae(vae) input_space = VectorSpace(dim=5) conditional.initialize_parameters(input_space=input_space, ndim=5) mu = T.matrix("mu") conditional.conditional_expectation([mu])
def test_bernoulli_vector_sample_from_conditional(): """ BernoulliVector.sample_from_conditional works when num_samples is provided """ mlp = MLP(layers=[Linear(layer_name="h", dim=5, irange=0.01, max_col_norm=0.01)]) conditional = BernoulliVector(mlp=mlp, name="conditional") vae = DummyVAE() conditional.set_vae(vae) input_space = VectorSpace(dim=5) conditional.initialize_parameters(input_space=input_space, ndim=5) mu = T.matrix("mu") conditional.sample_from_conditional([mu], num_samples=2)
def test_bernoulli_vector_reparametrization_trick(): """ BernoulliVector.sample_from_conditional raises an error when asked to sample using the reparametrization trick """ mlp = MLP(layers=[Linear(layer_name="h", dim=5, irange=0.01, max_col_norm=0.01)]) conditional = BernoulliVector(mlp=mlp, name="conditional") vae = DummyVAE() conditional.set_vae(vae) input_space = VectorSpace(dim=5) conditional.initialize_parameters(input_space=input_space, ndim=5) mu = T.matrix("mu") epsilon = T.tensor3("epsilon") conditional.sample_from_conditional([mu], epsilon=epsilon)
def test_multiple_samples_allowed(): """ VAE allows multiple samples per data point """ encoding_model = MLP(layers=[Linear(layer_name='h', dim=10, irange=0.01)]) decoding_model = MLP(layers=[Linear(layer_name='h', dim=10, irange=0.01)]) prior = DiagonalGaussianPrior() conditional = BernoulliVector(mlp=decoding_model, name='conditional') posterior = DiagonalGaussian(mlp=encoding_model, name='posterior') vae = VAE(nvis=10, prior=prior, conditional=conditional, posterior=posterior, nhid=5) X = T.matrix('X') lower_bound = vae.log_likelihood_lower_bound(X, num_samples=10) f = theano.function(inputs=[X], outputs=lower_bound) rng = make_np_rng(default_seed=11223) f(as_floatX(rng.uniform(size=(10, 10))))
def test_bernoulli_vector_default_output_layer(): """ BernoulliVector's default output layer is compatible with its required output space """ mlp = MLP(layers=[Linear(layer_name='h', dim=5, irange=0.01, max_col_norm=0.01)]) conditional = BernoulliVector(mlp=mlp, name='conditional') vae = DummyVAE() conditional.set_vae(vae) input_space = VectorSpace(dim=5) conditional.initialize_parameters(input_space=input_space, ndim=5)