def sample_p(self,
                 targets_length,
                 temp,
                 check_invertibility=False,
                 targets_mask=None,
                 **kwargs):
        hparams = self._hparams
        if targets_mask is None:
            targets_mask = ops.sequence_mask(targets_length, hparams)
        decoder_self_attention_bias = (
            common_attention.attention_bias_ignore_padding(1.0 - targets_mask))
        batch_size, targets_max_length = (
            common_layers.shape_list(targets_mask)[:2])
        prior_shape = [batch_size, targets_max_length, hparams.latent_size]
        noise = tf.random.normal(prior_shape, stddev=temp)
        p_dist = None

        if hparams.prior_type == "standard_normal":
            z_p = noise
        elif hparams.prior_type == "diagonal_normal":
            diag_prior_params = ops.cond_prior("diag_prior", hparams,
                                               tf.zeros(prior_shape),
                                               targets_mask,
                                               hparams.latent_size * 2,
                                               decoder_self_attention_bias,
                                               **kwargs)
            p_dist = gops.diagonal_normal(diag_prior_params, "diag_prior")
            z_p = p_dist.loc + p_dist.scale * noise
        elif hparams.prior_type in ["affine", "additive", "rq"]:
            n_levels = len(hparams.depths.split("/"))
            divi = max(1, hparams.factor**(n_levels - 1))
            flow_prior_shape = [
                batch_size, targets_max_length // divi, hparams.latent_size
            ]
            noise = tf.random_normal(flow_prior_shape, stddev=temp)
            z_p, _, _, _ = glow.glow("glow",
                                     noise,
                                     targets_mask,
                                     decoder_self_attention_bias,
                                     inverse=True,
                                     init=False,
                                     hparams=self._fparams,
                                     disable_dropout=True,
                                     temp=temp,
                                     **kwargs)
            if self.is_evaluating and check_invertibility:
                noise_inv, _, _, _ = glow.glow("glow",
                                               z_p,
                                               targets_mask,
                                               decoder_self_attention_bias,
                                               inverse=False,
                                               init=False,
                                               hparams=self._fparams,
                                               disable_dropout=True,
                                               **kwargs)
                z_diff = noise - noise_inv
                tf.summary.scalar("flow_recon_inverse",
                                  tf.reduce_max(tf.abs(z_diff)))
        return z_p, p_dist
  def compute_prior_log_prob(
      self, z_q, targets_mask, decoder_self_attention_bias,
      check_invertibility=False, **kwargs):
    hparams = self._hparams
    batch_size, targets_max_length = (
        common_layers.shape_list(targets_mask)[:2])
    prior_shape = [batch_size, targets_max_length, hparams.latent_size]
    log_abs_det = tf.zeros([batch_size])

    if hparams.prior_type == "standard_normal":
      log_p_z_base = gops.standard_normal_density(z_q, targets_mask)
    elif hparams.prior_type == "diagonal_normal":
      diag_prior_params = ops.cond_prior(
          "diag_prior", hparams, tf.zeros(prior_shape), targets_mask,
          hparams.latent_size*2, decoder_self_attention_bias, **kwargs)
      p_dist = gops.diagonal_normal(diag_prior_params, "diag_prior")
      log_p_z_base = p_dist.log_prob(z_q)  # [B, L, C]
      log_p_z_base = gops.reduce_sum_over_lc(log_p_z_base, targets_mask)  # [B]
    elif hparams.prior_type in ["affine", "additive", "rq"]:
      if self.is_evaluating:
        disable_dropout = True
        init = False
      elif self.is_training:
        disable_dropout = False
        init = tf.equal(hparams.kl_startup_steps,
                        tf.cast(tf.train.get_global_step(), tf.int32))
      else:
        raise ValueError("compute_prior shouldn't be used in decoding.")

      z_inv, log_abs_det, log_p_z_base, zs = glow.glow(
          "glow", z_q, targets_mask, decoder_self_attention_bias,
          inverse=False, init=init, hparams=self._fparams,
          disable_dropout=disable_dropout, **kwargs)
      if self.is_evaluating and check_invertibility:
        z_inv_inv, _, _, _ = glow.glow(
            "glow", z_inv, targets_mask, decoder_self_attention_bias,
            inverse=True, split_zs=zs, init=False, hparams=self._fparams,
            disable_dropout=True, **kwargs)
        z_diff = z_q - z_inv_inv
        tf.summary.scalar("flow_recon_forward", tf.reduce_max(tf.abs(z_diff)))
    return log_p_z_base, log_abs_det
  def test_aaa_glow_training(self, depths, split_plans, prior_type):
    with tf.Graph().as_default():
      _, x_mask, _ = self.get_data()
      x = tf.random_normal((BATCH_SIZE, TARGET_LENGTH, N_CHANNELS),
                           mean=10.0, stddev=3.0, dtype=DTYPE)
      bias = common_attention.attention_bias_ignore_padding(1.0 - x_mask)
      hparams = self.get_hparams()
      hparams.prior_type = prior_type
      hparams.depths = depths
      hparams.split_plans = split_plans
      n_levels = len(hparams.depths.split("/"))
      kwargs = self.get_kwargs(x_mask, hparams)
      _ = kwargs.pop("decoder_self_attention_bias")

      x_inv, _, _, _ = glow.glow(
          "glow", x, x_mask, bias, inverse=False, init=True,
          disable_dropout=True, **kwargs)
      curr_dir = tempfile.mkdtemp()
      model_path = os.path.join(curr_dir, "model")

      with tf.Session() as session:
        saver = tf.train.Saver()
        session.run(tf.global_variables_initializer())
        session.run(x_inv)
        saver.save(session, model_path)

    with tf.Graph().as_default():
      _, x_mask, _ = self.get_data()
      x = tf.random_normal((BATCH_SIZE, TARGET_LENGTH, N_CHANNELS),
                           mean=10.0, stddev=3.0, dtype=DTYPE)
      bias = common_attention.attention_bias_ignore_padding(1.0 - x_mask)
      hparams = self.get_hparams()
      hparams.depths = depths
      hparams.split_plans = split_plans
      kwargs = self.get_kwargs(x_mask, hparams)
      _ = kwargs.pop("decoder_self_attention_bias")
      log_q_z = gops.standard_normal_density(x, x_mask)
      log_q_z = tf.reduce_sum(log_q_z) / tf.reduce_sum(x_mask)

      x_inv, logabsdets, log_ps, zs = glow.glow(
          "glow", x, x_mask, bias, inverse=False, init=False,
          disable_dropout=True, **kwargs)
      x_inv_inv, logabsdets_inv, log_ps_inv, _ = glow.glow(
          "glow", x_inv, x_mask, bias, inverse=True, split_zs=zs, init=False,
          disable_dropout=True, **kwargs)
      logabsdets = tf.reduce_sum(
          logabsdets, axis=0) / tf.reduce_sum(x_mask)
      logabsdets_inv = tf.reduce_sum(
          logabsdets_inv, axis=0) / tf.reduce_sum(x_mask)
      log_ps = tf.reduce_sum(log_ps, axis=0) / tf.reduce_sum(x_mask)
      log_ps_inv = tf.reduce_sum(log_ps_inv, axis=0) / tf.reduce_sum(x_mask)

      with tf.Session() as session:
        saver = tf.train.Saver()
        saver.restore(session, model_path)
        (x, x_inv, x_inv_inv, log_q_z, logabsdets, log_ps,
         logabsdets_inv, log_ps_inv) = session.run([
             x, x_inv, x_inv_inv, log_q_z, logabsdets, log_ps,
             logabsdets_inv, log_ps_inv])
        diff = x - x_inv_inv
        log_ps_diff = log_ps - log_ps_inv
        logabsdets_sum = logabsdets + logabsdets_inv
        self.assertEqual(
            x_inv.shape,
            (BATCH_SIZE, TARGET_LENGTH//(2**(n_levels-1)), N_CHANNELS))
        print (np.max(np.abs(diff)))
        print (np.max(np.abs(log_ps_diff)))
        print (np.max(np.abs(logabsdets_sum)))
        self.assertTrue(np.allclose(diff, 0.0, atol=1e-4),
                        msg=np.max(np.abs(diff)))
        self.assertTrue(np.allclose(log_ps_diff, 0.0, atol=1e-4),
                        msg=np.max(np.abs(log_ps_diff)))
        self.assertTrue(np.allclose(logabsdets_sum, 0.0, atol=1e-4),
                        msg=np.max(np.abs(logabsdets_sum)))