Beispiel #1
0
def guide(batch, hidden_dim=400, z_dim=100):
    batch = jnp.reshape(batch, (batch.shape[0], -1))
    batch_dim, out_dim = jnp.shape(batch)
    encode = numpyro.module('encoder', encoder(hidden_dim, z_dim), (batch_dim, out_dim))
    z_loc, z_std = encode(batch)
    z = numpyro.sample('z', dist.Normal(z_loc, z_std))
    return z
Beispiel #2
0
def guide(batch: np.ndarray, hidden_dim: int = 400, z_dim: int = 100) -> None:

    batch = jnp.reshape(batch, (batch.shape[0], -1))
    batch_dim, out_dim = jnp.shape(batch)
    encode = numpyro.module("encoder", encoder(hidden_dim, z_dim), (batch_dim, out_dim))
    z_loc, z_std = encode(batch)
    numpyro.sample("z", dist.Normal(z_loc, z_std))
Beispiel #3
0
def guide(batch, hidden_dim=400, z_dim=100):
    batch = jnp.reshape(batch, (batch.shape[0], -1))
    batch_dim, out_dim = jnp.shape(batch)
    encode = numpyro.module("encoder", encoder(hidden_dim, z_dim), (batch_dim, out_dim))
    z_loc, z_std = encode(batch)
    with numpyro.plate("batch", batch_dim):
        return numpyro.sample("z", dist.Normal(z_loc, z_std).to_event(1))
Beispiel #4
0
def model(batch, hidden_dim=400, z_dim=100):
    batch = jnp.reshape(batch, (batch.shape[0], -1))
    batch_dim, out_dim = jnp.shape(batch)
    decode = numpyro.module('decoder', decoder(hidden_dim, out_dim), (batch_dim, z_dim))
    z = numpyro.sample('z', dist.Normal(jnp.zeros((z_dim,)), jnp.ones((z_dim,))))
    img_loc = decode(z)
    return numpyro.sample('obs', dist.Bernoulli(img_loc), obs=batch)
Beispiel #5
0
def model(batch: np.ndarray, hidden_dim: int = 400, z_dim: int = 100) -> None:

    batch = jnp.reshape(batch, (batch.shape[0], -1))
    batch_dim, out_dim = jnp.shape(batch)
    decode = numpyro.module("decoder", decoder(hidden_dim, out_dim), (batch_dim, z_dim))
    z = numpyro.sample("z", dist.Normal(jnp.zeros((z_dim,)), jnp.ones((z_dim,))))
    img_loc = decode(z)
    numpyro.sample("obs", dist.Bernoulli(img_loc), obs=batch)
Beispiel #6
0
def model(batch, hidden_dim=400, z_dim=100):
    batch = jnp.reshape(batch, (batch.shape[0], -1))
    batch_dim, out_dim = jnp.shape(batch)
    decode = numpyro.module("decoder", decoder(hidden_dim, out_dim), (batch_dim, z_dim))
    with numpyro.plate("batch", batch_dim):
        z = numpyro.sample("z", dist.Normal(0, 1).expand([z_dim]).to_event(1))
        img_loc = decode(z)
        return numpyro.sample("obs", dist.Bernoulli(img_loc).to_event(1), obs=batch)
Beispiel #7
0
 def _get_posterior(self):
     if self.latent_dim == 1:
         raise ValueError('latent dim = 1. Consider using AutoDiagonalNormal instead')
     flows = []
     for i in range(self.num_flows):
         if i > 0:
             flows.append(PermuteTransform(jnp.arange(self.latent_dim)[::-1]))
         residual = "gated" if i < (self.num_flows - 1) else None
         arn = BlockNeuralAutoregressiveNN(self.latent_dim, self._hidden_factors, residual)
         arnn = numpyro.module('{}_arn__{}'.format(self.prefix, i), arn, (self.latent_dim,))
         flows.append(BlockNeuralAutoregressiveTransform(arnn))
     return dist.TransformedDistribution(self.get_base_dist(), flows)
Beispiel #8
0
 def _get_transform(self):
     if self.latent_size == 1:
         raise ValueError('latent dim = 1. Consider using AutoDiagonalNormal instead')
     hidden_dims = [self.latent_size, self.latent_size] if self._hidden_dims is None else self._hidden_dims
     flows = []
     for i in range(self.num_flows):
         if i > 0:
             flows.append(PermuteTransform(np.arange(self.latent_size)[::-1]))
         arn = AutoregressiveNN(self.latent_size, hidden_dims,
                                permutation=np.arange(self.latent_size),
                                skip_connections=self._skip_connections,
                                nonlinearity=self._nonlinearity)
         arnn = numpyro.module('{}_arn__{}'.format(self.prefix, i), arn, (self.latent_size,))
         flows.append(InverseAutoregressiveTransform(arnn))
     return ComposeTransform(flows)
