def test_train_q_rev_kl(self): """Verify that train_q_rev_kl doesn't raise any exceptions.""" q = train_ebm.MeanFieldGaussianQ() u = lambda x: tf.reduce_sum(tf.square(x), axis=[1, 2, 3]) opt = tf.optimizers.Adam() loss, entropy = train_ebm.train_q_rev_kl(q, u, opt) self.assertTrue(np.all(np.isfinite(loss))) self.assertTrue(np.all(np.isfinite(entropy)))
def test_train_q_mle(self): """Verify that train_q_mle doesn't raise any exceptions.""" num_samples = 16 x = tf.random.normal([num_samples, N_WH, N_WH, N_CH], seed=12) q = train_ebm.MeanFieldGaussianQ() opt = tf.optimizers.Adam() x = tf.random.normal([num_samples, N_WH, N_WH, N_CH], seed=13) loss = train_ebm.train_q_mle(q, x, opt) self.assertTrue(np.all(np.isfinite(loss)))
def test_train_q_rev_kl_mle(self): """Verify that train_q_rev_kl_mle doesn't raise any exceptions.""" num_samples = 16 x = tf.random.normal([num_samples, N_WH, N_WH, N_CH], seed=12) q = train_ebm.MeanFieldGaussianQ() u = lambda x: tf.reduce_sum(tf.square(x), axis=[1, 2, 3]) opt = tf.optimizers.Adam() x = tf.random.normal([num_samples, N_WH, N_WH, N_CH], seed=12) (loss, entropy, neg_e_q, mle_loss, grads_ebm_norm, grads_mle_norm) = train_ebm.train_q_rev_kl_mle(q, u, x, 1., opt) self.assertTrue(np.all(np.isfinite(loss))) self.assertTrue(np.all(np.isfinite(entropy))) self.assertTrue(np.all(np.isfinite(neg_e_q))) self.assertTrue(np.all(np.isfinite(mle_loss))) self.assertTrue(np.all(np.isfinite(grads_ebm_norm))) self.assertTrue(np.all(np.isfinite(grads_mle_norm)))
def test_train_p(self): """Verify that train_p doesn't raise any exceptions.""" num_samples = 16 x = tf.random.normal([num_samples, N_WH, N_WH, N_CH], seed=12) q = train_ebm.MeanFieldGaussianQ() u = train_ebm.EbmConv(anchor_size=1) opt = tf.optimizers.Adam() x = tf.random.normal([num_samples, N_WH, N_WH, N_CH], seed=13) (x_neg_q, x_neg_p, p_accept, step_size, pos_e, pos_e_updated, neg_e_q, neg_e_p, neg_e_p_updated) = train_ebm.train_p(q, u, x, 0.1, opt) self.assertTrue(np.all(np.isfinite(x_neg_q))) self.assertTrue(np.all(np.isfinite(x_neg_p))) self.assertTrue(np.all(np.isfinite(p_accept))) self.assertTrue(np.all(np.isfinite(step_size))) self.assertTrue(np.all(np.isfinite(pos_e))) self.assertTrue(np.all(np.isfinite(pos_e_updated))) self.assertTrue(np.all(np.isfinite(neg_e_q))) self.assertTrue(np.all(np.isfinite(neg_e_p))) self.assertTrue(np.all(np.isfinite(neg_e_p_updated)))