def test_sgvb(self): with self.get_session(): a = StochasticTensor(Normal(0., 1.), observed=np.asarray([[0., 1., 2.]])) b = StochasticTensor(Normal(1., 2.), observed=np.asarray([[1., 2., 3.]])) c = StochasticTensor(Normal(2., 3.), observed=np.asarray([[2., 3., 4.]])) lower_bound = sgvb([a, b], [c]) self.assertEqual(lower_bound.get_shape().as_list(), [1, 3]) np.testing.assert_almost_equal(lower_bound.eval(), (a.log_prob() + b.log_prob() - c.log_prob()).eval()) lower_bound = sgvb([a], [b, c], latent_axis=0) self.assertEqual(lower_bound.get_shape().as_list(), [3]) np.testing.assert_almost_equal(lower_bound.eval(), (a.log_prob() - b.log_prob() - c.log_prob()).eval().reshape([3])) lower_bound = sgvb([a], [b, c], latent_axis=[0, 1]) self.assertEqual(lower_bound.get_shape().as_list(), []) np.testing.assert_almost_equal( lower_bound.eval(), np.mean((a.log_prob() - b.log_prob() - c.log_prob()).eval()))
def test_construction_error(self): with self.assertRaisesRegex( ValueError, 'One and only one of `samples`, `observed` ' 'should be specified.'): _ = StochasticTensor(self.distrib) with self.assertRaisesRegex( TypeError, '`distribution` is expected to be a Distribution ' 'but got .*'): _ = StochasticTensor('', tf.constant(1.))
def test_disallowed_op(self): with self.assertRaisesRegex( TypeError, '`StochasticTensor` object is not iterable.'): _ = iter(StochasticTensor(self.distrib, tf.constant(1))) with self.assertRaisesRegex( TypeError, 'Using a `StochasticTensor` as a Python `bool` ' 'is not allowed.'): _ = not StochasticTensor(self.distrib, tf.constant(1)) with self.assertRaisesRegex( TypeError, 'Using a `StochasticTensor` as a Python `bool` ' 'is not allowed.'): if StochasticTensor(self.distrib, tf.constant(1)): pass
def test_observation_dtype_cast(self): t = StochasticTensor( self.distrib, observed=tf.placeholder(tf.int32) ) self.assertEqual(self.distrib.dtype, tf.float32) self.assertEqual(t.dtype, tf.float32)
def test_set_attributes(self): t = StochasticTensor(self.distrib, tf.constant([1., 2., 3.])) self.assertTrue(hasattr(t, '_self_is_observed')) self.assertFalse(hasattr(t.__wrapped__, '_self_is_observed')) t._self_is_observed = 123 self.assertEqual(t._self_is_observed, 123) self.assertFalse(hasattr(t.__wrapped__, '_self_is_observed')) self.assertTrue(hasattr(t, 'log_prob')) self.assertFalse(hasattr(t.__wrapped__, 'log_prob')) t.log_prob = 456 self.assertEqual(t.log_prob, 456) self.assertTrue(hasattr(t, 'log_prob')) self.assertFalse(hasattr(t.__wrapped__, 'log_prob')) self.assertTrue(hasattr(t, 'get_shape')) self.assertTrue(hasattr(t.__wrapped__, 'get_shape')) t.get_shape = 789 self.assertEqual(t.get_shape, 789) self.assertEqual(t.__wrapped__.get_shape, 789) self.assertTrue(hasattr(t, 'get_shape')) self.assertTrue(hasattr(t.__wrapped__, 'get_shape')) t.abc = 1001 self.assertEqual(t.abc, 1001) self.assertEqual(t.__wrapped__.abc, 1001) self.assertTrue(hasattr(t, 'abc')) self.assertTrue(hasattr(t.__wrapped__, 'abc')) t.__wrapped__.xyz = 2002 self.assertEqual(t.xyz, 2002) self.assertEqual(t.__wrapped__.xyz, 2002) self.assertTrue(hasattr(t, 'xyz')) self.assertTrue(hasattr(t.__wrapped__, 'xyz'))
def test_dynamic_dimension_replaced_by_observed_shape(self): distrib = _MyDistribution(tf.placeholder(tf.float32, (None, 3, 4))) t = StochasticTensor( distrib, observed=tf.placeholder(tf.float32, (2, 3, None)) ) self.assertEqual(t.get_shape().as_list(), [2, 3, None])
def test_attributes_from_distribution(self): with self.get_session(): distrib = Normal(0., 1.) t = StochasticTensor(distrib, tf.constant(0.)) for k in ['is_continuous', 'is_reparameterized']: self.assertEqual(getattr(distrib, k), getattr(t, k), msg='attribute %r mismatch.' % k)
def test_error_convert_to_tensor(self): with self.assertRaisesRegex( ValueError, 'Incompatible type conversion requested to ' 'type .* for tensor of type .*'): _ = tf.convert_to_tensor( StochasticTensor(self.distrib, 1.), dtype=tf.int32 )
def test_convert_to_tensor_if_dynamic(self): for v in [ tf.placeholder(tf.int32, ()), tf.get_variable('v', shape=(), dtype=tf.int32), StochasticTensor(Normal(0., 1.), 1.) ]: self.assertIsInstance(convert_to_tensor_if_dynamic(v), tf.Tensor) for v in [1, 1.0, object(), (), [], {}, np.array([1, 2, 3])]: self.assertIs(convert_to_tensor_if_dynamic(v), v)
def test_observed_tensor(self): with self.get_session(): observed = tf.constant(12345678, dtype=tf.float32) t = StochasticTensor(self.distrib, observed=observed) self.assertIs(t.distribution, self.distrib) self.assertEqual(t.dtype, tf.float32) self.assertTrue(t.is_observed) self.assertEqual(t.eval(), 12345678) self.assertIsInstance(t.__wrapped__, tf.Tensor) self.assertEqual(t.__wrapped__.eval(), 12345678)
def test_prob_and_log_prob(self): with self.get_session(): distrib = Normal(np.asarray(0., dtype=np.float32), np.asarray([1.0, 2.0, 3.0], dtype=np.float32)) observed = np.arange(24, dtype=np.float32).reshape([4, 2, 3]) t = StochasticTensor(distrib, observed=observed) np.testing.assert_almost_equal(t.log_prob().eval(), distrib.log_prob(observed).eval()) np.testing.assert_almost_equal(t.prob().eval(), distrib.prob(observed).eval())
def test_get_attributes(self): t = StochasticTensor(self.distrib, tf.constant([1., 2., 3.])) members = dir(t) for member in ['dtype', 'log_prob', '__wrapped__']: self.assertIn( member, members, msg='%r should in dir(t), but not.' % (members,) ) self.assertTrue( hasattr(t, member), msg='StochasticTensor should has member %r, but not.' % (member,) )
def test_session_run(self): with self.get_session() as sess: # test session run t = StochasticTensor(self.distrib, tf.constant([1., 2., 3.])) np.testing.assert_almost_equal(sess.run(t), [1., 2., 3.]) # test using in feed_dict np.testing.assert_almost_equal( sess.run(tf.identity(t), feed_dict={ t: np.asarray([4., 5., 6.]) }), np.asarray([4., 5., 6.]) )
def test_prob_with_group_events_ndims(self): with self.get_session(): distrib = Normal(np.asarray(0., dtype=np.float32), np.asarray([1.0, 2.0, 3.0], dtype=np.float32), group_event_ndims=1) observed = np.asarray([[-1., 1., 2.], [0., 0., 0.]]) t = StochasticTensor(distrib, observed=observed) np.testing.assert_allclose( t.prob(group_event_ndims=0).eval(), distrib.prob(t, group_event_ndims=0).eval()) np.testing.assert_allclose( t.prob(group_event_ndims=1).eval(), distrib.prob(t, group_event_ndims=1).eval())
def test_specialized_prob_method(self): class MyNormal(Normal): @property def has_specialized_prob_method(self): return True with self.get_session(): distrib = MyNormal(np.asarray(0., dtype=np.float32), np.asarray([1.0, 2.0, 3.0], dtype=np.float32)) observed = np.arange(24, dtype=np.float32).reshape([4, 2, 3]) t = StochasticTensor(distrib, observed=observed) t.distribution._has_specialized_prob_method = True np.testing.assert_almost_equal(t.prob().eval(), distrib.prob(observed).eval())
def test_is_dynamic_tensor_like(self): for v in [ tf.placeholder(tf.int32, ()), tf.get_variable('v', shape=(), dtype=tf.int32), StochasticTensor(Normal(0., 1.), 1.) ]: self.assertTrue( is_dynamic_tensor_like(v), msg='%r should be interpreted as a dynamic tensor.' % (v, )) for v in [1, 1.0, object(), (), [], {}, np.array([1, 2, 3])]: self.assertFalse( is_dynamic_tensor_like(v), msg='%r should not be interpreted as a dynamic tensor.' % (v, ))
def sample(self, n_samples=None, group_ndims=0, is_reparameterized=None, compute_density=None, name=None): from tfsnippet.bayes import StochasticTensor group_ndims = int(group_ndims) t = self._distribution.sample(n_samples=n_samples, group_ndims=group_ndims + self._ndims, is_reparameterized=is_reparameterized, compute_density=compute_density, name=name) ret = StochasticTensor(distribution=self, tensor=t.tensor, n_samples=n_samples, group_ndims=group_ndims, is_reparameterized=t.is_reparameterized, log_prob=t._self_log_prob) ret._self_prob = t._self_prob return ret
def test_del_attributes(self): t = StochasticTensor(self.distrib, tf.constant([1., 2., 3.])) del t._self_is_observed self.assertFalse(hasattr(t, '_self_is_observed')) self.assertFalse(hasattr(t.__wrapped__, '_self_is_observed')) t.abc = 1001 del t.abc self.assertFalse(hasattr(t, 'abc')) self.assertFalse(hasattr(t.__wrapped__, 'abc')) t.__wrapped__.xyz = 2002 del t.xyz self.assertFalse(hasattr(t, 'xyz')) self.assertFalse(hasattr(t.__wrapped__, 'xyz')) t.log_prob = 123 del t.log_prob self.assertFalse(hasattr(t.__wrapped__, 'log_prob')) self.assertNotEqual(t.log_prob, 123)
def test_non_observed_tensor(self): t = StochasticTensor(self.distrib, samples=1.) self.assertIs(t.distribution, self.distrib) self.assertEqual(t.dtype, tf.float32) self.assertFalse(t.is_observed) self.assertIsInstance(t.__wrapped__, tf.Tensor)
def test_initialize_from_tensor_wrapper(self): samples = tf.constant(1.) t = StochasticTensor(self.distrib, samples=TensorWrapper(samples)) self.assertIs(t.__wrapped__, samples)
def test_equality(self): observed = tf.constant(0.) t = StochasticTensor(self.distrib, samples=observed) self.assertEqual(t, t) self.assertEqual(hash(t), hash(t)) self.assertNotEqual(StochasticTensor(self.distrib, observed), t)
def test_convert_to_tensor(self): with self.get_session(): t = StochasticTensor(self.distrib, 1.) self.assertIsInstance(tf.convert_to_tensor(t), tf.Tensor) self.assertNotIsInstance(tf.convert_to_tensor(t), StochasticTensor)