Beispiel #9
0
 def _get_posterior(self):
     if self.latent_dim == 1:
         raise ValueError('latent dim = 1. Consider using AutoDiagonalNormal instead')
     hidden_dims = [self.latent_dim, self.latent_dim] if self._hidden_dims is None else self._hidden_dims
     flows = []
     for i in range(self.num_flows):
         if i > 0:
             flows.append(PermuteTransform(jnp.arange(self.latent_dim)[::-1]))
         arn = AutoregressiveNN(self.latent_dim, hidden_dims,
                                permutation=jnp.arange(self.latent_dim),
                                skip_connections=self._skip_connections,
                                nonlinearity=self._nonlinearity)
         arnn = numpyro.module('{}_arn__{}'.format(self.prefix, i), arn, (self.latent_dim,))
         flows.append(InverseAutoregressiveTransform(arnn))
     return dist.TransformedDistribution(self.get_base_dist(), flows)
Beispiel #10
0
 def _get_transform(self):
     if self.latent_size == 1:
         raise ValueError(
             'latent dim = 1. Consider using AutoDiagonalNormal instead')
     flows = []
     for i in range(self.num_flows):
         if i > 0:
             flows.append(
                 PermuteTransform(np.arange(self.latent_size)[::-1]))
         residual = "gated" if i < (self.num_flows - 1) else None
         arn = BlockNeuralAutoregressiveNN(self.latent_size,
                                           self._hidden_factors, residual)
         arnn = numpyro.module('{}_arn__{}'.format(self.prefix, i), arn,
                               (self.latent_size, ))
         flows.append(BlockNeuralAutoregressiveTransform(arnn))
     return ComposeTransform(flows)
Beispiel #11
0
def numpyro_haiku(name, haiku_fn, input_shape):
    """Converts a Haiku module to a Stax module and calls numpyro.module."""
    init_fn, apply_fn = hk.transform(haiku_fn, apply_rng=True)

    def stax_init_fn(rng, in_shapes):
        # TODO: Extend to trees of in_shapes.
        inputs = np.zeros(in_shapes)
        params = init_fn(rng, inputs)
        output = apply_fn(params, None, inputs)
        out_shapes = jax.tree_map(lambda o: o.shape, output)
        return out_shapes, params

    def stax_apply_fn(params, inputs):
        return apply_fn(params, None, inputs)

    return numpyro.module(name, (stax_init_fn, stax_apply_fn), input_shape)
Beispiel #12
0
def guide(batch, z_dim, hidden_dim, out_dim=None, num_obs_total=None):
    """Defines the probabilistic guide for z (variational approximation to posterior): q(z) ~ p(z|q)
    :param batch: a batch of observations
    :return: (named) sampled z from the variational (guide) distribution q(z)
    """
    assert (jnp.ndim(batch) == 3)
    batch_size = jnp.shape(batch)[0]
    batch = jnp.reshape(
        batch, (batch_size, -1)
    )  # squash each data item into a one-dimensional array (preserving only the batch size on the first axis)
    out_dim = jnp.shape(batch)[1]

    encode = numpyro.module('encoder', encoder(hidden_dim, z_dim),
                            (batch_size, out_dim))
    with minibatch(batch_size, num_obs_total=num_obs_total):
        z_loc, z_std = encode(
            batch)  # obtain mean and variance for q(z) ~ p(z|x) from encoder
        z = sample('z', dist.Normal(z_loc, z_std))  # z follows q(z)
        return z
Beispiel #13
0
def model(batch_or_batchsize,
          z_dim,
          hidden_dim,
          out_dim=None,
          num_obs_total=None):
    """Defines the generative probabilistic model: p(x|z)p(z)

    The model is conditioned on the observed data

    :param batch: a batch of observations
    :param hidden_dim: dimensions of the hidden layers in the VAE
    :param z_dim: dimensions of the latent variable / code
    :param out_dim: number of dimensions in a single output sample (flattened)

    :return: (named) sample x from the model observation distribution p(x|z)p(z)
    """
    if is_int_scalar(batch_or_batchsize):
        batch = None
        batch_size = batch_or_batchsize
        if out_dim is None:
            raise ValueError("if no batch is provided, out_dim must be given")
    else:
        batch = batch_or_batchsize
        assert (jnp.ndim(batch) == 3)
        batch_size = jnp.shape(batch)[0]
        batch = jnp.reshape(
            batch, (batch_size, -1)
        )  # squash each data item into a one-dimensional array (preserving only the batch size on the first axis)
        out_dim = jnp.shape(batch)[1]

    decode = numpyro.module('decoder', decoder(hidden_dim, out_dim),
                            (batch_size, z_dim))
    with minibatch(batch_size, num_obs_total=num_obs_total):
        z = sample('z', dist.Normal(jnp.zeros((z_dim, )), jnp.ones(
            (z_dim, ))))  # prior on z is N(0,I)
        img_loc = decode(
            z
        )  # evaluate decoder (p(x|z)) on sampled z to get means for output bernoulli distribution
        x = sample(
            'obs', dist.Bernoulli(img_loc), obs=batch
        )  # outputs x are sampled from bernoulli distribution depending on z and conditioned on the observed data
        return x
Beispiel #14
0
 def model(x, y):
     nn = numpyro.module("nn", Dense(1), (10,))
     mu = nn(x).squeeze(-1)
     sigma = numpyro.sample("sigma", dist.HalfNormal(1))
     numpyro.sample("y", dist.Normal(mu, sigma), obs=y)