def test_u_valid(self):
     """Tests that we can initialize U without error."""
     num_samples = 16
     x = tf.random.normal([num_samples, N_WH, N_WH, N_CH], seed=12)
     u = train_ebm.EbmConv(anchor_size=1)
     energy_x = u(x)
     # This should have a higher energy due to the quadratic prior.
     energy_x_far = u(x + tf.ones_like(x))
     self.assertTrue(np.all(energy_x < energy_x_far))
 def test_train_p(self):
     """Verify that train_p doesn't raise any exceptions."""
     num_samples = 16
     x = tf.random.normal([num_samples, N_WH, N_WH, N_CH], seed=12)
     q = train_ebm.MeanFieldGaussianQ()
     u = train_ebm.EbmConv(anchor_size=1)
     opt = tf.optimizers.Adam()
     x = tf.random.normal([num_samples, N_WH, N_WH, N_CH], seed=13)
     (x_neg_q, x_neg_p, p_accept, step_size, pos_e, pos_e_updated, neg_e_q,
      neg_e_p, neg_e_p_updated) = train_ebm.train_p(q, u, x, 0.1, opt)
     self.assertTrue(np.all(np.isfinite(x_neg_q)))
     self.assertTrue(np.all(np.isfinite(x_neg_p)))
     self.assertTrue(np.all(np.isfinite(p_accept)))
     self.assertTrue(np.all(np.isfinite(step_size)))
     self.assertTrue(np.all(np.isfinite(pos_e)))
     self.assertTrue(np.all(np.isfinite(pos_e_updated)))
     self.assertTrue(np.all(np.isfinite(neg_e_q)))
     self.assertTrue(np.all(np.isfinite(neg_e_p)))
     self.assertTrue(np.all(np.isfinite(neg_e_p_updated)))