def sample_p(self,
                 targets_length,
                 temp,
                 check_invertibility=False,
                 targets_mask=None,
                 **kwargs):
        hparams = self._hparams
        if targets_mask is None:
            targets_mask = ops.sequence_mask(targets_length, hparams)
        decoder_self_attention_bias = (
            common_attention.attention_bias_ignore_padding(1.0 - targets_mask))
        batch_size, targets_max_length = (
            common_layers.shape_list(targets_mask)[:2])
        prior_shape = [batch_size, targets_max_length, hparams.latent_size]
        noise = tf.random.normal(prior_shape, stddev=temp)
        p_dist = None

        if hparams.prior_type == "standard_normal":
            z_p = noise
        elif hparams.prior_type == "diagonal_normal":
            diag_prior_params = ops.cond_prior("diag_prior", hparams,
                                               tf.zeros(prior_shape),
                                               targets_mask,
                                               hparams.latent_size * 2,
                                               decoder_self_attention_bias,
                                               **kwargs)
            p_dist = gops.diagonal_normal(diag_prior_params, "diag_prior")
            z_p = p_dist.loc + p_dist.scale * noise
        elif hparams.prior_type in ["affine", "additive", "rq"]:
            n_levels = len(hparams.depths.split("/"))
            divi = max(1, hparams.factor**(n_levels - 1))
            flow_prior_shape = [
                batch_size, targets_max_length // divi, hparams.latent_size
            ]
            noise = tf.random_normal(flow_prior_shape, stddev=temp)
            z_p, _, _, _ = glow.glow("glow",
                                     noise,
                                     targets_mask,
                                     decoder_self_attention_bias,
                                     inverse=True,
                                     init=False,
                                     hparams=self._fparams,
                                     disable_dropout=True,
                                     temp=temp,
                                     **kwargs)
            if self.is_evaluating and check_invertibility:
                noise_inv, _, _, _ = glow.glow("glow",
                                               z_p,
                                               targets_mask,
                                               decoder_self_attention_bias,
                                               inverse=False,
                                               init=False,
                                               hparams=self._fparams,
                                               disable_dropout=True,
                                               **kwargs)
                z_diff = noise - noise_inv
                tf.summary.scalar("flow_recon_inverse",
                                  tf.reduce_max(tf.abs(z_diff)))
        return z_p, p_dist
 def delta_posterior(self, z, targets_mask, decoder_self_attention_bias,
                     n_gibbs_steps, **kwargs):
     hparams = self._hparams
     for _ in range(n_gibbs_steps):
         _, targets_emb = self.argmax_decode(z, decoder_self_attention_bias,
                                             **kwargs)
         q_params = ops.posterior("posterior", hparams, targets_emb,
                                  targets_mask, decoder_self_attention_bias,
                                  **kwargs)
         q_dist = gops.diagonal_normal(q_params, "posterior")
         z = q_dist.loc  # [B, L, C]
     return z
  def compute_prior_log_prob(
      self, z_q, targets_mask, decoder_self_attention_bias,
      check_invertibility=False, **kwargs):
    hparams = self._hparams
    batch_size, targets_max_length = (
        common_layers.shape_list(targets_mask)[:2])
    prior_shape = [batch_size, targets_max_length, hparams.latent_size]
    log_abs_det = tf.zeros([batch_size])

    if hparams.prior_type == "standard_normal":
      log_p_z_base = gops.standard_normal_density(z_q, targets_mask)
    elif hparams.prior_type == "diagonal_normal":
      diag_prior_params = ops.cond_prior(
          "diag_prior", hparams, tf.zeros(prior_shape), targets_mask,
          hparams.latent_size*2, decoder_self_attention_bias, **kwargs)
      p_dist = gops.diagonal_normal(diag_prior_params, "diag_prior")
      log_p_z_base = p_dist.log_prob(z_q)  # [B, L, C]
      log_p_z_base = gops.reduce_sum_over_lc(log_p_z_base, targets_mask)  # [B]
    elif hparams.prior_type in ["affine", "additive", "rq"]:
      if self.is_evaluating:
        disable_dropout = True
        init = False
      elif self.is_training:
        disable_dropout = False
        init = tf.equal(hparams.kl_startup_steps,
                        tf.cast(tf.train.get_global_step(), tf.int32))
      else:
        raise ValueError("compute_prior shouldn't be used in decoding.")

      z_inv, log_abs_det, log_p_z_base, zs = glow.glow(
          "glow", z_q, targets_mask, decoder_self_attention_bias,
          inverse=False, init=init, hparams=self._fparams,
          disable_dropout=disable_dropout, **kwargs)
      if self.is_evaluating and check_invertibility:
        z_inv_inv, _, _, _ = glow.glow(
            "glow", z_inv, targets_mask, decoder_self_attention_bias,
            inverse=True, split_zs=zs, init=False, hparams=self._fparams,
            disable_dropout=True, **kwargs)
        z_diff = z_q - z_inv_inv
        tf.summary.scalar("flow_recon_forward", tf.reduce_max(tf.abs(z_diff)))
    return log_p_z_base, log_abs_det
 def sample_q(
     self, targets, targets_mask, decoder_self_attention_bias, n_samples,
     temp, **kwargs):
   hparams = self._hparams
   batch_size, targets_max_length = common_layers.shape_list(targets_mask)[:2]
   q_params = ops.posterior("posterior", hparams, targets, targets_mask,
                            decoder_self_attention_bias, **kwargs)
   q_dist = gops.diagonal_normal(q_params, "posterior")
   loc, scale = q_dist.loc, q_dist.scale
   z_shape = [batch_size, targets_max_length, hparams.latent_size]
   iw_z_shape = [n_samples*batch_size, targets_max_length, hparams.latent_size]
   if n_samples == 1:
     noise = tf.random_normal(z_shape, stddev=temp)
     z_q = loc + scale * noise
     log_q_z = q_dist.log_prob(z_q)  # [B, L, C]
   else:
     noise = tf.random_normal([n_samples] + z_shape, stddev=temp)
     z_q = loc[tf.newaxis, ...] + scale[tf.newaxis, ...] * noise
     log_q_z = q_dist.log_prob(z_q)  # [K, B, L, C]
     z_q = tf.reshape(z_q, iw_z_shape)
     log_q_z = tf.reshape(log_q_z, iw_z_shape)
   return z_q, log_q_z, q_dist