def model(self, value): mixing_coeffs_logits, probs_logits = value self._model = MixtureSameFamily( mixture_distribution=distrax.Categorical( logits=mixing_coeffs_logits), components_distribution=distrax.Independent( distrax.Bernoulli(logits=probs_logits), reinterpreted_batch_ndims=1))
def __call__(self, x: jnp.ndarray) -> VAEOutput: x = x.astype(jnp.float32) # q(z|x) = N(mean(x), covariance(x)) mean, stddev = Encoder(self._hidden_size, self._latent_size)(x) variational_distrib = distrax.MultivariateNormalDiag(loc=mean, scale_diag=stddev) z = variational_distrib.sample(seed=hk.next_rng_key()) # p(x|z) = \Prod Bernoulli(logits(z)) logits = Decoder(self._hidden_size, self._output_shape)(z) likelihood_distrib = distrax.Independent( distrax.Bernoulli(logits=logits), reinterpreted_batch_ndims=len( self._output_shape)) # 3 non-batch dims # Generate images from the likelihood image = likelihood_distrib.sample(seed=hk.next_rng_key()) return VAEOutput(variational_distrib, likelihood_distrib, image)
def make_flow_model(event_shape: Sequence[int], num_layers: int, hidden_sizes: Sequence[int], num_bins: int) -> distrax.Transformed: """Creates the flow model.""" # Alternating binary mask. mask = jnp.arange(0, np.prod(event_shape)) % 2 mask = jnp.reshape(mask, event_shape) mask = mask.astype(bool) def bijector_fn(params: Array): return distrax.RationalQuadraticSpline(params, range_min=0., range_max=1.) # Number of parameters for the rational-quadratic spline: # - `num_bins` bin widths # - `num_bins` bin heights # - `num_bins + 1` knot slopes # for a total of `3 * num_bins + 1` parameters. num_bijector_params = 3 * num_bins + 1 layers = [] for _ in range(num_layers): layer = distrax.MaskedCoupling(mask=mask, bijector=bijector_fn, conditioner=make_conditioner( event_shape, hidden_sizes, num_bijector_params)) layers.append(layer) # Flip the mask after each layer. mask = jnp.logical_not(mask) # We invert the flow so that the `forward` method is called with `log_prob`. flow = distrax.Inverse(distrax.Chain(layers)) base_distribution = distrax.Independent( distrax.Uniform(low=jnp.zeros(event_shape), high=jnp.ones(event_shape)), reinterpreted_batch_ndims=len(event_shape)) return distrax.Transformed(base_distribution, flow)
def model(self, value): mixing_coeffs, probs = value self._model = MixtureSameFamily(mixture_distribution=distrax.Categorical(probs=mixing_coeffs), components_distribution=distrax.Independent(distrax.Bernoulli(probs=probs)))
def __call__( self, inputs, sample_rng, context_vectors=None, encoder_outputs=None, temperature=1., ): """Evaluates the DecoderBlock. Args: inputs: a batch of input images of shape [B, H, W, C], where H=W is the resolution, and C matches the number of channels of the DecoderBlock. sample_rng: random key for sampling. context_vectors: optional batch of shape [B, D]. These are typically used to condition the VDVAE. encoder_outputs: a mapping from resolution to encoded images corresponding to the output of an Encoder. This mapping should contain the resolution of `inputs`. For each resolution R in encoder_outputs, the corresponding value has shape [B, R, R, C]. temperature: when encoder outputs are not provided, the decoder block samples a latent unconditionnally using the mean of the prior distribution, and its log_std + log(temperature). Returns: A DecoderBlockOutput object holding the outputs of the decoder block, which have the same shape as the in puts, as well as the KL divergence between the prior and posterior. Raises: ValueError: if the inputs are not square images, or they have a number of channels incompatible with the settings of the DecoderBlock. """ chex.assert_rank(inputs, 4) if inputs.shape[1] != inputs.shape[2]: raise ValueError( 'VDVAE only works with square images, but got ' f'rectangular images of shape {inputs.shape[1:3]}.') if inputs.shape[3] != self.num_channels: raise ValueError('inputs have incompatible number of channels: ' f'got {inputs.shape[3]} channels but expeced ' f'{self.num_channels}.') if self.upsampling_rate > 1: current_res = inputs.shape[1] target_res = current_res * self.upsampling_rate target_shape = (inputs.shape[0], target_res, target_res, inputs.shape[3]) inputs = jax.image.resize(inputs, shape=target_shape, method='nearest') prior_mean, prior_log_std, features = self._compute_prior_and_features( inputs, context_vectors) if encoder_outputs is not None: posterior_mean, posterior_log_std = self._compute_posterior( inputs, encoder_outputs, context_vectors) else: posterior_mean = prior_mean posterior_log_std = prior_log_std + jnp.log(temperature) posterior_distribution = distrax.Independent( distrax.Normal(posterior_mean, jnp.exp(posterior_log_std)), reinterpreted_batch_ndims=3) prior_distribution = distrax.Independent(distrax.Normal( prior_mean, jnp.exp(prior_log_std)), reinterpreted_batch_ndims=3) latent = posterior_distribution.sample(seed=sample_rng) kl = posterior_distribution.kl_divergence(prior_distribution) outputs = self._compute_outputs(inputs, features, latent) return DecoderBlockOutput(outputs=outputs, kl=kl)