Ejemplo n.º 1
0
  def ode_sampler(model, z=None):
    """The probability flow ODE sampler with black-box ODE solver.

    Args:
      model: A score model.
      z: If present, generate samples from latent code `z`.
    Returns:
      samples, number of function evaluations.
    """
    with torch.no_grad():
      # Initial sample
      if z is None:
        # If not represent, sample the latent code from the prior distibution of the SDE.
        x = sde.prior_sampling(shape).to(device)
      else:
        x = z

      def ode_func(t, x):
        x = from_flattened_numpy(x, shape).to(device).type(torch.float32)
        vec_t = torch.ones(shape[0], device=x.device) * t
        drift = drift_fn(model, x, vec_t)
        return to_flattened_numpy(drift)

      # Black-box ODE solver for the probability flow ODE
      solution = integrate.solve_ivp(ode_func, (sde.T, eps), to_flattened_numpy(x),
                                     rtol=rtol, atol=atol, method=method)
      nfe = solution.nfev
      x = torch.tensor(solution.y[:, -1]).reshape(shape).to(device).type(torch.float32)

      # Denoising is equivalent to running one predictor step without adding noise
      if denoise:
        x = denoise_update_fn(model, x)

      x = inverse_scaler(x)
      return x, nfe
Ejemplo n.º 2
0
  def likelihood_fn(prng, pstate, data):
    """Compute an unbiased estimate to the log-likelihood in bits/dim.

    Args:
      prng: An array of random states. The list dimension equals the number of devices.
      pstate: Replicated training state for running on multiple devices.
      data: A JAX array of shape [#devices, batch size, ...].

    Returns:
      bpd: A JAX array of shape [#devices, batch size]. The log-likelihoods on `data` in bits/dim.
      z: A JAX array of the same shape as `data`. The latent representation of `data` under the
        probability flow ODE.
      nfe: An integer. The number of function evaluations used for running the black-box ODE solver.
    """
    rng, step_rng = jax.random.split(flax.jax_utils.unreplicate(prng))
    shape = data.shape
    if hutchinson_type == 'Gaussian':
      epsilon = jax.random.normal(step_rng, shape)
    elif hutchinson_type == 'Rademacher':
      epsilon = jax.random.randint(step_rng, shape,
                                   minval=0, maxval=2).astype(jnp.float32) * 2 - 1
    else:
      raise NotImplementedError(f"Hutchinson type {hutchinson_type} unknown.")

    def ode_func(t, x):
      sample = mutils.from_flattened_numpy(x[:-shape[0] * shape[1]], shape)
      vec_t = jnp.ones((sample.shape[0], sample.shape[1])) * t
      drift = mutils.to_flattened_numpy(p_drift_fn(pstate, sample, vec_t))
      logp_grad = mutils.to_flattened_numpy(p_div_fn(pstate, sample, vec_t, epsilon))
      return np.concatenate([drift, logp_grad], axis=0)

    init = jnp.concatenate([mutils.to_flattened_numpy(data), np.zeros((shape[0] * shape[1],))], axis=0)
    solution = integrate.solve_ivp(ode_func, (eps, sde.T), init, rtol=rtol, atol=atol, method=method)
    nfe = solution.nfev
    zp = jnp.asarray(solution.y[:, -1])
    z = mutils.from_flattened_numpy(zp[:-shape[0] * shape[1]], shape)
    delta_logp = zp[-shape[0] * shape[1]:].reshape((shape[0], shape[1]))
    prior_logp = p_prior_logp_fn(z)
    bpd = -(prior_logp + delta_logp) / np.log(2)
    N = np.prod(shape[2:])
    bpd = bpd / N
    # A hack to convert log-likelihoods to bits/dim
    # based on the gradient of the inverse data normalizer.
    offset = jnp.log2(jax.grad(inverse_scaler)(0.)) + 8.
    bpd += offset
    return bpd, z, nfe
Ejemplo n.º 3
0
    def ode_sampler(prng, pstate, z=None):
        """The probability flow ODE sampler with black-box ODE solver.

    Args:
      prng: An array of random state. The leading dimension equals the number of devices.
      pstate: Replicated training state for running on multiple devices.
      z: If present, generate samples from latent code `z`.
    Returns:
      Samples, and the number of function evaluations.
    """
        # Initial sample
        rng = flax.jax_utils.unreplicate(prng)
        rng, step_rng = random.split(rng)
        if z is None:
            # If not represent, sample the latent code from the prior distibution of the SDE.
            x = sde.prior_sampling(step_rng,
                                   (jax.local_device_count(), ) + shape)
        else:
            x = z

        def ode_func(t, x):
            x = from_flattened_numpy(x, (jax.local_device_count(), ) + shape)
            vec_t = jnp.ones((x.shape[0], x.shape[1])) * t
            drift = drift_fn(pstate, x, vec_t)
            return to_flattened_numpy(drift)

        # Black-box ODE solver for the probability flow ODE
        solution = integrate.solve_ivp(ode_func, (sde.T, eps),
                                       to_flattened_numpy(x),
                                       rtol=rtol,
                                       atol=atol,
                                       method=method)
        nfe = solution.nfev
        x = jnp.asarray(
            solution.y[:, -1]).reshape((jax.local_device_count(), ) + shape)

        # Denoising is equivalent to running one predictor step without adding noise
        if denoise:
            rng, *step_rng = random.split(rng, jax.local_device_count() + 1)
            step_rng = jnp.asarray(step_rng)
            x = denoise_update_fn(step_rng, pstate, x)

        x = inverse_scaler(x)
        return x, nfe
