def test_get_score(self): class Capture(object): def __init__(self, vae): self._vae = vae self._vae_variational = vae.variational vae.variational = self._variational self.q_net = None def _variational(self, x, z=None, n_z=None): self.q_net = self._vae_variational(x, z=z, n_z=n_z) return self.q_net tf.set_random_seed(1234) donut = Donut( h_for_p_x=lambda x: x, h_for_q_z=lambda x: x, x_dims=5, z_dims=3, ) capture = Capture(donut.vae) donut.vae.reconstruct = Mock( wraps=lambda x: x + tf.reduce_sum(x) # only called by MCMC ) x = tf.reshape(tf.range(20, dtype=tf.float32), [4, 5]) y = tf.constant([[1, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 1, 0, 0], [0, 0, 0, 1, 0]], dtype=tf.int32) _ = donut.get_score(x, y) # ensure variables created def r_prob(x, z, n_z=None, x_in=None): if x_in is None: x_in = x q_net = donut.vae.variational(x_in, z=z, n_z=n_z) p_net = donut.vae.model(z=q_net['z'], x=x, n_z=n_z) p = p_net['x'].log_prob(group_ndims=0) if n_z is not None: p = tf.reduce_mean(p, axis=0) return p with self.test_session() as sess: ensure_variables_initialized() # test y is None donut.vae.reconstruct.reset_mock() np.testing.assert_allclose(*sess.run([ donut.get_score( x, y=None, mcmc_iteration=1, last_point_only=False), r_prob(x, z=capture.q_net['z']) ])) self.assertEqual(donut.vae.reconstruct.call_count, 0) # test mcmc_iteration is None donut.vae.reconstruct.reset_mock() np.testing.assert_allclose(*sess.run([ donut.get_score( x, y=y, mcmc_iteration=None, last_point_only=False), r_prob(x, z=capture.q_net['z']) ])) self.assertEqual(donut.vae.reconstruct.call_count, 0) # test mcmc once x2 = tf.where( tf.cast(y, dtype=tf.bool), x, x + tf.reduce_sum(x), ) donut.vae.reconstruct.reset_mock() np.testing.assert_allclose(*sess.run([ donut.get_score( x, y=y, mcmc_iteration=1, last_point_only=False), r_prob(x, z=capture.q_net['z'], x_in=x2) ])) self.assertEqual(donut.vae.reconstruct.call_count, 1) # test mcmc with n_z > 1 donut.vae.reconstruct.reset_mock() np.testing.assert_allclose(*sess.run([ donut.get_score( x, y=y, n_z=7, mcmc_iteration=1, last_point_only=False), r_prob(x, z=capture.q_net['z'], x_in=x2, n_z=7) ])) self.assertEqual(capture.q_net['z'].get_shape(), tf.TensorShape([7, 4, 3])) self.assertEqual(donut.vae.reconstruct.call_count, 1)
def test_prediction(self): np.random.seed(1234) tf.set_random_seed(1234) # test last_point_only == True donut = Donut(h_for_p_x=lambda x: x, h_for_q_z=lambda x: x, x_dims=5, z_dims=3) _ = donut.get_score(tf.zeros([4, 5], dtype=tf.float32), tf.zeros( [4, 5], dtype=tf.int32)) # ensure variables created pred = DonutPredictor(donut, n_z=2, batch_size=4) self.assertIs(pred.model, donut) with self.test_session(): ensure_variables_initialized() # test without missing res = pred.get_score(values=np.arange(5, dtype=np.float32)) self.assertEqual(res.shape, (1, )) res = pred.get_score(values=np.arange(8, dtype=np.float32)) self.assertEqual(res.shape, (4, )) res = pred.get_score(values=np.arange(10, dtype=np.float32)) self.assertEqual(res.shape, (6, )) # test with missing res = pred.get_score(values=np.arange(10, dtype=np.float32), missing=np.random.binomial(1, .5, size=10)) self.assertEqual(res.shape, (6, )) # test errors with self.test_session(): with pytest.raises(ValueError, match='`values` must be a 1-D array'): _ = pred.get_score( np.arange(10, dtype=np.float32).reshape([-1, 1])) with pytest.raises(ValueError, match='The shape of `missing` does not agree ' 'with the shape of `values`'): _ = pred.get_score(np.arange(10, dtype=np.float32), np.arange(9, dtype=np.int32)) # test last_point_only == False pred = DonutPredictor(donut, n_z=2, batch_size=4, last_point_only=False) with self.test_session(): ensure_variables_initialized() # test without missing res = pred.get_score(values=np.arange(10, dtype=np.float32)) self.assertEqual(res.shape, (6, 5)) # test with missing res = pred.get_score(values=np.arange(10, dtype=np.float32), missing=np.random.binomial(1, .5, size=10)) self.assertEqual(res.shape, (6, 5)) # test set feed_dict is_training = tf.placeholder(tf.bool, shape=()) pred = DonutPredictor(donut, n_z=2, batch_size=4, feed_dict={is_training: False}) with self.test_session(): ensure_variables_initialized() _ = pred.get_score(values=np.arange(10, dtype=np.float32))