Example #1
0
 def loop_body(step, val):
     rng, x, x_mean = val
     grad = score_fn(x, t)
     rng, step_rng = jax.random.split(rng)
     noise = jax.random.normal(step_rng, x.shape)
     step_size = (target_snr * std)**2 * 2 * alpha
     x_mean = x + batch_mul(step_size, grad)
     x = x_mean + batch_mul(noise, jnp.sqrt(step_size * 2))
     return rng, x, x_mean
Example #2
0
 def vpsde_update_fn(self, rng, x, t):
     sde = self.sde
     timestep = (t * (sde.N - 1) / sde.T).astype(jnp.int32)
     beta = sde.discrete_betas[timestep]
     score = self.score_fn(x, t)
     x_mean = batch_mul((x + batch_mul(beta, score)),
                        1. / jnp.sqrt(1. - beta))
     noise = random.normal(rng, x.shape)
     x = x_mean + batch_mul(jnp.sqrt(beta), noise)
     return x, x_mean
Example #3
0
 def vesde_update_fn(self, rng, x, t):
     sde = self.sde
     timestep = (t * (sde.N - 1) / sde.T).astype(jnp.int32)
     sigma = sde.discrete_sigmas[timestep]
     adjacent_sigma = jnp.where(timestep == 0, jnp.zeros(t.shape),
                                sde.discrete_sigmas[timestep - 1])
     score = self.score_fn(x, t)
     x_mean = x + batch_mul(score, sigma**2 - adjacent_sigma**2)
     std = jnp.sqrt(
         (adjacent_sigma**2 * (sigma**2 - adjacent_sigma**2)) / (sigma**2))
     noise = random.normal(rng, x.shape)
     x = x_mean + batch_mul(std, noise)
     return x, x_mean
Example #4
0
 def loop_body(step, val):
     rng, x, x_mean = val
     grad = score_fn(x, t)
     rng, step_rng = jax.random.split(rng)
     noise = jax.random.normal(step_rng, x.shape)
     grad_norm = jnp.linalg.norm(grad.reshape((grad.shape[0], -1)),
                                 axis=-1).mean()
     grad_norm = jax.lax.pmean(grad_norm, axis_name='batch')
     noise_norm = jnp.linalg.norm(noise.reshape((noise.shape[0], -1)),
                                  axis=-1).mean()
     noise_norm = jax.lax.pmean(noise_norm, axis_name='batch')
     step_size = (target_snr * noise_norm / grad_norm)**2 * 2 * alpha
     x_mean = x + batch_mul(step_size, grad)
     x = x_mean + batch_mul(noise, jnp.sqrt(step_size * 2))
     return rng, x, x_mean
Example #5
0
 def sde(self, x, t):
     beta_t = self.beta_0 + t * (self.beta_1 - self.beta_0)
     drift = -0.5 * batch_mul(beta_t, x)
     discount = 1. - jnp.exp(-2 * self.beta_0 * t -
                             (self.beta_1 - self.beta_0) * t**2)
     diffusion = jnp.sqrt(beta_t * discount)
     return drift, diffusion
Example #6
0
 def update_fn(self, rng, x, t):
     dt = -1. / self.rsde.N
     z = random.normal(rng, x.shape)
     drift, diffusion = self.rsde.sde(x, t)
     x_mean = x + drift * dt
     x = x_mean + batch_mul(diffusion, jnp.sqrt(-dt) * z)
     return x, x_mean
Example #7
0
 def loss_fn(rng, params, states, batch):
     model_fn = mutils.get_model_fn(model, params, states, train=train)
     data = batch['image']
     rng, step_rng = random.split(rng)
     labels = random.choice(step_rng, vesde.N, shape=(data.shape[0], ))
     sigmas = smld_sigma_array[labels]
     rng, step_rng = random.split(rng)
     noise = batch_mul(random.normal(step_rng, data.shape), sigmas)
     perturbed_data = noise + data
     rng, step_rng = random.split(rng)
     score, new_model_state = model_fn(perturbed_data, labels, rng=step_rng)
     target = -batch_mul(noise, 1. / (sigmas**2))
     losses = jnp.square(score - target)
     losses = reduce_op(losses.reshape(
         (losses.shape[0], -1)), axis=-1) * sigmas**2
     loss = jnp.mean(losses)
     return loss, new_model_state
Example #8
0
 def discretize(self, x, t):
     """Create discretized iteration rules for the reverse diffusion sampler."""
     f, G = discretize_fn(x, t)
     rev_f = f - batch_mul(
         G**2,
         score_fn(x, t) * (0.5 if self.probability_flow else 1.))
     rev_G = 0. if self.probability_flow else G
     return rev_f, rev_G
Example #9
0
 def loss_fn(rng, params, states, batch):
     model_fn = mutils.get_model_fn(model, params, states, train=train)
     data = batch['image']
     rng, step_rng = random.split(rng)
     labels = random.choice(step_rng, vpsde.N, shape=(data.shape[0], ))
     sqrt_alphas_cumprod = vpsde.sqrt_alphas_cumprod
     sqrt_1m_alphas_cumprod = vpsde.sqrt_1m_alphas_cumprod
     rng, step_rng = random.split(rng)
     noise = random.normal(step_rng, data.shape)
     perturbed_data = batch_mul(sqrt_alphas_cumprod[labels], data) + \
                      batch_mul(sqrt_1m_alphas_cumprod[labels], noise)
     rng, step_rng = random.split(rng)
     score, new_model_state = model_fn(perturbed_data, labels, rng=step_rng)
     losses = jnp.square(score - noise)
     losses = reduce_op(losses.reshape((losses.shape[0], -1)), axis=-1)
     loss = jnp.mean(losses)
     return loss, new_model_state
