def loop_body(step, val): rng, x, x_mean = val grad = score_fn(x, t) rng, step_rng = jax.random.split(rng) noise = jax.random.normal(step_rng, x.shape) step_size = (target_snr * std)**2 * 2 * alpha x_mean = x + batch_mul(step_size, grad) x = x_mean + batch_mul(noise, jnp.sqrt(step_size * 2)) return rng, x, x_mean
def vpsde_update_fn(self, rng, x, t): sde = self.sde timestep = (t * (sde.N - 1) / sde.T).astype(jnp.int32) beta = sde.discrete_betas[timestep] score = self.score_fn(x, t) x_mean = batch_mul((x + batch_mul(beta, score)), 1. / jnp.sqrt(1. - beta)) noise = random.normal(rng, x.shape) x = x_mean + batch_mul(jnp.sqrt(beta), noise) return x, x_mean
def vesde_update_fn(self, rng, x, t): sde = self.sde timestep = (t * (sde.N - 1) / sde.T).astype(jnp.int32) sigma = sde.discrete_sigmas[timestep] adjacent_sigma = jnp.where(timestep == 0, jnp.zeros(t.shape), sde.discrete_sigmas[timestep - 1]) score = self.score_fn(x, t) x_mean = x + batch_mul(score, sigma**2 - adjacent_sigma**2) std = jnp.sqrt( (adjacent_sigma**2 * (sigma**2 - adjacent_sigma**2)) / (sigma**2)) noise = random.normal(rng, x.shape) x = x_mean + batch_mul(std, noise) return x, x_mean
def loop_body(step, val): rng, x, x_mean = val grad = score_fn(x, t) rng, step_rng = jax.random.split(rng) noise = jax.random.normal(step_rng, x.shape) grad_norm = jnp.linalg.norm(grad.reshape((grad.shape[0], -1)), axis=-1).mean() grad_norm = jax.lax.pmean(grad_norm, axis_name='batch') noise_norm = jnp.linalg.norm(noise.reshape((noise.shape[0], -1)), axis=-1).mean() noise_norm = jax.lax.pmean(noise_norm, axis_name='batch') step_size = (target_snr * noise_norm / grad_norm)**2 * 2 * alpha x_mean = x + batch_mul(step_size, grad) x = x_mean + batch_mul(noise, jnp.sqrt(step_size * 2)) return rng, x, x_mean
def sde(self, x, t): beta_t = self.beta_0 + t * (self.beta_1 - self.beta_0) drift = -0.5 * batch_mul(beta_t, x) discount = 1. - jnp.exp(-2 * self.beta_0 * t - (self.beta_1 - self.beta_0) * t**2) diffusion = jnp.sqrt(beta_t * discount) return drift, diffusion
def update_fn(self, rng, x, t): dt = -1. / self.rsde.N z = random.normal(rng, x.shape) drift, diffusion = self.rsde.sde(x, t) x_mean = x + drift * dt x = x_mean + batch_mul(diffusion, jnp.sqrt(-dt) * z) return x, x_mean
def loss_fn(rng, params, states, batch): model_fn = mutils.get_model_fn(model, params, states, train=train) data = batch['image'] rng, step_rng = random.split(rng) labels = random.choice(step_rng, vesde.N, shape=(data.shape[0], )) sigmas = smld_sigma_array[labels] rng, step_rng = random.split(rng) noise = batch_mul(random.normal(step_rng, data.shape), sigmas) perturbed_data = noise + data rng, step_rng = random.split(rng) score, new_model_state = model_fn(perturbed_data, labels, rng=step_rng) target = -batch_mul(noise, 1. / (sigmas**2)) losses = jnp.square(score - target) losses = reduce_op(losses.reshape( (losses.shape[0], -1)), axis=-1) * sigmas**2 loss = jnp.mean(losses) return loss, new_model_state
def discretize(self, x, t): """Create discretized iteration rules for the reverse diffusion sampler.""" f, G = discretize_fn(x, t) rev_f = f - batch_mul( G**2, score_fn(x, t) * (0.5 if self.probability_flow else 1.)) rev_G = 0. if self.probability_flow else G return rev_f, rev_G
def loss_fn(rng, params, states, batch): model_fn = mutils.get_model_fn(model, params, states, train=train) data = batch['image'] rng, step_rng = random.split(rng) labels = random.choice(step_rng, vpsde.N, shape=(data.shape[0], )) sqrt_alphas_cumprod = vpsde.sqrt_alphas_cumprod sqrt_1m_alphas_cumprod = vpsde.sqrt_1m_alphas_cumprod rng, step_rng = random.split(rng) noise = random.normal(step_rng, data.shape) perturbed_data = batch_mul(sqrt_alphas_cumprod[labels], data) + \ batch_mul(sqrt_1m_alphas_cumprod[labels], noise) rng, step_rng = random.split(rng) score, new_model_state = model_fn(perturbed_data, labels, rng=step_rng) losses = jnp.square(score - noise) losses = reduce_op(losses.reshape((losses.shape[0], -1)), axis=-1) loss = jnp.mean(losses) return loss, new_model_state
def loss_fn(rng, params, states, batch): """Compute the loss function. Args: rng: A JAX random state. params: A dictionary that contains trainable parameters of the score-based model. states: A dictionary that contains mutable states of the score-based model. batch: A mini-batch of training data. Returns: loss: A scalar that represents the average loss value across the mini-batch. new_model_state: A dictionary that contains the mutated states of the score-based model. """ score_fn = mutils.get_score_fn(sde, model, params, states, train=train, continuous=continuous, return_state=True) data = batch['image'] rng, step_rng = random.split(rng) t = random.uniform(step_rng, (data.shape[0], ), minval=eps, maxval=sde.T) rng, step_rng = random.split(rng) z = random.normal(step_rng, data.shape) mean, std = sde.marginal_prob(data, t) perturbed_data = mean + batch_mul(std, z) rng, step_rng = random.split(rng) score, new_model_state = score_fn(perturbed_data, t, rng=step_rng) if not likelihood_weighting: losses = jnp.square(batch_mul(score, std) + z) losses = reduce_op(losses.reshape((losses.shape[0], -1)), axis=-1) else: g2 = sde.sde(jnp.zeros_like(data), t)[1]**2 losses = jnp.square(score + batch_mul(z, 1. / std)) losses = reduce_op(losses.reshape( (losses.shape[0], -1)), axis=-1) * g2 loss = jnp.mean(losses) return loss, new_model_state
def discretize(self, x, t): """DDPM discretization.""" timestep = (t * (self.N - 1) / self.T).astype(jnp.int32) beta = self.discrete_betas[timestep] alpha = self.alphas[timestep] sqrt_beta = jnp.sqrt(beta) f = batch_mul(jnp.sqrt(alpha), x) - x G = sqrt_beta return f, G
def sde(self, x, t): """Create the drift and diffusion functions for the reverse SDE/ODE.""" drift, diffusion = sde_fn(x, t) score = score_fn(x, t) drift = drift - batch_mul( diffusion**2, score * (0.5 if self.probability_flow else 1.)) # Set the diffusion function to zero for ODEs. diffusion = 0. if self.probability_flow else diffusion return drift, diffusion
def inpaint_update_fn(rng, state, data, mask, x, t): rng, step_rng = jax.random.split(rng) vec_t = jnp.ones(data.shape[0]) * t x, x_mean = update_fn(step_rng, state, x, vec_t) masked_data_mean, std = sde.marginal_prob(data, vec_t) masked_data = masked_data_mean + batch_mul( jax.random.normal(rng, x.shape), std) x = x * (1. - mask) + masked_data * mask x_mean = x * (1. - mask) + masked_data_mean * mask return x, x_mean
def colorization_update_fn(rng, state, gray_scale_img, x, t): mask = get_mask(x) rng, step_rng = jax.random.split(rng) vec_t = jnp.ones(x.shape[0]) * t x, x_mean = update_fn(step_rng, state, x, vec_t) masked_data_mean, std = sde.marginal_prob(decouple(gray_scale_img), vec_t) masked_data = masked_data_mean + batch_mul( jax.random.normal(rng, x.shape), std) x = couple(decouple(x) * (1. - mask) + masked_data * mask) x_mean = couple( decouple(x) * (1. - mask) + masked_data_mean * mask) return x, x_mean
def score_fn(x, t, rng=None): # For VP-trained models, t=0 corresponds to the lowest noise level labels = t * (sde.N - 1) # Scale neural network output by standard deviation and flip sign model, state = model_fn(x, labels, rng) if continuous or isinstance(sde, sde_lib.subVPSDE): std = sde.marginal_prob(jnp.zeros_like(x), t)[1] else: std = sde.sqrt_1m_alphas_cumprod[labels.astype(jnp.int32)] score = batch_mul(-model, 1. / std) if return_state: return score, state else: return score
def marginal_prob(self, x, t): log_mean_coeff = -0.25 * t**2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 mean = batch_mul(jnp.exp(log_mean_coeff), x) std = 1 - jnp.exp(2. * log_mean_coeff) return mean, std
def sde(self, x, t): beta_t = self.beta_0 + t * (self.beta_1 - self.beta_0) drift = -0.5 * batch_mul(beta_t, x) diffusion = jnp.sqrt(beta_t) return drift, diffusion
def update_fn(self, rng, x, t): f, G = self.rsde.discretize(x, t) z = random.normal(rng, x.shape) x_mean = x - f x = x_mean + batch_mul(G, z) return x, x_mean