def denoise_update_fn(model, x): score_fn = get_score_fn(sde, model, train=False, continuous=True) # Reverse diffusion predictor for denoising predictor_obj = ReverseDiffusionPredictor(sde, score_fn, probability_flow=False) vec_eps = torch.ones(x.shape[0], device=x.device) * eps _, x = predictor_obj.update_fn(x, vec_eps) return x
def loss_fn(model, batch): """Compute the loss function. Args: model: A score model. batch: A mini-batch of training data. Returns: loss: A scalar that represents the average loss value across the mini-batch. """ score_fn = mutils.get_score_fn(sde, model, train=train, continuous=continuous) t = torch.rand(batch.shape[0], device=batch.device) * (sde.T - eps) + eps z = torch.randn_like(batch) mean, std = sde.marginal_prob(batch, t) perturbed_data = mean + std[:, None, None, None] * z score = score_fn(perturbed_data, t) if not likelihood_weighting: losses = torch.square(score * std[:, None, None, None] + z) losses = reduce_op(losses.reshape(losses.shape[0], -1), dim=-1) else: g2 = sde.sde(torch.zeros_like(batch), t)[1]**2 losses = torch.square(score + z / std[:, None, None, None]) losses = reduce_op(losses.reshape(losses.shape[0], -1), dim=-1) * g2 loss = torch.mean(losses) return loss
def shared_corrector_update_fn(x, t, sde, model, corrector, continuous, snr, n_steps): """A wrapper tha configures and returns the update function of correctors.""" score_fn = mutils.get_score_fn(sde, model, train=False, continuous=continuous) if corrector is None: # Predictor-only sampler corrector_obj = NoneCorrector(sde, score_fn, snr, n_steps) else: corrector_obj = corrector(sde, score_fn, snr, n_steps) return corrector_obj.update_fn(x, t)
def shared_predictor_update_fn(x, t, sde, model, predictor, probability_flow, continuous): """A wrapper that configures and returns the update function of predictors.""" score_fn = mutils.get_score_fn(sde, model, train=False, continuous=continuous) if predictor is None: # Corrector-only sampler predictor_obj = NonePredictor(sde, score_fn, probability_flow) else: predictor_obj = predictor(sde, score_fn, probability_flow) return predictor_obj.update_fn(x, t)
def drift_fn(state, x, t): """Get the drift function of the reverse-time SDE.""" score_fn = get_score_fn(sde, model, state.params_ema, state.model_state, train=False, continuous=True) rsde = sde.reverse(score_fn, probability_flow=True) return rsde.sde(x, t)[0]
def denoise_update_fn(rng, state, x): score_fn = get_score_fn(sde, model, state.params_ema, state.model_state, train=False, continuous=True) # Reverse diffusion predictor for denoising predictor_obj = ReverseDiffusionPredictor(sde, score_fn, probability_flow=False) vec_eps = jnp.ones((x.shape[0], )) * eps _, x = predictor_obj.update_fn(rng, x, vec_eps) return x
def loss_fn(rng, params, states, batch): """Compute the loss function. Args: rng: A JAX random state. params: A dictionary that contains trainable parameters of the score-based model. states: A dictionary that contains mutable states of the score-based model. batch: A mini-batch of training data. Returns: loss: A scalar that represents the average loss value across the mini-batch. new_model_state: A dictionary that contains the mutated states of the score-based model. """ score_fn = mutils.get_score_fn(sde, model, params, states, train=train, continuous=continuous, return_state=True) data = batch['image'] rng, step_rng = random.split(rng) t = random.uniform(step_rng, (data.shape[0], ), minval=eps, maxval=sde.T) rng, step_rng = random.split(rng) z = random.normal(step_rng, data.shape) mean, std = sde.marginal_prob(data, t) perturbed_data = mean + batch_mul(std, z) rng, step_rng = random.split(rng) score, new_model_state = score_fn(perturbed_data, t, rng=step_rng) if not likelihood_weighting: losses = jnp.square(batch_mul(score, std) + z) losses = reduce_op(losses.reshape((losses.shape[0], -1)), axis=-1) else: g2 = sde.sde(jnp.zeros_like(data), t)[1]**2 losses = jnp.square(score + batch_mul(z, 1. / std)) losses = reduce_op(losses.reshape( (losses.shape[0], -1)), axis=-1) * g2 loss = jnp.mean(losses) return loss, new_model_state
def conditional_corrector_update_fn(rng, state, x, t, labels): """The corrector update function for class-conditional sampling.""" score_fn = mutils.get_score_fn(sde, score_model, state.params_ema, state.model_state, train=False, continuous=continuous) def total_grad_fn(x, t): ve_noise_scale = sde.marginal_prob(x, t)[1] return score_fn(x, t) + classifier_grad_fn(x, ve_noise_scale, labels) if corrector is None: corrector_obj = NoneCorrector(sde, total_grad_fn, snr, n_steps) else: corrector_obj = corrector(sde, total_grad_fn, snr, n_steps) return corrector_obj.update_fn(rng, x, t)
def conditional_predictor_update_fn(rng, state, x, t, labels): """The predictor update function for class-conditional sampling.""" score_fn = mutils.get_score_fn(sde, score_model, state.params_ema, state.model_state, train=False, continuous=continuous) def total_grad_fn(x, t): ve_noise_scale = sde.marginal_prob(x, t)[1] return score_fn(x, t) + classifier_grad_fn(x, ve_noise_scale, labels) if predictor is None: predictor_obj = NonePredictor(sde, total_grad_fn, probability_flow) else: predictor_obj = predictor(sde, total_grad_fn, probability_flow) return predictor_obj.update_fn(rng, x, t)
def drift_fn(state, x, t): """The drift function of the reverse-time SDE.""" score_fn = mutils.get_score_fn(sde, model, state.params_ema, state.model_state, train=False, continuous=True) # Probability flow ODE is a special case of Reverse SDE rsde = sde.reverse(score_fn, probability_flow=True) return rsde.sde(x, t)[0]