Exemplo n.º 1
0
def neg_log_likelihood_loss(nn_out, images):
  # The log-likelihood in bits per pixel-channel
  means, inv_scales, logit_weights = (
      pixelcnn.conditional_params_from_outputs(nn_out, images))
  log_likelihoods = pixelcnn.logprob_from_conditional_params(
      images, means, inv_scales, logit_weights)
  return -jnp.mean(log_likelihoods) / (jnp.log(2) * np.prod(images.shape[-3:]))
Exemplo n.º 2
0
def sample_iteration(config, rng, params, sample):
    """PixelCNN++ sampling expressed as a fixed-point iteration."""
    rng, dropout_rng = random.split(rng)
    out = train.model(config).apply({'params': params},
                                    sample,
                                    rngs={'dropout': dropout_rng})
    c_params = pixelcnn.conditional_params_from_outputs(out, sample)
    return conditional_params_to_sample(rng, c_params)
Exemplo n.º 3
0
def sample_iteration(rng, model, sample):
    """PixelCNN++ sampling expressed as a fixed-point iteration.
  """
    c_params = pixelcnn.conditional_params_from_outputs(model(sample), sample)
    return conditional_params_to_sample(rng, c_params)