예제 #1
0
 def model(self, value):
     mixing_coeffs_logits, probs_logits = value
     self._model = MixtureSameFamily(
         mixture_distribution=distrax.Categorical(
             logits=mixing_coeffs_logits),
         components_distribution=distrax.Independent(
             distrax.Bernoulli(logits=probs_logits),
             reinterpreted_batch_ndims=1))
예제 #2
0
파일: vae.py 프로젝트: stjordanis/distrax
    def __call__(self, x: jnp.ndarray) -> VAEOutput:
        x = x.astype(jnp.float32)

        # q(z|x) = N(mean(x), covariance(x))
        mean, stddev = Encoder(self._hidden_size, self._latent_size)(x)
        variational_distrib = distrax.MultivariateNormalDiag(loc=mean,
                                                             scale_diag=stddev)
        z = variational_distrib.sample(seed=hk.next_rng_key())

        # p(x|z) = \Prod Bernoulli(logits(z))
        logits = Decoder(self._hidden_size, self._output_shape)(z)
        likelihood_distrib = distrax.Independent(
            distrax.Bernoulli(logits=logits),
            reinterpreted_batch_ndims=len(
                self._output_shape))  # 3 non-batch dims

        # Generate images from the likelihood
        image = likelihood_distrib.sample(seed=hk.next_rng_key())

        return VAEOutput(variational_distrib, likelihood_distrib, image)
예제 #3
0
def make_flow_model(event_shape: Sequence[int], num_layers: int,
                    hidden_sizes: Sequence[int],
                    num_bins: int) -> distrax.Transformed:
    """Creates the flow model."""
    # Alternating binary mask.
    mask = jnp.arange(0, np.prod(event_shape)) % 2
    mask = jnp.reshape(mask, event_shape)
    mask = mask.astype(bool)

    def bijector_fn(params: Array):
        return distrax.RationalQuadraticSpline(params,
                                               range_min=0.,
                                               range_max=1.)

    # Number of parameters for the rational-quadratic spline:
    # - `num_bins` bin widths
    # - `num_bins` bin heights
    # - `num_bins + 1` knot slopes
    # for a total of `3 * num_bins + 1` parameters.
    num_bijector_params = 3 * num_bins + 1

    layers = []
    for _ in range(num_layers):
        layer = distrax.MaskedCoupling(mask=mask,
                                       bijector=bijector_fn,
                                       conditioner=make_conditioner(
                                           event_shape, hidden_sizes,
                                           num_bijector_params))
        layers.append(layer)
        # Flip the mask after each layer.
        mask = jnp.logical_not(mask)

    # We invert the flow so that the `forward` method is called with `log_prob`.
    flow = distrax.Inverse(distrax.Chain(layers))
    base_distribution = distrax.Independent(
        distrax.Uniform(low=jnp.zeros(event_shape),
                        high=jnp.ones(event_shape)),
        reinterpreted_batch_ndims=len(event_shape))

    return distrax.Transformed(base_distribution, flow)
예제 #4
0
 def model(self, value):
     mixing_coeffs, probs = value
     self._model = MixtureSameFamily(mixture_distribution=distrax.Categorical(probs=mixing_coeffs),
                                     components_distribution=distrax.Independent(distrax.Bernoulli(probs=probs)))
예제 #5
0
    def __call__(
        self,
        inputs,
        sample_rng,
        context_vectors=None,
        encoder_outputs=None,
        temperature=1.,
    ):
        """Evaluates the DecoderBlock.

    Args:
      inputs: a batch of input images of shape [B, H, W, C], where H=W is the
        resolution, and C matches the number of channels of the DecoderBlock.
      sample_rng: random key for sampling.
      context_vectors: optional batch of shape [B, D]. These are typically used
        to condition the VDVAE.
      encoder_outputs: a mapping from resolution to encoded images corresponding
        to the output of an Encoder. This mapping should contain the resolution
        of `inputs`. For each resolution R in encoder_outputs, the corresponding
        value has shape [B, R, R, C].
      temperature: when encoder outputs are not provided, the decoder block
        samples a latent unconditionnally using the mean of the prior
        distribution, and its log_std + log(temperature).

    Returns:
      A DecoderBlockOutput object holding the outputs of the decoder block,
      which have the same shape as the in
      puts, as well as the KL divergence
      between the prior and posterior.

    Raises:
      ValueError: if the inputs are not square images, or they have a number
      of channels incompatible with the settings of the DecoderBlock.
    """
        chex.assert_rank(inputs, 4)
        if inputs.shape[1] != inputs.shape[2]:
            raise ValueError(
                'VDVAE only works with square images, but got '
                f'rectangular images of shape {inputs.shape[1:3]}.')
        if inputs.shape[3] != self.num_channels:
            raise ValueError('inputs have incompatible number of channels: '
                             f'got {inputs.shape[3]} channels but expeced '
                             f'{self.num_channels}.')

        if self.upsampling_rate > 1:
            current_res = inputs.shape[1]
            target_res = current_res * self.upsampling_rate
            target_shape = (inputs.shape[0], target_res, target_res,
                            inputs.shape[3])
            inputs = jax.image.resize(inputs,
                                      shape=target_shape,
                                      method='nearest')

        prior_mean, prior_log_std, features = self._compute_prior_and_features(
            inputs, context_vectors)
        if encoder_outputs is not None:
            posterior_mean, posterior_log_std = self._compute_posterior(
                inputs, encoder_outputs, context_vectors)
        else:
            posterior_mean = prior_mean
            posterior_log_std = prior_log_std + jnp.log(temperature)

        posterior_distribution = distrax.Independent(
            distrax.Normal(posterior_mean, jnp.exp(posterior_log_std)),
            reinterpreted_batch_ndims=3)
        prior_distribution = distrax.Independent(distrax.Normal(
            prior_mean, jnp.exp(prior_log_std)),
                                                 reinterpreted_batch_ndims=3)
        latent = posterior_distribution.sample(seed=sample_rng)
        kl = posterior_distribution.kl_divergence(prior_distribution)

        outputs = self._compute_outputs(inputs, features, latent)
        return DecoderBlockOutput(outputs=outputs, kl=kl)