def test_posterior_mode(self): mix_dist_probs = jnp.array([[0.5, 0.5], [0.01, 0.99]]) locs = jnp.array([[-1., 1.], [-1., 1.]]) scale = jnp.array([1.]) gm = MixtureSameFamily( mixture_distribution=distrax.Categorical(probs=mix_dist_probs), components_distribution=distrax.Normal(loc=locs, scale=scale)) mode = gm.posterior_mode(jnp.array([[1.], [-1.], [-6.]])) self.assertEqual((3, 2), mode.shape) self.assertAllClose(jnp.array([[1, 1], [0, 1], [0, 0]]), mode)
def make_distribution(self, net_output: jnp.ndarray) -> distrax.Distribution: if self.learned_sigma: init = hk.initializers.Constant(-jnp.log(2.0) / 2.0) log_scale = hk.get_parameter("log_scale", shape=(), dtype=net_output.dtype, init=init) scale = jnp.full_like(net_output, jnp.exp(log_scale)) else: scale = jnp.full_like(net_output, 1 / jnp.sqrt(2.0)) return distrax.Normal(net_output, scale)
def make_distribution(self, net_output: jnp.ndarray) -> distrax.Distribution: if self.distribution_name is None: return net_output elif self.distribution_name == "diagonal_normal": if self.aggregation_type is None: split_axis, num_axes = self.data_format.index("C"), 3 else: split_axis, num_axes = 1, 1 # Add an extra axis if the input has more than 1 batch dimension split_axis += net_output.ndim - num_axes - 1 loc, log_scale = jnp.split(net_output, 2, axis=split_axis) return distrax.Normal(loc, jnp.exp(log_scale)) else: raise NotImplementedError()
def prior(self) -> distrax.Distribution: """Given the parameters returns the prior distribution of the model.""" # Allow to run with both the full parameters and only the priors if self.prior_type == "standard_normal": # assert self.prior_nets is None and self.gated_made is None if self.latent_system_net_type == "mlp": event_shape = (self.latent_system_dim, ) elif self.latent_system_net_type == "conv": if self.data_format == "NHWC": event_shape = self.latent_spatial_shape + ( self.latent_system_dim, ) else: event_shape = ( self.latent_system_dim, ) + self.latent_spatial_shape else: raise NotImplementedError() return distrax.Normal(jnp.zeros(event_shape), jnp.ones(event_shape)) else: raise ValueError(f"Unrecognized prior_type='{self.prior_type}'.")
print( f'Loglikelihood found by HMM General Log Space Version: {loglikelihood_log}' ) assert np.allclose(jnp.log(alphas), alphas_log, 8) assert np.allclose(loglikelihood, loglikelihood_log) assert np.allclose(jnp.log(gammas), gammas_log, 8) # Test for the hmm_viterbi_log. This test is based on https://github.com/deepmind/distrax/blob/master/distrax/_src/utils/hmm_test.py loc = jnp.array([0.0, 1.0, 2.0, 3.0]) scale = jnp.array(0.25) initial = jnp.array([0.25, 0.25, 0.25, 0.25]) trans = jnp.array([[0.9, 0.1, 0.0, 0.0], [0.1, 0.8, 0.1, 0.0], [0.0, 0.1, 0.8, 0.1], [0.0, 0.0, 0.1, 0.9]]) observations = jnp.array([0.1, 0.2, 0.3, 0.4, 0.5, 3.0, 2.9, 2.8, 2.7, 2.6]) model = HMM(init_dist=distrax.Categorical(probs=initial), trans_dist=distrax.Categorical(probs=trans), obs_dist=distrax.Normal(loc, scale)) inferred_states = hmm_viterbi_log(model, observations) expected_states = [0, 0, 0, 0, 1, 2, 3, 3, 3, 3] assert np.allclose(inferred_states, expected_states) length = 7 inferred_states = hmm_viterbi_log(model, observations, length) expected_states = [0, 0, 0, 0, 1, 2, 3, -1, -1, -1] assert np.allclose(inferred_states, expected_states)
def __call__( self, inputs, sample_rng, context_vectors=None, encoder_outputs=None, temperature=1., ): """Evaluates the DecoderBlock. Args: inputs: a batch of input images of shape [B, H, W, C], where H=W is the resolution, and C matches the number of channels of the DecoderBlock. sample_rng: random key for sampling. context_vectors: optional batch of shape [B, D]. These are typically used to condition the VDVAE. encoder_outputs: a mapping from resolution to encoded images corresponding to the output of an Encoder. This mapping should contain the resolution of `inputs`. For each resolution R in encoder_outputs, the corresponding value has shape [B, R, R, C]. temperature: when encoder outputs are not provided, the decoder block samples a latent unconditionnally using the mean of the prior distribution, and its log_std + log(temperature). Returns: A DecoderBlockOutput object holding the outputs of the decoder block, which have the same shape as the in puts, as well as the KL divergence between the prior and posterior. Raises: ValueError: if the inputs are not square images, or they have a number of channels incompatible with the settings of the DecoderBlock. """ chex.assert_rank(inputs, 4) if inputs.shape[1] != inputs.shape[2]: raise ValueError( 'VDVAE only works with square images, but got ' f'rectangular images of shape {inputs.shape[1:3]}.') if inputs.shape[3] != self.num_channels: raise ValueError('inputs have incompatible number of channels: ' f'got {inputs.shape[3]} channels but expeced ' f'{self.num_channels}.') if self.upsampling_rate > 1: current_res = inputs.shape[1] target_res = current_res * self.upsampling_rate target_shape = (inputs.shape[0], target_res, target_res, inputs.shape[3]) inputs = jax.image.resize(inputs, shape=target_shape, method='nearest') prior_mean, prior_log_std, features = self._compute_prior_and_features( inputs, context_vectors) if encoder_outputs is not None: posterior_mean, posterior_log_std = self._compute_posterior( inputs, encoder_outputs, context_vectors) else: posterior_mean = prior_mean posterior_log_std = prior_log_std + jnp.log(temperature) posterior_distribution = distrax.Independent( distrax.Normal(posterior_mean, jnp.exp(posterior_log_std)), reinterpreted_batch_ndims=3) prior_distribution = distrax.Independent(distrax.Normal( prior_mean, jnp.exp(prior_log_std)), reinterpreted_batch_ndims=3) latent = posterior_distribution.sample(seed=sample_rng) kl = posterior_distribution.kl_divergence(prior_distribution) outputs = self._compute_outputs(inputs, features, latent) return DecoderBlockOutput(outputs=outputs, kl=kl)
def value(self): return jax.tree_map(lambda x: x / self._num_samples, self._obj) def max(self): return jax.tree_map(float, self._obj_max) def min(self): return jax.tree_map(float, self._obj_min) def sum(self): return self._obj register_pytree_node(distrax.Normal, lambda instance: ([instance.loc, instance.scale], None), lambda _, args: distrax.Normal(*args)) def inner_product(x: Any, y: Any) -> jnp.ndarray: products = jax.tree_multimap(lambda x_, y_: jnp.sum(x_ * y_), x, y) return sum(jax.tree_leaves(products)) get_first = utils.get_first bcast_local_devices = utils.bcast_local_devices py_prefetch = utils.py_prefetch p_split = jax.pmap(lambda x, num: list(jax.random.split(x, num)), static_broadcasted_argnums=1) def wrap_if_pmap(p_func):