Exemplo n.º 1
0
 def test_train_q_rev_kl(self):
     """Verify that train_q_rev_kl doesn't raise any exceptions."""
     q = train_ebm.MeanFieldGaussianQ()
     u = lambda x: tf.reduce_sum(tf.square(x), axis=[1, 2, 3])
     opt = tf.optimizers.Adam()
     loss, entropy = train_ebm.train_q_rev_kl(q, u, opt)
     self.assertTrue(np.all(np.isfinite(loss)))
     self.assertTrue(np.all(np.isfinite(entropy)))
Exemplo n.º 2
0
 def test_train_q_mle(self):
     """Verify that train_q_mle doesn't raise any exceptions."""
     num_samples = 16
     x = tf.random.normal([num_samples, N_WH, N_WH, N_CH], seed=12)
     q = train_ebm.MeanFieldGaussianQ()
     opt = tf.optimizers.Adam()
     x = tf.random.normal([num_samples, N_WH, N_WH, N_CH], seed=13)
     loss = train_ebm.train_q_mle(q, x, opt)
     self.assertTrue(np.all(np.isfinite(loss)))
Exemplo n.º 3
0
 def test_train_q_rev_kl_mle(self):
     """Verify that train_q_rev_kl_mle doesn't raise any exceptions."""
     num_samples = 16
     x = tf.random.normal([num_samples, N_WH, N_WH, N_CH], seed=12)
     q = train_ebm.MeanFieldGaussianQ()
     u = lambda x: tf.reduce_sum(tf.square(x), axis=[1, 2, 3])
     opt = tf.optimizers.Adam()
     x = tf.random.normal([num_samples, N_WH, N_WH, N_CH], seed=12)
     (loss, entropy, neg_e_q, mle_loss, grads_ebm_norm,
      grads_mle_norm) = train_ebm.train_q_rev_kl_mle(q, u, x, 1., opt)
     self.assertTrue(np.all(np.isfinite(loss)))
     self.assertTrue(np.all(np.isfinite(entropy)))
     self.assertTrue(np.all(np.isfinite(neg_e_q)))
     self.assertTrue(np.all(np.isfinite(mle_loss)))
     self.assertTrue(np.all(np.isfinite(grads_ebm_norm)))
     self.assertTrue(np.all(np.isfinite(grads_mle_norm)))
Exemplo n.º 4
0
 def test_train_p(self):
     """Verify that train_p doesn't raise any exceptions."""
     num_samples = 16
     x = tf.random.normal([num_samples, N_WH, N_WH, N_CH], seed=12)
     q = train_ebm.MeanFieldGaussianQ()
     u = train_ebm.EbmConv(anchor_size=1)
     opt = tf.optimizers.Adam()
     x = tf.random.normal([num_samples, N_WH, N_WH, N_CH], seed=13)
     (x_neg_q, x_neg_p, p_accept, step_size, pos_e, pos_e_updated, neg_e_q,
      neg_e_p, neg_e_p_updated) = train_ebm.train_p(q, u, x, 0.1, opt)
     self.assertTrue(np.all(np.isfinite(x_neg_q)))
     self.assertTrue(np.all(np.isfinite(x_neg_p)))
     self.assertTrue(np.all(np.isfinite(p_accept)))
     self.assertTrue(np.all(np.isfinite(step_size)))
     self.assertTrue(np.all(np.isfinite(pos_e)))
     self.assertTrue(np.all(np.isfinite(pos_e_updated)))
     self.assertTrue(np.all(np.isfinite(neg_e_q)))
     self.assertTrue(np.all(np.isfinite(neg_e_p)))
     self.assertTrue(np.all(np.isfinite(neg_e_p_updated)))