Ejemplo n.º 1
0
    def test_posterior_mode(self):
        mix_dist_probs = jnp.array([[0.5, 0.5], [0.01, 0.99]])
        locs = jnp.array([[-1., 1.], [-1., 1.]])
        scale = jnp.array([1.])

        gm = MixtureSameFamily(
            mixture_distribution=distrax.Categorical(probs=mix_dist_probs),
            components_distribution=distrax.Normal(loc=locs, scale=scale))

        mode = gm.posterior_mode(jnp.array([[1.], [-1.], [-6.]]))

        self.assertEqual((3, 2), mode.shape)
        self.assertAllClose(jnp.array([[1, 1], [0, 1], [0, 0]]), mode)
Ejemplo n.º 2
0
    def make_distribution(self,
                          net_output: jnp.ndarray) -> distrax.Distribution:
        if self.learned_sigma:
            init = hk.initializers.Constant(-jnp.log(2.0) / 2.0)
            log_scale = hk.get_parameter("log_scale",
                                         shape=(),
                                         dtype=net_output.dtype,
                                         init=init)
            scale = jnp.full_like(net_output, jnp.exp(log_scale))
        else:
            scale = jnp.full_like(net_output, 1 / jnp.sqrt(2.0))

        return distrax.Normal(net_output, scale)
Ejemplo n.º 3
0
 def make_distribution(self,
                       net_output: jnp.ndarray) -> distrax.Distribution:
     if self.distribution_name is None:
         return net_output
     elif self.distribution_name == "diagonal_normal":
         if self.aggregation_type is None:
             split_axis, num_axes = self.data_format.index("C"), 3
         else:
             split_axis, num_axes = 1, 1
         # Add an extra axis if the input has more than 1 batch dimension
         split_axis += net_output.ndim - num_axes - 1
         loc, log_scale = jnp.split(net_output, 2, axis=split_axis)
         return distrax.Normal(loc, jnp.exp(log_scale))
     else:
         raise NotImplementedError()
Ejemplo n.º 4
0
 def prior(self) -> distrax.Distribution:
     """Given the parameters returns the prior distribution of the model."""
     # Allow to run with both the full parameters and only the priors
     if self.prior_type == "standard_normal":
         # assert self.prior_nets is None and self.gated_made is None
         if self.latent_system_net_type == "mlp":
             event_shape = (self.latent_system_dim, )
         elif self.latent_system_net_type == "conv":
             if self.data_format == "NHWC":
                 event_shape = self.latent_spatial_shape + (
                     self.latent_system_dim, )
             else:
                 event_shape = (
                     self.latent_system_dim, ) + self.latent_spatial_shape
         else:
             raise NotImplementedError()
         return distrax.Normal(jnp.zeros(event_shape),
                               jnp.ones(event_shape))
     else:
         raise ValueError(f"Unrecognized prior_type='{self.prior_type}'.")
Ejemplo n.º 5
0
print(
    f'Loglikelihood found by HMM General Log Space Version: {loglikelihood_log}'
)

assert np.allclose(jnp.log(alphas), alphas_log, 8)
assert np.allclose(loglikelihood, loglikelihood_log)
assert np.allclose(jnp.log(gammas), gammas_log, 8)

# Test for the hmm_viterbi_log. This test is based on https://github.com/deepmind/distrax/blob/master/distrax/_src/utils/hmm_test.py
loc = jnp.array([0.0, 1.0, 2.0, 3.0])
scale = jnp.array(0.25)
initial = jnp.array([0.25, 0.25, 0.25, 0.25])
trans = jnp.array([[0.9, 0.1, 0.0, 0.0], [0.1, 0.8, 0.1, 0.0],
                   [0.0, 0.1, 0.8, 0.1], [0.0, 0.0, 0.1, 0.9]])

observations = jnp.array([0.1, 0.2, 0.3, 0.4, 0.5, 3.0, 2.9, 2.8, 2.7, 2.6])

model = HMM(init_dist=distrax.Categorical(probs=initial),
            trans_dist=distrax.Categorical(probs=trans),
            obs_dist=distrax.Normal(loc, scale))

inferred_states = hmm_viterbi_log(model, observations)
expected_states = [0, 0, 0, 0, 1, 2, 3, 3, 3, 3]

