def corrupt(self, x, stage): """Corrupts x to a lower depth version for each stage.""" # Note: we take the _generative process_ order. This is the reverse of the # diffusion process. However, the transition matrices are defined in the # order of the diffusion process. So here future is 's' and past 's+1'. stage_reverse = self.num_stages - 1 - stage if self.direct_parametrization: t = jnp.expand_dims(stage_reverse, np.arange(1, len(x.shape))) # The branch_factor determines at which speed the integers are reduced. # For example: suppose we have a 16-bit problem and we want to model it # in two stages, we could pick a branch factor of 256. Then, 16383 // 256 # = 255 for the first corruption and 255 // 256 = 0 for the second # corruption. x_past = (x // self.branch_factor** (t + 1)) * self.branch_factor**(t + 1) x_future = (x // self.branch_factor**t) * self.branch_factor**t else: transition_future = self.cum_matmul_transition_matrices[ stage_reverse] transition_past = self.cum_matmul_transition_matrices[stage_reverse + 1] # Here we compute the past and future x representations. x_onehot = util_fns.onehot(x, self.num_input_classes) x_past_onehot = self.transform_with_matrix(x_onehot, transition_past) x_past = jnp.argmax(x_past_onehot, axis=-1) x_future_onehot = self.transform_with_matrix( x_onehot, transition_future) x_future = jnp.argmax(x_future_onehot, axis=-1) return x_future, x_past
def log_prob_for_x_future_given_past(self, x_future, x_past, net_out, stage): """Computes the log probability of x_future given x_past and net_out. Args: x_future: The variable to compute the log probability for. x_past: The variable to condition on. net_out: Network output containing the logits. stage: The stage of the specific datapoint. Returns: The elementwise log probabilities of x_future. """ batch_size = x_future.shape[0] if self.direct_parametrization: logits_per_stage = net_out.reshape(*x_future.shape, self.num_stages, self.branch_factor) # Retrieve the logits for this specific stage. logits = logits_per_stage[jnp.arange(batch_size), Ellipsis, stage, :] log_probs = jax.nn.log_softmax(logits, axis=-1) t = self.num_stages - 1 - stage t = jnp.expand_dims(t, np.arange(1, len(x_future.shape))) x_target = (x_future // self.branch_factor**t) % self.branch_factor x_target_onehot = util_fns.onehot(x_target, self.branch_factor) log_prob_future_given_past = jnp.sum(log_probs * x_target_onehot, axis=-1) else: logits = net_out.reshape(*x_future.shape, self.num_output_classes) log_prob_future_given_past = self.log_prob_vector_future_given_past( logits, x_past, stage) x_future_onehot = util_fns.onehot(x_future, self.num_input_classes) log_prob_future_given_past = jnp.sum(log_prob_future_given_past * x_future_onehot, axis=-1) return log_prob_future_given_past
def sample_categorical_and_log_prob(key, logits): # If logits already normalized, operation does nothing as is desired. log_p = jax.nn.log_softmax(logits, axis=-1) category = sample_categorical(key, logits) category_onehot = util_fns.onehot(category, num_classes=logits.shape[-1]) log_pcategory = (log_p * category_onehot).sum(axis=-1) return category, log_pcategory
def prepare_additional_input(self, stage, already_predicted): new_axes = tuple(range(1, len(already_predicted.shape) - 1)) stage = jnp.expand_dims(stage, axis=new_axes) stage_onehot = util_fns.onehot(stage, num_classes=self.num_stages) stage_onehot = jnp.broadcast_to(stage_onehot, shape=already_predicted.shape[:-1] + (self.num_stages, )) add_info = jnp.concatenate([already_predicted, stage_onehot], axis=-1) return add_info
def sample_from_discretized_mix_logistic_rgb(rng, params, nr_mix): """Sample from discretized mix logistic distribution.""" xshape = params.shape[:-1] + (3, ) batchsize, height, width, _ = xshape # unpack parameters pi_logits = params[:, :, :, :nr_mix] remaining_params = params[:, :, :, nr_mix:].reshape(*xshape, nr_mix * 3) # sample mixture indicator from softmax rng1, rng2 = jax.random.split(rng) mixture_idcs = sample_categorical(rng1, pi_logits) onehot_values = util_fns.onehot(mixture_idcs, nr_mix) assert onehot_values.shape == (batchsize, height, width, nr_mix) selection = onehot_values.reshape(xshape[:-1] + (1, nr_mix)) # select logistic parameters means = jnp.sum(remaining_params[:, :, :, :, :nr_mix] * selection, axis=4) pre_act_scales = jnp.sum(remaining_params[:, :, :, :, nr_mix:2 * nr_mix] * selection, axis=4) coeffs = jnp.sum( jax.nn.tanh(remaining_params[:, :, :, :, 2 * nr_mix:3 * nr_mix]) * selection, axis=4) u = jax.random.uniform(rng2, means.shape, minval=1e-5, maxval=1. - 1e-5) standard_logistic = jnp.log(u) - jnp.log(1. - u) scale = 1. / jax.nn.softplus(pre_act_scales) x = means + scale * standard_logistic x0 = jnp.clip(x[:, :, :, 0], a_min=-1., a_max=1.) # TODO(emielh) although this is typically how it is implemented, technically # one should first round x0 to the grid before using it. It does not matter # too much since it is only used linearly. x1 = jnp.clip(x[:, :, :, 1] + coeffs[:, :, :, 0] * x0, a_min=-1., a_max=1.) x2 = jnp.clip(x[:, :, :, 2] + coeffs[:, :, :, 1] * x0 + coeffs[:, :, :, 2] * x1, a_min=-1., a_max=1.) sample_x = jnp.concatenate( [x0[:, :, :, None], x1[:, :, :, None], x2[:, :, :, None]], axis=3) return sample_x
def log_prob(self, x, dist_params): """Computes log prob.""" assert x.dtype == jnp.int32 assert x.shape[:-1] == dist_params.shape[:-1], ( f' for {x.shape} and {dist_params.shape}') assert dist_params.shape[-1] == self.n_classes * self.n_channels logits = dist_params.reshape(*x.shape, self.n_classes) x_onehot = util_fns.onehot(x, num_classes=self.n_classes) log_probs = jax.nn.log_softmax(logits, axis=-1) log_probs = jnp.sum(x_onehot * log_probs, axis=-1) return log_probs
def log_prob_vector_future_given_past(self, logits, x_past, stage): probs = jax.nn.softmax(logits, axis=-1) if self.num_input_classes > self.num_output_classes: # We have to pad probs with empty probabilities for augmented classes. pad_size = self.num_input_classes - self.num_output_classes # Only padding at the end of the last axis. padding = ((0, 0), ) * (len(logits.shape) - 1) + ((0, pad_size), ) # Vector is zero-padded since it is in log-space. probs = jnp.pad(probs, pad_width=padding, mode='constant', constant_values=0) stage_reversed = self.num_stages - 1 - stage transition_future = self.cum_matmul_transition_matrices[stage_reversed] transition_single_step = self.transition_matrices[stage_reversed] probs_future = self.transform_with_matrix(probs, transition_future) x_past_onehot = util_fns.onehot(x_past, self.num_input_classes) # We only select possible futures: num_data_axes = len(x_past.shape) - 1 # minus batch size. new_axes = tuple(range(1, num_data_axes + 1)) possible_futures = jnp.sum( jnp.expand_dims(transition_single_step, axis=new_axes) * x_past_onehot[Ellipsis, None], axis=-2) # Collect logits given past, using only possible futures. infty = 1e20 logits_given_past = jnp.log(probs_future * possible_futures + 1e-10) - (1 - possible_futures) * infty # Forces stability. Corresponds to -jnp.log(prob_past[..., None]) but # this is numerically more stable. log_probs = logits_given_past - jax.nn.logsumexp( logits_given_past, axis=-1, keepdims=True) return log_probs
def sample_from_discretized_mix_logistic(rng, params, nr_mix): """Sample from discretized mix logistic distribution, channel-independent.""" channels = (params.shape[-1] // nr_mix - 1) // 2 xshape = params.shape[:-1] + (channels, ) batchsize = xshape[0] spatial_dims = xshape[1:-1] # Unpack parameters. pi_logits = params[Ellipsis, :nr_mix] remaining_params = params[Ellipsis, nr_mix:].reshape(*xshape, nr_mix * 2) means = remaining_params[Ellipsis, :nr_mix] means = means + jnp.reshape( jnp.linspace(-1. + 1 / nr_mix, 1 - 1 / nr_mix, nr_mix), (1, ) * len(xshape) + (nr_mix, )) log_scales = remaining_params[Ellipsis, nr_mix:2 * nr_mix] log_scales = log_scales - jnp.log(nr_mix) # Sample mixture indicator from softmax. rng1, rng2 = jax.random.split(rng) mixture_idcs = sample_categorical(rng1, pi_logits) onehot_values = util_fns.onehot(mixture_idcs, nr_mix) assert onehot_values.shape == (batchsize, ) + spatial_dims + (nr_mix, ) selection = onehot_values.reshape(xshape[:-1] + (1, nr_mix)) # Select logistic parameters. means = jnp.sum(means * selection, axis=-1) log_scales = jnp.sum(log_scales * selection, axis=-1) log_scales = jnp.clip(log_scales, a_min=-7.) u = jax.random.uniform(rng2, means.shape, minval=1e-5, maxval=1. - 1e-5) standard_logistic = jnp.log(u) - jnp.log(1. - u) sample_x = means + jnp.exp(log_scales) * standard_logistic sample_x = jnp.clip(sample_x, a_min=-1., a_max=1.) return sample_x