Example #10
0
    def loss_fn(rng, params, states, batch):
        """Compute the loss function.

    Args:
      rng: A JAX random state.
      params: A dictionary that contains trainable parameters of the score-based model.
      states: A dictionary that contains mutable states of the score-based model.
      batch: A mini-batch of training data.

    Returns:
      loss: A scalar that represents the average loss value across the mini-batch.
      new_model_state: A dictionary that contains the mutated states of the score-based model.
    """

        score_fn = mutils.get_score_fn(sde,
                                       model,
                                       params,
                                       states,
                                       train=train,
                                       continuous=continuous,
                                       return_state=True)
        data = batch['image']

        rng, step_rng = random.split(rng)
        t = random.uniform(step_rng, (data.shape[0], ),
                           minval=eps,
                           maxval=sde.T)
        rng, step_rng = random.split(rng)
        z = random.normal(step_rng, data.shape)
        mean, std = sde.marginal_prob(data, t)
        perturbed_data = mean + batch_mul(std, z)
        rng, step_rng = random.split(rng)
        score, new_model_state = score_fn(perturbed_data, t, rng=step_rng)

        if not likelihood_weighting:
            losses = jnp.square(batch_mul(score, std) + z)
            losses = reduce_op(losses.reshape((losses.shape[0], -1)), axis=-1)
        else:
            g2 = sde.sde(jnp.zeros_like(data), t)[1]**2
            losses = jnp.square(score + batch_mul(z, 1. / std))
            losses = reduce_op(losses.reshape(
                (losses.shape[0], -1)), axis=-1) * g2

        loss = jnp.mean(losses)
        return loss, new_model_state
Example #11
0
 def discretize(self, x, t):
     """DDPM discretization."""
     timestep = (t * (self.N - 1) / self.T).astype(jnp.int32)
     beta = self.discrete_betas[timestep]
     alpha = self.alphas[timestep]
     sqrt_beta = jnp.sqrt(beta)
     f = batch_mul(jnp.sqrt(alpha), x) - x
     G = sqrt_beta
     return f, G
Example #12
0
 def sde(self, x, t):
     """Create the drift and diffusion functions for the reverse SDE/ODE."""
     drift, diffusion = sde_fn(x, t)
     score = score_fn(x, t)
     drift = drift - batch_mul(
         diffusion**2, score *
         (0.5 if self.probability_flow else 1.))
     # Set the diffusion function to zero for ODEs.
     diffusion = 0. if self.probability_flow else diffusion
     return drift, diffusion
 def inpaint_update_fn(rng, state, data, mask, x, t):
     rng, step_rng = jax.random.split(rng)
     vec_t = jnp.ones(data.shape[0]) * t
     x, x_mean = update_fn(step_rng, state, x, vec_t)
     masked_data_mean, std = sde.marginal_prob(data, vec_t)
     masked_data = masked_data_mean + batch_mul(
         jax.random.normal(rng, x.shape), std)
     x = x * (1. - mask) + masked_data * mask
     x_mean = x * (1. - mask) + masked_data_mean * mask
     return x, x_mean
 def colorization_update_fn(rng, state, gray_scale_img, x, t):
     mask = get_mask(x)
     rng, step_rng = jax.random.split(rng)
     vec_t = jnp.ones(x.shape[0]) * t
     x, x_mean = update_fn(step_rng, state, x, vec_t)
     masked_data_mean, std = sde.marginal_prob(decouple(gray_scale_img),
                                               vec_t)
     masked_data = masked_data_mean + batch_mul(
         jax.random.normal(rng, x.shape), std)
     x = couple(decouple(x) * (1. - mask) + masked_data * mask)
     x_mean = couple(
         decouple(x) * (1. - mask) + masked_data_mean * mask)
     return x, x_mean
Example #15
0
        def score_fn(x, t, rng=None):
            # For VP-trained models, t=0 corresponds to the lowest noise level
            labels = t * (sde.N - 1)
            # Scale neural network output by standard deviation and flip sign
            model, state = model_fn(x, labels, rng)
            if continuous or isinstance(sde, sde_lib.subVPSDE):
                std = sde.marginal_prob(jnp.zeros_like(x), t)[1]
            else:
                std = sde.sqrt_1m_alphas_cumprod[labels.astype(jnp.int32)]

            score = batch_mul(-model, 1. / std)
            if return_state:
                return score, state
            else:
                return score
Example #16
0
 def marginal_prob(self, x, t):
     log_mean_coeff = -0.25 * t**2 * (self.beta_1 -
                                      self.beta_0) - 0.5 * t * self.beta_0
     mean = batch_mul(jnp.exp(log_mean_coeff), x)
     std = 1 - jnp.exp(2. * log_mean_coeff)
     return mean, std
Example #17
0
 def sde(self, x, t):
     beta_t = self.beta_0 + t * (self.beta_1 - self.beta_0)
     drift = -0.5 * batch_mul(beta_t, x)
     diffusion = jnp.sqrt(beta_t)
     return drift, diffusion
Example #18
0
 def update_fn(self, rng, x, t):
     f, G = self.rsde.discretize(x, t)
     z = random.normal(rng, x.shape)
     x_mean = x - f
     x = x_mean + batch_mul(G, z)
     return x, x_mean