assert np.allclose(inferred_states, expected_states)

length = 7
inferred_states = hmm_viterbi_log(model, observations, length)
expected_states = [0, 0, 0, 0, 1, 2, 3, -1, -1, -1]
assert np.allclose(inferred_states, expected_states)
Ejemplo n.º 6
0
    def __call__(
        self,
        inputs,
        sample_rng,
        context_vectors=None,
        encoder_outputs=None,
        temperature=1.,
    ):
        """Evaluates the DecoderBlock.

    Args:
      inputs: a batch of input images of shape [B, H, W, C], where H=W is the
        resolution, and C matches the number of channels of the DecoderBlock.
      sample_rng: random key for sampling.
      context_vectors: optional batch of shape [B, D]. These are typically used
        to condition the VDVAE.
      encoder_outputs: a mapping from resolution to encoded images corresponding
        to the output of an Encoder. This mapping should contain the resolution
        of `inputs`. For each resolution R in encoder_outputs, the corresponding
        value has shape [B, R, R, C].
      temperature: when encoder outputs are not provided, the decoder block
        samples a latent unconditionnally using the mean of the prior
        distribution, and its log_std + log(temperature).

    Returns:
      A DecoderBlockOutput object holding the outputs of the decoder block,
      which have the same shape as the in
      puts, as well as the KL divergence
      between the prior and posterior.

    Raises:
      ValueError: if the inputs are not square images, or they have a number
      of channels incompatible with the settings of the DecoderBlock.
    """
        chex.assert_rank(inputs, 4)
        if inputs.shape[1] != inputs.shape[2]:
            raise ValueError(
                'VDVAE only works with square images, but got '
                f'rectangular images of shape {inputs.shape[1:3]}.')
        if inputs.shape[3] != self.num_channels:
            raise ValueError('inputs have incompatible number of channels: '
                             f'got {inputs.shape[3]} channels but expeced '
                             f'{self.num_channels}.')

        if self.upsampling_rate > 1:
            current_res = inputs.shape[1]
            target_res = current_res * self.upsampling_rate
            target_shape = (inputs.shape[0], target_res, target_res,
                            inputs.shape[3])
            inputs = jax.image.resize(inputs,
                                      shape=target_shape,
                                      method='nearest')

        prior_mean, prior_log_std, features = self._compute_prior_and_features(
            inputs, context_vectors)
        if encoder_outputs is not None:
            posterior_mean, posterior_log_std = self._compute_posterior(
                inputs, encoder_outputs, context_vectors)
        else:
            posterior_mean = prior_mean
            posterior_log_std = prior_log_std + jnp.log(temperature)

        posterior_distribution = distrax.Independent(
            distrax.Normal(posterior_mean, jnp.exp(posterior_log_std)),
            reinterpreted_batch_ndims=3)
        prior_distribution = distrax.Independent(distrax.Normal(
            prior_mean, jnp.exp(prior_log_std)),
                                                 reinterpreted_batch_ndims=3)
        latent = posterior_distribution.sample(seed=sample_rng)
        kl = posterior_distribution.kl_divergence(prior_distribution)

        outputs = self._compute_outputs(inputs, features, latent)
        return DecoderBlockOutput(outputs=outputs, kl=kl)
Ejemplo n.º 7
0
    def value(self):
        return jax.tree_map(lambda x: x / self._num_samples, self._obj)

    def max(self):
        return jax.tree_map(float, self._obj_max)

    def min(self):
        return jax.tree_map(float, self._obj_min)

    def sum(self):
        return self._obj


register_pytree_node(distrax.Normal, lambda instance:
                     ([instance.loc, instance.scale], None),
                     lambda _, args: distrax.Normal(*args))


def inner_product(x: Any, y: Any) -> jnp.ndarray:
    products = jax.tree_multimap(lambda x_, y_: jnp.sum(x_ * y_), x, y)
    return sum(jax.tree_leaves(products))


get_first = utils.get_first
bcast_local_devices = utils.bcast_local_devices
py_prefetch = utils.py_prefetch
p_split = jax.pmap(lambda x, num: list(jax.random.split(x, num)),
                   static_broadcasted_argnums=1)


def wrap_if_pmap(p_func):