def denoise_update_fn(model, x):
   score_fn = get_score_fn(sde, model, train=False, continuous=True)
   # Reverse diffusion predictor for denoising
   predictor_obj = ReverseDiffusionPredictor(sde, score_fn, probability_flow=False)
   vec_eps = torch.ones(x.shape[0], device=x.device) * eps
   _, x = predictor_obj.update_fn(x, vec_eps)
   return x
Beispiel #2
0
    def loss_fn(model, batch):
        """Compute the loss function.

    Args:
      model: A score model.
      batch: A mini-batch of training data.

    Returns:
      loss: A scalar that represents the average loss value across the mini-batch.
    """
        score_fn = mutils.get_score_fn(sde,
                                       model,
                                       train=train,
                                       continuous=continuous)
        t = torch.rand(batch.shape[0],
                       device=batch.device) * (sde.T - eps) + eps
        z = torch.randn_like(batch)
        mean, std = sde.marginal_prob(batch, t)
        perturbed_data = mean + std[:, None, None, None] * z
        score = score_fn(perturbed_data, t)

        if not likelihood_weighting:
            losses = torch.square(score * std[:, None, None, None] + z)
            losses = reduce_op(losses.reshape(losses.shape[0], -1), dim=-1)
        else:
            g2 = sde.sde(torch.zeros_like(batch), t)[1]**2
            losses = torch.square(score + z / std[:, None, None, None])
            losses = reduce_op(losses.reshape(losses.shape[0], -1),
                               dim=-1) * g2

        loss = torch.mean(losses)
        return loss
def shared_corrector_update_fn(x, t, sde, model, corrector, continuous, snr, n_steps):
  """A wrapper tha configures and returns the update function of correctors."""
  score_fn = mutils.get_score_fn(sde, model, train=False, continuous=continuous)
  if corrector is None:
    # Predictor-only sampler
    corrector_obj = NoneCorrector(sde, score_fn, snr, n_steps)
  else:
    corrector_obj = corrector(sde, score_fn, snr, n_steps)
  return corrector_obj.update_fn(x, t)
def shared_predictor_update_fn(x, t, sde, model, predictor, probability_flow, continuous):
  """A wrapper that configures and returns the update function of predictors."""
  score_fn = mutils.get_score_fn(sde, model, train=False, continuous=continuous)
  if predictor is None:
    # Corrector-only sampler
    predictor_obj = NonePredictor(sde, score_fn, probability_flow)
  else:
    predictor_obj = predictor(sde, score_fn, probability_flow)
  return predictor_obj.update_fn(x, t)
Beispiel #5
0
 def drift_fn(state, x, t):
     """Get the drift function of the reverse-time SDE."""
     score_fn = get_score_fn(sde,
                             model,
                             state.params_ema,
                             state.model_state,
                             train=False,
                             continuous=True)
     rsde = sde.reverse(score_fn, probability_flow=True)
     return rsde.sde(x, t)[0]
Beispiel #6
0
 def denoise_update_fn(rng, state, x):
     score_fn = get_score_fn(sde,
                             model,
                             state.params_ema,
                             state.model_state,
                             train=False,
                             continuous=True)
     # Reverse diffusion predictor for denoising
     predictor_obj = ReverseDiffusionPredictor(sde,
                                               score_fn,
                                               probability_flow=False)
     vec_eps = jnp.ones((x.shape[0], )) * eps
     _, x = predictor_obj.update_fn(rng, x, vec_eps)
     return x
Beispiel #7
0
    def loss_fn(rng, params, states, batch):
        """Compute the loss function.

    Args:
      rng: A JAX random state.
      params: A dictionary that contains trainable parameters of the score-based model.
      states: A dictionary that contains mutable states of the score-based model.
      batch: A mini-batch of training data.

    Returns:
      loss: A scalar that represents the average loss value across the mini-batch.
      new_model_state: A dictionary that contains the mutated states of the score-based model.
    """

        score_fn = mutils.get_score_fn(sde,
                                       model,
                                       params,
                                       states,
                                       train=train,
                                       continuous=continuous,
                                       return_state=True)
        data = batch['image']

        rng, step_rng = random.split(rng)
        t = random.uniform(step_rng, (data.shape[0], ),
                           minval=eps,
                           maxval=sde.T)
        rng, step_rng = random.split(rng)
        z = random.normal(step_rng, data.shape)
        mean, std = sde.marginal_prob(data, t)
        perturbed_data = mean + batch_mul(std, z)
        rng, step_rng = random.split(rng)
        score, new_model_state = score_fn(perturbed_data, t, rng=step_rng)

        if not likelihood_weighting:
            losses = jnp.square(batch_mul(score, std) + z)
            losses = reduce_op(losses.reshape((losses.shape[0], -1)), axis=-1)
        else:
            g2 = sde.sde(jnp.zeros_like(data), t)[1]**2
            losses = jnp.square(score + batch_mul(z, 1. / std))
            losses = reduce_op(losses.reshape(
                (losses.shape[0], -1)), axis=-1) * g2

        loss = jnp.mean(losses)
        return loss, new_model_state
    def conditional_corrector_update_fn(rng, state, x, t, labels):
        """The corrector update function for class-conditional sampling."""
        score_fn = mutils.get_score_fn(sde,
                                       score_model,
                                       state.params_ema,
                                       state.model_state,
                                       train=False,
                                       continuous=continuous)

        def total_grad_fn(x, t):
            ve_noise_scale = sde.marginal_prob(x, t)[1]
            return score_fn(x, t) + classifier_grad_fn(x, ve_noise_scale,
                                                       labels)

        if corrector is None:
            corrector_obj = NoneCorrector(sde, total_grad_fn, snr, n_steps)
        else:
            corrector_obj = corrector(sde, total_grad_fn, snr, n_steps)
        return corrector_obj.update_fn(rng, x, t)
    def conditional_predictor_update_fn(rng, state, x, t, labels):
        """The predictor update function for class-conditional sampling."""
        score_fn = mutils.get_score_fn(sde,
                                       score_model,
                                       state.params_ema,
                                       state.model_state,
                                       train=False,
                                       continuous=continuous)

        def total_grad_fn(x, t):
            ve_noise_scale = sde.marginal_prob(x, t)[1]
            return score_fn(x, t) + classifier_grad_fn(x, ve_noise_scale,
                                                       labels)

        if predictor is None:
            predictor_obj = NonePredictor(sde, total_grad_fn, probability_flow)
        else:
            predictor_obj = predictor(sde, total_grad_fn, probability_flow)
        return predictor_obj.update_fn(rng, x, t)
Beispiel #10
0
 def drift_fn(state, x, t):
   """The drift function of the reverse-time SDE."""
   score_fn = mutils.get_score_fn(sde, model, state.params_ema, state.model_state, train=False, continuous=True)
   # Probability flow ODE is a special case of Reverse SDE
   rsde = sde.reverse(score_fn, probability_flow=True)
   return rsde.sde(x, t)[0]