Ejemplo n.º 4
0
  def likelihood_fn(model, data):
    """Compute an unbiased estimate to the log-likelihood in bits/dim.

    Args:
      model: A score model.
      data: A PyTorch tensor.

    Returns:
      bpd: A PyTorch tensor of shape [batch size]. The log-likelihoods on `data` in bits/dim.
      z: A PyTorch tensor of the same shape as `data`. The latent representation of `data` under the
        probability flow ODE.
      nfe: An integer. The number of function evaluations used for running the black-box ODE solver.
    """
    with torch.no_grad():
      shape = data.shape
      if hutchinson_type == 'Gaussian':
        epsilon = torch.randn_like(data)
      elif hutchinson_type == 'Rademacher':
        epsilon = torch.randint_like(data, low=0, high=2).float() * 2 - 1.
      else:
        raise NotImplementedError(f"Hutchinson type {hutchinson_type} unknown.")

      def ode_func(t, x):
        sample = mutils.from_flattened_numpy(x[:-shape[0]], shape).to(data.device).type(torch.float32)
        vec_t = torch.ones(sample.shape[0], device=sample.device) * t
        drift = mutils.to_flattened_numpy(drift_fn(model, sample, vec_t))
        logp_grad = mutils.to_flattened_numpy(div_fn(model, sample, vec_t, epsilon))
        return np.concatenate([drift, logp_grad], axis=0)

      init = np.concatenate([mutils.to_flattened_numpy(data), np.zeros((shape[0],))], axis=0)
      solution = integrate.solve_ivp(ode_func, (eps, sde.T), init, rtol=rtol, atol=atol, method=method)
      nfe = solution.nfev
      zp = solution.y[:, -1]
      z = mutils.from_flattened_numpy(zp[:-shape[0]], shape).to(data.device).type(torch.float32)
      delta_logp = mutils.from_flattened_numpy(zp[-shape[0]:], (shape[0],)).to(data.device).type(torch.float32)
      prior_logp = sde.prior_logp(z)
      bpd = -(prior_logp + delta_logp) / np.log(2)
      N = np.prod(shape[1:])
      bpd = bpd / N
      # A hack to convert log-likelihoods to bits/dim
      offset = 7. - inverse_scaler(-1.)
      bpd = bpd + offset
      return bpd, z, nfe
Ejemplo n.º 5
0
 def ode_func(t, x):
   sample = mutils.from_flattened_numpy(x[:-shape[0] * shape[1]], shape)
   vec_t = jnp.ones((sample.shape[0], sample.shape[1])) * t
   drift = mutils.to_flattened_numpy(p_drift_fn(pstate, sample, vec_t))
   logp_grad = mutils.to_flattened_numpy(p_div_fn(pstate, sample, vec_t, epsilon))
   return np.concatenate([drift, logp_grad], axis=0)
Ejemplo n.º 6
0
 def ode_func(t, x):
     x = from_flattened_numpy(x, (jax.local_device_count(), ) + shape)
     vec_t = jnp.ones((x.shape[0], x.shape[1])) * t
     drift = drift_fn(pstate, x, vec_t)
     return to_flattened_numpy(drift)
Ejemplo n.º 7
0
 def ode_func(t, x):
   x = from_flattened_numpy(x, shape).to(device).type(torch.float32)
   vec_t = torch.ones(shape[0], device=x.device) * t
   drift = drift_fn(model, x, vec_t)
   return to_flattened_numpy(drift)
Ejemplo n.º 8
0
 def ode_func(t, x):
   sample = mutils.from_flattened_numpy(x[:-shape[0]], shape).to(data.device).type(torch.float32)
   vec_t = torch.ones(sample.shape[0], device=sample.device) * t
   drift = mutils.to_flattened_numpy(drift_fn(model, sample, vec_t))
   logp_grad = mutils.to_flattened_numpy(div_fn(model, sample, vec_t, epsilon))
   return np.concatenate([drift, logp_grad], axis=0)