コード例 #1
0
ファイル: util_test.py プロジェクト: yuanying-cc/vae-seq
 def test_calc_kl_analytical(self):
     hparams = hparams_mod.make_hparams(use_monte_carlo_kl=False)
     dist_a = tf.distributions.Bernoulli(probs=0.5)
     dist_b = tf.distributions.Bernoulli(probs=0.3)
     kl_div = util.calc_kl(hparams, dist_a.sample(), dist_a, dist_b)
     with self.test_session():
         self.assertAllClose(kl_div.eval(),
                             0.5 * (np.log(0.5 / 0.3) + np.log(0.5 / 0.7)))
コード例 #2
0
ファイル: vae_test.py プロジェクト: yuanying-cc/vae-seq
 def _test_vae(self, vae_type):
     """Make sure that all tensors and assertions evaluate without error."""
     hparams = hparams_mod.make_hparams(vae_type=vae_type)
     inputs, vae = _inputs_and_vae(hparams)
     tensors = _all_tensors(hparams, inputs, vae)
     with self.test_session() as sess:
         sess.run(tf.global_variables_initializer())
         sess.run(tensors)
コード例 #3
0
ファイル: util_test.py プロジェクト: yuanying-cc/vae-seq
 def test_calc_kl_mc(self):
     tf.set_random_seed(0)
     hparams = hparams_mod.make_hparams(use_monte_carlo_kl=True)
     samples = 1000
     dist_a = tf.distributions.Bernoulli(probs=tf.fill([samples], 0.5))
     dist_b = tf.distributions.Bernoulli(probs=tf.fill([samples], 0.3))
     kl_div = tf.reduce_mean(util.calc_kl(hparams, dist_a.sample(), dist_a,
                                          dist_b),
                             axis=0)
     with self.test_session():
         self.assertAllClose(kl_div.eval(),
                             0.5 * (np.log(0.5 / 0.3) + np.log(0.5 / 0.7)),
                             atol=0.05)
コード例 #4
0
ファイル: model_test.py プロジェクト: yuanying-cc/vae-seq
 def _setup_model(self, session_params):
     self.train_dataset = "train"
     self.valid_dataset = "valid"
     self.hparams = hparams_mod.make_hparams()
     self.model = MockModel(self.hparams, session_params)
コード例 #5
0
def make_hparams(flag_value=None, **kwargs):
    """Initialize HParams with the defaults in this module."""
    init = dict(_DEFAULTS)
    init.update(kwargs)
    ret = hparams_mod.make_hparams(flag_value=flag_value, **init)
    return ret