def sample_p(self, targets_length, temp, check_invertibility=False, targets_mask=None, **kwargs): hparams = self._hparams if targets_mask is None: targets_mask = ops.sequence_mask(targets_length, hparams) decoder_self_attention_bias = ( common_attention.attention_bias_ignore_padding(1.0 - targets_mask)) batch_size, targets_max_length = ( common_layers.shape_list(targets_mask)[:2]) prior_shape = [batch_size, targets_max_length, hparams.latent_size] noise = tf.random.normal(prior_shape, stddev=temp) p_dist = None if hparams.prior_type == "standard_normal": z_p = noise elif hparams.prior_type == "diagonal_normal": diag_prior_params = ops.cond_prior("diag_prior", hparams, tf.zeros(prior_shape), targets_mask, hparams.latent_size * 2, decoder_self_attention_bias, **kwargs) p_dist = gops.diagonal_normal(diag_prior_params, "diag_prior") z_p = p_dist.loc + p_dist.scale * noise elif hparams.prior_type in ["affine", "additive", "rq"]: n_levels = len(hparams.depths.split("/")) divi = max(1, hparams.factor**(n_levels - 1)) flow_prior_shape = [ batch_size, targets_max_length // divi, hparams.latent_size ] noise = tf.random_normal(flow_prior_shape, stddev=temp) z_p, _, _, _ = glow.glow("glow", noise, targets_mask, decoder_self_attention_bias, inverse=True, init=False, hparams=self._fparams, disable_dropout=True, temp=temp, **kwargs) if self.is_evaluating and check_invertibility: noise_inv, _, _, _ = glow.glow("glow", z_p, targets_mask, decoder_self_attention_bias, inverse=False, init=False, hparams=self._fparams, disable_dropout=True, **kwargs) z_diff = noise - noise_inv tf.summary.scalar("flow_recon_inverse", tf.reduce_max(tf.abs(z_diff))) return z_p, p_dist
def compute_prior_log_prob( self, z_q, targets_mask, decoder_self_attention_bias, check_invertibility=False, **kwargs): hparams = self._hparams batch_size, targets_max_length = ( common_layers.shape_list(targets_mask)[:2]) prior_shape = [batch_size, targets_max_length, hparams.latent_size] log_abs_det = tf.zeros([batch_size]) if hparams.prior_type == "standard_normal": log_p_z_base = gops.standard_normal_density(z_q, targets_mask) elif hparams.prior_type == "diagonal_normal": diag_prior_params = ops.cond_prior( "diag_prior", hparams, tf.zeros(prior_shape), targets_mask, hparams.latent_size*2, decoder_self_attention_bias, **kwargs) p_dist = gops.diagonal_normal(diag_prior_params, "diag_prior") log_p_z_base = p_dist.log_prob(z_q) # [B, L, C] log_p_z_base = gops.reduce_sum_over_lc(log_p_z_base, targets_mask) # [B] elif hparams.prior_type in ["affine", "additive", "rq"]: if self.is_evaluating: disable_dropout = True init = False elif self.is_training: disable_dropout = False init = tf.equal(hparams.kl_startup_steps, tf.cast(tf.train.get_global_step(), tf.int32)) else: raise ValueError("compute_prior shouldn't be used in decoding.") z_inv, log_abs_det, log_p_z_base, zs = glow.glow( "glow", z_q, targets_mask, decoder_self_attention_bias, inverse=False, init=init, hparams=self._fparams, disable_dropout=disable_dropout, **kwargs) if self.is_evaluating and check_invertibility: z_inv_inv, _, _, _ = glow.glow( "glow", z_inv, targets_mask, decoder_self_attention_bias, inverse=True, split_zs=zs, init=False, hparams=self._fparams, disable_dropout=True, **kwargs) z_diff = z_q - z_inv_inv tf.summary.scalar("flow_recon_forward", tf.reduce_max(tf.abs(z_diff))) return log_p_z_base, log_abs_det
def test_aaa_glow_training(self, depths, split_plans, prior_type): with tf.Graph().as_default(): _, x_mask, _ = self.get_data() x = tf.random_normal((BATCH_SIZE, TARGET_LENGTH, N_CHANNELS), mean=10.0, stddev=3.0, dtype=DTYPE) bias = common_attention.attention_bias_ignore_padding(1.0 - x_mask) hparams = self.get_hparams() hparams.prior_type = prior_type hparams.depths = depths hparams.split_plans = split_plans n_levels = len(hparams.depths.split("/")) kwargs = self.get_kwargs(x_mask, hparams) _ = kwargs.pop("decoder_self_attention_bias") x_inv, _, _, _ = glow.glow( "glow", x, x_mask, bias, inverse=False, init=True, disable_dropout=True, **kwargs) curr_dir = tempfile.mkdtemp() model_path = os.path.join(curr_dir, "model") with tf.Session() as session: saver = tf.train.Saver() session.run(tf.global_variables_initializer()) session.run(x_inv) saver.save(session, model_path) with tf.Graph().as_default(): _, x_mask, _ = self.get_data() x = tf.random_normal((BATCH_SIZE, TARGET_LENGTH, N_CHANNELS), mean=10.0, stddev=3.0, dtype=DTYPE) bias = common_attention.attention_bias_ignore_padding(1.0 - x_mask) hparams = self.get_hparams() hparams.depths = depths hparams.split_plans = split_plans kwargs = self.get_kwargs(x_mask, hparams) _ = kwargs.pop("decoder_self_attention_bias") log_q_z = gops.standard_normal_density(x, x_mask) log_q_z = tf.reduce_sum(log_q_z) / tf.reduce_sum(x_mask) x_inv, logabsdets, log_ps, zs = glow.glow( "glow", x, x_mask, bias, inverse=False, init=False, disable_dropout=True, **kwargs) x_inv_inv, logabsdets_inv, log_ps_inv, _ = glow.glow( "glow", x_inv, x_mask, bias, inverse=True, split_zs=zs, init=False, disable_dropout=True, **kwargs) logabsdets = tf.reduce_sum( logabsdets, axis=0) / tf.reduce_sum(x_mask) logabsdets_inv = tf.reduce_sum( logabsdets_inv, axis=0) / tf.reduce_sum(x_mask) log_ps = tf.reduce_sum(log_ps, axis=0) / tf.reduce_sum(x_mask) log_ps_inv = tf.reduce_sum(log_ps_inv, axis=0) / tf.reduce_sum(x_mask) with tf.Session() as session: saver = tf.train.Saver() saver.restore(session, model_path) (x, x_inv, x_inv_inv, log_q_z, logabsdets, log_ps, logabsdets_inv, log_ps_inv) = session.run([ x, x_inv, x_inv_inv, log_q_z, logabsdets, log_ps, logabsdets_inv, log_ps_inv]) diff = x - x_inv_inv log_ps_diff = log_ps - log_ps_inv logabsdets_sum = logabsdets + logabsdets_inv self.assertEqual( x_inv.shape, (BATCH_SIZE, TARGET_LENGTH//(2**(n_levels-1)), N_CHANNELS)) print (np.max(np.abs(diff))) print (np.max(np.abs(log_ps_diff))) print (np.max(np.abs(logabsdets_sum))) self.assertTrue(np.allclose(diff, 0.0, atol=1e-4), msg=np.max(np.abs(diff))) self.assertTrue(np.allclose(log_ps_diff, 0.0, atol=1e-4), msg=np.max(np.abs(log_ps_diff))) self.assertTrue(np.allclose(logabsdets_sum, 0.0, atol=1e-4), msg=np.max(np.abs(logabsdets_sum)))