def test_transform_on_transformed(self): with self.test_session() as sess: normal = Normal(mean=tf.zeros([3, 4, 5]), logstd=0.) self.assertEqual(normal.value_ndims, 0) self.assertEqual(normal.get_batch_shape().as_list(), [3, 4, 5]) self.assertEqual(list(sess.run(normal.batch_shape)), [3, 4, 5]) distrib = normal.batch_ndims_to_value(0) self.assertIs(distrib, normal) distrib = normal.batch_ndims_to_value(1) self.assertIsInstance(distrib, BatchToValueDistribution) self.assertEqual(distrib.value_ndims, 1) self.assertEqual(distrib.get_batch_shape().as_list(), [3, 4]) self.assertEqual(list(sess.run(distrib.batch_shape)), [3, 4]) self.assertIs(distrib.base_distribution, normal) distrib2 = distrib.expand_value_ndims(1) self.assertIsInstance(distrib2, BatchToValueDistribution) self.assertEqual(distrib2.value_ndims, 2) self.assertEqual(distrib2.get_batch_shape().as_list(), [3]) self.assertEqual(list(sess.run(distrib2.batch_shape)), [3]) self.assertIs(distrib.base_distribution, normal) distrib2 = distrib.expand_value_ndims(0) self.assertIs(distrib2, distrib) self.assertEqual(distrib2.value_ndims, 1) self.assertEqual(distrib.value_ndims, 1) self.assertEqual(distrib2.get_batch_shape().as_list(), [3, 4]) self.assertEqual(list(sess.run(distrib2.batch_shape)), [3, 4]) self.assertIs(distrib.base_distribution, normal)
def __init__(self, input_q, mean_q_mlp, std_q_mlp, z_dim, window_length=100, is_reparameterized=True, check_numerics=True): self.normal = Normal(mean=tf.zeros([window_length, z_dim]), std=tf.ones([window_length, z_dim])) super(RecurrentDistribution, self).__init__(dtype=self.normal.dtype, is_continuous=True, is_reparameterized=is_reparameterized, batch_shape=self.normal.batch_shape, batch_static_shape=self.normal.get_batch_shape(), value_ndims=self.normal.value_ndims) self.std_q_mlp = std_q_mlp self.mean_q_mlp = mean_q_mlp self._check_numerics = check_numerics self.input_q = tf.transpose(input_q, [1, 0, 2]) self._dtype = input_q.dtype self._is_reparameterized = is_reparameterized self._is_continuous = True self.z_dim = z_dim self.window_length = window_length self.time_first_shape = tf.convert_to_tensor( [self.window_length, tf.shape(input_q)[0], self.z_dim])
def test_invert_flow(self): with self.test_session() as sess: # test invert a normal flow flow = QuadraticFlow(2., 5.) inv_flow = flow.invert() self.assertIsInstance(inv_flow, InvertFlow) self.assertEqual(inv_flow.x_value_ndims, 0) self.assertEqual(inv_flow.y_value_ndims, 0) self.assertFalse(inv_flow.require_batch_dims) test_x = np.arange(12, dtype=np.float32) + 1. test_y, test_log_det = quadratic_transform(npyops, test_x, 2., 5.) self.assertFalse(flow._has_built) y, log_det_y = inv_flow.inverse_transform(tf.constant(test_x)) self.assertTrue(flow._has_built) np.testing.assert_allclose(sess.run(y), test_y) np.testing.assert_allclose(sess.run(log_det_y), test_log_det) invertible_flow_standard_check(self, inv_flow, sess, test_y) # test invert an InvertFlow inv_inv_flow = inv_flow.invert() self.assertIs(inv_inv_flow, flow) # test use with FlowDistribution normal = Normal(mean=1., std=2.) inv_flow = QuadraticFlow(2., 5.).invert() distrib = FlowDistribution(normal, inv_flow) distrib_log_det = distrib.log_prob(test_x) np.testing.assert_allclose(*sess.run( [distrib_log_det, normal.log_prob(test_y) + test_log_det]))
def test_ndims_exceed_limit(self): normal = Normal(mean=tf.zeros([3, 4]), logstd=0.) with pytest.raises(ValueError, match='`distribution.batch_shape.ndims` ' 'is less then `ndims`'): _ = normal.expand_value_ndims(3)
def test_with_normal(self): mean = np.random.normal(size=[4, 5]).astype(np.float64) logstd = np.random.normal(size=mean.shape).astype(np.float64) x = np.random.normal(size=[3, 4, 5]) with self.test_session() as sess: normal = Normal(mean=mean, logstd=logstd) distrib = normal.batch_ndims_to_value(1) self.assertIsInstance(distrib, BatchToValueDistribution) self.assertEqual(distrib.value_ndims, 1) self.assertEqual(distrib.get_batch_shape().as_list(), [4]) self.assertEqual(list(sess.run(distrib.batch_shape)), [4]) self.assertEqual(distrib.dtype, tf.float64) self.assertTrue(distrib.is_continuous) self.assertTrue(distrib.is_reparameterized) self.assertIs(distrib.base_distribution, normal) log_prob = distrib.log_prob(x) log_prob2 = distrib.log_prob(x, group_ndims=1) self.assertEqual(get_static_shape(log_prob), (3, 4)) self.assertEqual(get_static_shape(log_prob2), (3, )) np.testing.assert_allclose(*sess.run( [log_prob, normal.log_prob(x, group_ndims=1)])) np.testing.assert_allclose(*sess.run( [log_prob2, normal.log_prob(x, group_ndims=2)])) prob = distrib.prob(x) prob2 = distrib.prob(x, group_ndims=1) self.assertEqual(get_static_shape(prob), (3, 4)) self.assertEqual(get_static_shape(prob2), (3, )) np.testing.assert_allclose( *sess.run([prob, normal.prob(x, group_ndims=1)])) np.testing.assert_allclose( *sess.run([prob2, normal.prob(x, group_ndims=2)])) sample = distrib.sample(3, compute_density=False) sample2 = distrib.sample(3, compute_density=True, group_ndims=1) log_prob = sample.log_prob() log_prob2 = sample2.log_prob() self.assertEqual(get_static_shape(log_prob), (3, 4)) self.assertEqual(get_static_shape(log_prob2), (3, )) np.testing.assert_allclose(*sess.run( [log_prob, normal.log_prob(sample, group_ndims=1)])) np.testing.assert_allclose(*sess.run( [log_prob2, normal.log_prob(sample2, group_ndims=2)]))
def test_value_ndims_0(self): self.do_check_mixture( lambda: Normal( mean=np.random.normal(size=[4, 5]).astype(np.float64), logstd=np.random.normal(size=[4, 5]).astype(np.float64) ), value_ndims=0, batch_shape=[4, 5], is_continuous=True, dtype=tf.float64, logits_dtype=np.float64, is_reparameterized=True )
def test_ndims_equals_zero_and_negative(self): normal = Normal(mean=tf.zeros([3, 4]), logstd=0.) self.assertIs(normal.batch_ndims_to_value(0), normal) self.assertIs(normal.expand_value_ndims(0), normal) with pytest.raises(ValueError, match='`ndims` must be non-negative integers'): _ = normal.batch_ndims_to_value(-1) with pytest.raises(ValueError, match='`ndims` must be non-negative integers'): _ = normal.expand_value_ndims(-1)
class RecurrentDistribution(Distribution): """ A multi-variable distribution integrated with recurrent structure. """ @property def dtype(self): return self._dtype @property def is_continuous(self): return self._is_continuous @property def is_reparameterized(self): return self._is_reparameterized @property def value_shape(self): return self.normal.value_shape def get_value_shape(self): return self.normal.get_value_shape() @property def batch_shape(self): return self.normal.batch_shape def get_batch_shape(self): return self.normal.get_batch_shape() def sample_step(self, a, t): z_previous, mu_q_previous, std_q_previous = a noise_n, input_q_n = t input_q_n = tf.broadcast_to(input_q_n, [ tf.shape(z_previous)[0], tf.shape(input_q_n)[0], input_q_n.shape[1] ]) input_q = tf.concat([input_q_n, z_previous], axis=-1) mu_q = self.mean_q_mlp( input_q, reuse=tf.AUTO_REUSE) # n_sample * batch_size * z_dim std_q = self.std_q_mlp(input_q) # n_sample * batch_size * z_dim temp = tf.einsum('ik,ijk->ijk', noise_n, std_q) # n_sample * batch_size * z_dim mu_q = tf.broadcast_to(mu_q, tf.shape(temp)) std_q = tf.broadcast_to(std_q, tf.shape(temp)) z_n = temp + mu_q return z_n, mu_q, std_q # @global_reuse def log_prob_step(self, _, t): given_n, input_q_n = t if len(given_n.shape) > 2: input_q_n = tf.broadcast_to(input_q_n, [ tf.shape(given_n)[0], tf.shape(input_q_n)[0], input_q_n.shape[1] ]) input_q = tf.concat([given_n, input_q_n], axis=-1) mu_q = self.mean_q_mlp(input_q, reuse=tf.AUTO_REUSE) std_q = self.std_q_mlp(input_q) logstd_q = tf.log(std_q) precision = tf.exp(-2 * logstd_q) if self._check_numerics: precision = tf.check_numerics(precision, "precision") log_prob_n = -0.9189385332046727 - logstd_q - 0.5 * precision * tf.square( tf.minimum(tf.abs(given_n - mu_q), 1e8)) return log_prob_n def __init__(self, input_q, mean_q_mlp, std_q_mlp, z_dim, window_length=100, is_reparameterized=True, check_numerics=True): self.normal = Normal(mean=tf.zeros([window_length, z_dim]), std=tf.ones([window_length, z_dim])) super(RecurrentDistribution, self).__init__(dtype=self.normal.dtype, is_continuous=True, is_reparameterized=is_reparameterized, batch_shape=self.normal.batch_shape, batch_static_shape=self.normal.get_batch_shape(), value_ndims=self.normal.value_ndims) self.std_q_mlp = std_q_mlp self.mean_q_mlp = mean_q_mlp self._check_numerics = check_numerics self.input_q = tf.transpose(input_q, [1, 0, 2]) self._dtype = input_q.dtype self._is_reparameterized = is_reparameterized self._is_continuous = True self.z_dim = z_dim self.window_length = window_length self.time_first_shape = tf.convert_to_tensor( [self.window_length, tf.shape(input_q)[0], self.z_dim]) def sample(self, n_samples=1024, is_reparameterized=None, group_ndims=0, compute_density=False, name=None): from tfsnippet.stochastic import StochasticTensor if n_samples is None: n_samples = 1 n_samples_is_none = True else: n_samples_is_none = False with tf.name_scope(name=name, default_name='sample'): noise = self.normal.sample(n_samples=n_samples) noise = tf.transpose( noise, [1, 0, 2]) # window_length * n_samples * z_dim noise = tf.truncated_normal(tf.shape(noise)) time_indices_shape = tf.convert_to_tensor( [n_samples, tf.shape(self.input_q)[1], self.z_dim]) samples = tf.scan( fn=self.sample_step, elems=(noise, self.input_q), initializer=(tf.zeros(time_indices_shape), tf.zeros(time_indices_shape), tf.ones(time_indices_shape)), back_prop=False)[ 0] # time_step * n_samples * batch_size * z_dim samples = tf.transpose( samples, [1, 2, 0, 3]) # n_samples * batch_size * time_step * z_dim if n_samples_is_none: t = StochasticTensor( distribution=self, tensor=tf.reduce_mean(samples, axis=0), n_samples=1, group_ndims=group_ndims, is_reparameterized=self.is_reparameterized) else: t = StochasticTensor( distribution=self, tensor=samples, n_samples=n_samples, group_ndims=group_ndims, is_reparameterized=self.is_reparameterized) if compute_density: with tf.name_scope('compute_prob_and_log_prob'): log_p = t.log_prob() t._self_prob = tf.exp(log_p) return t def log_prob(self, given, group_ndims=0, name=None): with tf.name_scope(name=name, default_name='log_prob'): if len(given.shape) > 3: time_indices_shape = tf.convert_to_tensor([ tf.shape(given)[0], tf.shape(self.input_q)[1], self.z_dim ]) given = tf.transpose(given, [2, 0, 1, 3]) else: time_indices_shape = tf.convert_to_tensor( [tf.shape(self.input_q)[1], self.z_dim]) given = tf.transpose(given, [1, 0, 2]) log_prob = tf.scan(fn=self.log_prob_step, elems=(given, self.input_q), initializer=tf.zeros(time_indices_shape), back_prop=False) if len(given.shape) > 3: log_prob = tf.transpose(log_prob, [1, 2, 0, 3]) else: log_prob = tf.transpose(log_prob, [1, 0, 2]) if group_ndims == 1: log_prob = tf.reduce_sum(log_prob, axis=-1) return log_prob def prob(self, given, group_ndims=0, name=None): with tf.name_scope(name=name, default_name='prob'): log_prob = self.log_prob(given, group_ndims, name) return tf.exp(log_prob)
def test_errors(self): with pytest.raises(TypeError, match='`categorical` must be a Categorical ' 'distribution'): _ = Mixture(Normal(0., 0.), [Normal(0., 0.)]) with pytest.raises(ValueError, match='Dynamic `categorical.n_categories` is not ' 'supported'): _ = Mixture(Categorical(logits=tf.placeholder(tf.float32, [None])), [Normal(0., 0.)]) with pytest.raises(ValueError, match='`components` must not be empty'): _ = Mixture(Categorical(logits=tf.zeros([5])), []) with pytest.raises(ValueError, match=r'`len\(components\)` != `categorical.' r'n_categories`: 1 vs 5'): _ = Mixture(Categorical(logits=tf.zeros([5])), [Normal(0., 0.)]) with pytest.raises(ValueError, match='`dtype` of the 1-th component does not ' 'agree with the first component'): _ = Mixture(Categorical(logits=tf.zeros([2])), [Categorical(tf.zeros([2, 3]), dtype=tf.int32), Categorical(tf.zeros([2, 3]), dtype=tf.float32)]) with pytest.raises(ValueError, match='`value_ndims` of the 1-th component does not ' 'agree with the first component'): _ = Mixture(Categorical(logits=tf.zeros([2])), [Categorical(tf.zeros([2, 3])), OnehotCategorical(tf.zeros([2, 3]))]) with pytest.raises(ValueError, match='`is_continuous` of the 1-th component does ' 'not agree with the first component'): _ = Mixture(Categorical(logits=tf.zeros([2])), [Categorical(tf.zeros([2, 3]), dtype=tf.float32), Normal(tf.zeros([2]), tf.zeros([2]))]) with pytest.raises(ValueError, match='the 0-th component is not re-parameterized'): _ = Mixture(Categorical(logits=tf.zeros([2])), [Categorical(tf.zeros([2, 3]), dtype=tf.float32), Normal(tf.zeros([2]), tf.zeros([2]))], is_reparameterized=True) with pytest.raises(RuntimeError, match='.* is not re-parameterized'): m = Mixture( Categorical(logits=tf.zeros([2])), [Normal(-1., 0.), Normal(1., 0.)] ) _ = m.sample(1, is_reparameterized=True) with pytest.raises(ValueError, match='Batch shape of `categorical` does not ' 'agree with the first component'): _ = Mixture( Categorical(logits=tf.zeros([1, 3, 2])), [Normal(mean=tf.zeros([3]), logstd=0.), Normal(mean=tf.zeros([3]), logstd=0.)] ) with pytest.raises(ValueError, match='Batch shape of the 1-th component does not ' 'agree with the first component'): _ = Mixture( Categorical(logits=tf.zeros([3, 2])), [Normal(mean=tf.zeros([3]), logstd=0.), Normal(mean=tf.zeros([4]), logstd=0.)] )