Esempio n. 1
0
    def corrupt(self, x, stage):
        """Corrupts x to a lower depth version for each stage."""
        # Note: we take the _generative process_ order. This is the reverse of the
        # diffusion process. However, the transition matrices are defined in the
        # order of the diffusion process. So here future is 's' and past 's+1'.
        stage_reverse = self.num_stages - 1 - stage
        if self.direct_parametrization:
            t = jnp.expand_dims(stage_reverse, np.arange(1, len(x.shape)))
            # The branch_factor determines at which speed the integers are reduced.
            # For example: suppose we have a 16-bit problem and we want to model it
            # in two stages, we could pick a branch factor of 256. Then, 16383 // 256
            #  = 255 for the first corruption and 255 // 256 = 0 for the second
            # corruption.
            x_past = (x // self.branch_factor**
                      (t + 1)) * self.branch_factor**(t + 1)
            x_future = (x // self.branch_factor**t) * self.branch_factor**t

        else:
            transition_future = self.cum_matmul_transition_matrices[
                stage_reverse]
            transition_past = self.cum_matmul_transition_matrices[stage_reverse
                                                                  + 1]

            # Here we compute the past and future x representations.
            x_onehot = util_fns.onehot(x, self.num_input_classes)
            x_past_onehot = self.transform_with_matrix(x_onehot,
                                                       transition_past)
            x_past = jnp.argmax(x_past_onehot, axis=-1)
            x_future_onehot = self.transform_with_matrix(
                x_onehot, transition_future)
            x_future = jnp.argmax(x_future_onehot, axis=-1)

        return x_future, x_past
Esempio n. 2
0
    def log_prob_for_x_future_given_past(self, x_future, x_past, net_out,
                                         stage):
        """Computes the log probability of x_future given x_past and net_out.

    Args:
      x_future: The variable to compute the log probability for.
      x_past: The variable to condition on.
      net_out: Network output containing the logits.
      stage: The stage of the specific datapoint.

    Returns:
      The elementwise log probabilities of x_future.
    """
        batch_size = x_future.shape[0]
        if self.direct_parametrization:
            logits_per_stage = net_out.reshape(*x_future.shape,
                                               self.num_stages,
                                               self.branch_factor)

            # Retrieve the logits for this specific stage.
            logits = logits_per_stage[jnp.arange(batch_size), Ellipsis,
                                      stage, :]
            log_probs = jax.nn.log_softmax(logits, axis=-1)

            t = self.num_stages - 1 - stage
            t = jnp.expand_dims(t, np.arange(1, len(x_future.shape)))
            x_target = (x_future // self.branch_factor**t) % self.branch_factor

            x_target_onehot = util_fns.onehot(x_target, self.branch_factor)

            log_prob_future_given_past = jnp.sum(log_probs * x_target_onehot,
                                                 axis=-1)

        else:
            logits = net_out.reshape(*x_future.shape, self.num_output_classes)
            log_prob_future_given_past = self.log_prob_vector_future_given_past(
                logits, x_past, stage)
            x_future_onehot = util_fns.onehot(x_future, self.num_input_classes)
            log_prob_future_given_past = jnp.sum(log_prob_future_given_past *
                                                 x_future_onehot,
                                                 axis=-1)

        return log_prob_future_given_past
Esempio n. 3
0
def sample_categorical_and_log_prob(key, logits):
    # If logits already normalized, operation does nothing as is desired.
    log_p = jax.nn.log_softmax(logits, axis=-1)

    category = sample_categorical(key, logits)

    category_onehot = util_fns.onehot(category, num_classes=logits.shape[-1])

    log_pcategory = (log_p * category_onehot).sum(axis=-1)
    return category, log_pcategory
Esempio n. 4
0
    def prepare_additional_input(self, stage, already_predicted):
        new_axes = tuple(range(1, len(already_predicted.shape) - 1))
        stage = jnp.expand_dims(stage, axis=new_axes)
        stage_onehot = util_fns.onehot(stage, num_classes=self.num_stages)
        stage_onehot = jnp.broadcast_to(stage_onehot,
                                        shape=already_predicted.shape[:-1] +
                                        (self.num_stages, ))

        add_info = jnp.concatenate([already_predicted, stage_onehot], axis=-1)

        return add_info
Esempio n. 5
0
def sample_from_discretized_mix_logistic_rgb(rng, params, nr_mix):
    """Sample from discretized mix logistic distribution."""
    xshape = params.shape[:-1] + (3, )
    batchsize, height, width, _ = xshape

    # unpack parameters
    pi_logits = params[:, :, :, :nr_mix]
    remaining_params = params[:, :, :, nr_mix:].reshape(*xshape, nr_mix * 3)

    # sample mixture indicator from softmax
    rng1, rng2 = jax.random.split(rng)
    mixture_idcs = sample_categorical(rng1, pi_logits)

    onehot_values = util_fns.onehot(mixture_idcs, nr_mix)

    assert onehot_values.shape == (batchsize, height, width, nr_mix)

    selection = onehot_values.reshape(xshape[:-1] + (1, nr_mix))

    # select logistic parameters
    means = jnp.sum(remaining_params[:, :, :, :, :nr_mix] * selection, axis=4)
    pre_act_scales = jnp.sum(remaining_params[:, :, :, :, nr_mix:2 * nr_mix] *
                             selection,
                             axis=4)

    coeffs = jnp.sum(
        jax.nn.tanh(remaining_params[:, :, :, :, 2 * nr_mix:3 * nr_mix]) *
        selection,
        axis=4)

    u = jax.random.uniform(rng2, means.shape, minval=1e-5, maxval=1. - 1e-5)

    standard_logistic = jnp.log(u) - jnp.log(1. - u)
    scale = 1. / jax.nn.softplus(pre_act_scales)
    x = means + scale * standard_logistic

    x0 = jnp.clip(x[:, :, :, 0], a_min=-1., a_max=1.)
    # TODO(emielh) although this is typically how it is implemented, technically
    # one should first round x0 to the grid before using it. It does not matter
    # too much since it is only used linearly.
    x1 = jnp.clip(x[:, :, :, 1] + coeffs[:, :, :, 0] * x0, a_min=-1., a_max=1.)
    x2 = jnp.clip(x[:, :, :, 2] + coeffs[:, :, :, 1] * x0 +
                  coeffs[:, :, :, 2] * x1,
                  a_min=-1.,
                  a_max=1.)

    sample_x = jnp.concatenate(
        [x0[:, :, :, None], x1[:, :, :, None], x2[:, :, :, None]], axis=3)

    return sample_x
Esempio n. 6
0
  def log_prob(self, x, dist_params):
    """Computes log prob."""
    assert x.dtype == jnp.int32
    assert x.shape[:-1] == dist_params.shape[:-1], (
        f' for {x.shape} and {dist_params.shape}')
    assert dist_params.shape[-1] == self.n_classes * self.n_channels

    logits = dist_params.reshape(*x.shape, self.n_classes)

    x_onehot = util_fns.onehot(x, num_classes=self.n_classes)

    log_probs = jax.nn.log_softmax(logits, axis=-1)
    log_probs = jnp.sum(x_onehot * log_probs, axis=-1)

    return log_probs
Esempio n. 7
0
    def log_prob_vector_future_given_past(self, logits, x_past, stage):
        probs = jax.nn.softmax(logits, axis=-1)
        if self.num_input_classes > self.num_output_classes:
            # We have to pad probs with empty probabilities for augmented classes.
            pad_size = self.num_input_classes - self.num_output_classes

            # Only padding at the end of the last axis.
            padding = ((0, 0), ) * (len(logits.shape) - 1) + ((0, pad_size), )

            # Vector is zero-padded since it is in log-space.
            probs = jnp.pad(probs,
                            pad_width=padding,
                            mode='constant',
                            constant_values=0)

        stage_reversed = self.num_stages - 1 - stage
        transition_future = self.cum_matmul_transition_matrices[stage_reversed]

        transition_single_step = self.transition_matrices[stage_reversed]

        probs_future = self.transform_with_matrix(probs, transition_future)

        x_past_onehot = util_fns.onehot(x_past, self.num_input_classes)

        # We only select possible futures:
        num_data_axes = len(x_past.shape) - 1  # minus batch size.
        new_axes = tuple(range(1, num_data_axes + 1))
        possible_futures = jnp.sum(
            jnp.expand_dims(transition_single_step, axis=new_axes) *
            x_past_onehot[Ellipsis, None],
            axis=-2)

        # Collect logits given past, using only possible futures.
        infty = 1e20
        logits_given_past = jnp.log(probs_future * possible_futures +
                                    1e-10) - (1 - possible_futures) * infty

        # Forces stability. Corresponds to -jnp.log(prob_past[..., None]) but
        # this is numerically more stable.
        log_probs = logits_given_past - jax.nn.logsumexp(
            logits_given_past, axis=-1, keepdims=True)

        return log_probs
Esempio n. 8
0
def sample_from_discretized_mix_logistic(rng, params, nr_mix):
    """Sample from discretized mix logistic distribution, channel-independent."""
    channels = (params.shape[-1] // nr_mix - 1) // 2
    xshape = params.shape[:-1] + (channels, )
    batchsize = xshape[0]
    spatial_dims = xshape[1:-1]

    # Unpack parameters.
    pi_logits = params[Ellipsis, :nr_mix]
    remaining_params = params[Ellipsis, nr_mix:].reshape(*xshape, nr_mix * 2)

    means = remaining_params[Ellipsis, :nr_mix]
    means = means + jnp.reshape(
        jnp.linspace(-1. + 1 / nr_mix, 1 - 1 / nr_mix, nr_mix),
        (1, ) * len(xshape) + (nr_mix, ))

    log_scales = remaining_params[Ellipsis, nr_mix:2 * nr_mix]
    log_scales = log_scales - jnp.log(nr_mix)

    # Sample mixture indicator from softmax.
    rng1, rng2 = jax.random.split(rng)
    mixture_idcs = sample_categorical(rng1, pi_logits)

    onehot_values = util_fns.onehot(mixture_idcs, nr_mix)

    assert onehot_values.shape == (batchsize, ) + spatial_dims + (nr_mix, )

    selection = onehot_values.reshape(xshape[:-1] + (1, nr_mix))

    # Select logistic parameters.
    means = jnp.sum(means * selection, axis=-1)
    log_scales = jnp.sum(log_scales * selection, axis=-1)
    log_scales = jnp.clip(log_scales, a_min=-7.)

    u = jax.random.uniform(rng2, means.shape, minval=1e-5, maxval=1. - 1e-5)

    standard_logistic = jnp.log(u) - jnp.log(1. - u)

    sample_x = means + jnp.exp(log_scales) * standard_logistic
    sample_x = jnp.clip(sample_x, a_min=-1., a_max=1.)

    return sample_x