def test_u_valid(self): """Tests that we can initialize U without error.""" num_samples = 16 x = tf.random.normal([num_samples, N_WH, N_WH, N_CH], seed=12) u = train_ebm.EbmConv(anchor_size=1) energy_x = u(x) # This should have a higher energy due to the quadratic prior. energy_x_far = u(x + tf.ones_like(x)) self.assertTrue(np.all(energy_x < energy_x_far))
def test_train_p(self): """Verify that train_p doesn't raise any exceptions.""" num_samples = 16 x = tf.random.normal([num_samples, N_WH, N_WH, N_CH], seed=12) q = train_ebm.MeanFieldGaussianQ() u = train_ebm.EbmConv(anchor_size=1) opt = tf.optimizers.Adam() x = tf.random.normal([num_samples, N_WH, N_WH, N_CH], seed=13) (x_neg_q, x_neg_p, p_accept, step_size, pos_e, pos_e_updated, neg_e_q, neg_e_p, neg_e_p_updated) = train_ebm.train_p(q, u, x, 0.1, opt) self.assertTrue(np.all(np.isfinite(x_neg_q))) self.assertTrue(np.all(np.isfinite(x_neg_p))) self.assertTrue(np.all(np.isfinite(p_accept))) self.assertTrue(np.all(np.isfinite(step_size))) self.assertTrue(np.all(np.isfinite(pos_e))) self.assertTrue(np.all(np.isfinite(pos_e_updated))) self.assertTrue(np.all(np.isfinite(neg_e_q))) self.assertTrue(np.all(np.isfinite(neg_e_p))) self.assertTrue(np.all(np.isfinite(neg_e_p_updated)))