def test_transform_on_transformed(self): with self.test_session() as sess: normal = Normal(mean=tf.zeros([3, 4, 5]), logstd=0.) self.assertEqual(normal.value_ndims, 0) self.assertEqual(normal.get_batch_shape().as_list(), [3, 4, 5]) self.assertEqual(list(sess.run(normal.batch_shape)), [3, 4, 5]) distrib = normal.batch_ndims_to_value(0) self.assertIs(distrib, normal) distrib = normal.batch_ndims_to_value(1) self.assertIsInstance(distrib, BatchToValueDistribution) self.assertEqual(distrib.value_ndims, 1) self.assertEqual(distrib.get_batch_shape().as_list(), [3, 4]) self.assertEqual(list(sess.run(distrib.batch_shape)), [3, 4]) self.assertIs(distrib.base_distribution, normal) distrib2 = distrib.expand_value_ndims(1) self.assertIsInstance(distrib2, BatchToValueDistribution) self.assertEqual(distrib2.value_ndims, 2) self.assertEqual(distrib2.get_batch_shape().as_list(), [3]) self.assertEqual(list(sess.run(distrib2.batch_shape)), [3]) self.assertIs(distrib.base_distribution, normal) distrib2 = distrib.expand_value_ndims(0) self.assertIs(distrib2, distrib) self.assertEqual(distrib2.value_ndims, 1) self.assertEqual(distrib.value_ndims, 1) self.assertEqual(distrib2.get_batch_shape().as_list(), [3, 4]) self.assertEqual(list(sess.run(distrib2.batch_shape)), [3, 4]) self.assertIs(distrib.base_distribution, normal)
class RecurrentDistribution(Distribution): """ A multi-variable distribution integrated with recurrent structure. """ @property def dtype(self): return self._dtype @property def is_continuous(self): return self._is_continuous @property def is_reparameterized(self): return self._is_reparameterized @property def value_shape(self): return self.normal.value_shape def get_value_shape(self): return self.normal.get_value_shape() @property def batch_shape(self): return self.normal.batch_shape def get_batch_shape(self): return self.normal.get_batch_shape() def sample_step(self, a, t): z_previous, mu_q_previous, std_q_previous = a noise_n, input_q_n = t input_q_n = tf.broadcast_to(input_q_n, [ tf.shape(z_previous)[0], tf.shape(input_q_n)[0], input_q_n.shape[1] ]) input_q = tf.concat([input_q_n, z_previous], axis=-1) mu_q = self.mean_q_mlp( input_q, reuse=tf.AUTO_REUSE) # n_sample * batch_size * z_dim std_q = self.std_q_mlp(input_q) # n_sample * batch_size * z_dim temp = tf.einsum('ik,ijk->ijk', noise_n, std_q) # n_sample * batch_size * z_dim mu_q = tf.broadcast_to(mu_q, tf.shape(temp)) std_q = tf.broadcast_to(std_q, tf.shape(temp)) z_n = temp + mu_q return z_n, mu_q, std_q # @global_reuse def log_prob_step(self, _, t): given_n, input_q_n = t if len(given_n.shape) > 2: input_q_n = tf.broadcast_to(input_q_n, [ tf.shape(given_n)[0], tf.shape(input_q_n)[0], input_q_n.shape[1] ]) input_q = tf.concat([given_n, input_q_n], axis=-1) mu_q = self.mean_q_mlp(input_q, reuse=tf.AUTO_REUSE) std_q = self.std_q_mlp(input_q) logstd_q = tf.log(std_q) precision = tf.exp(-2 * logstd_q) if self._check_numerics: precision = tf.check_numerics(precision, "precision") log_prob_n = -0.9189385332046727 - logstd_q - 0.5 * precision * tf.square( tf.minimum(tf.abs(given_n - mu_q), 1e8)) return log_prob_n def __init__(self, input_q, mean_q_mlp, std_q_mlp, z_dim, window_length=100, is_reparameterized=True, check_numerics=True): self.normal = Normal(mean=tf.zeros([window_length, z_dim]), std=tf.ones([window_length, z_dim])) super(RecurrentDistribution, self).__init__(dtype=self.normal.dtype, is_continuous=True, is_reparameterized=is_reparameterized, batch_shape=self.normal.batch_shape, batch_static_shape=self.normal.get_batch_shape(), value_ndims=self.normal.value_ndims) self.std_q_mlp = std_q_mlp self.mean_q_mlp = mean_q_mlp self._check_numerics = check_numerics self.input_q = tf.transpose(input_q, [1, 0, 2]) self._dtype = input_q.dtype self._is_reparameterized = is_reparameterized self._is_continuous = True self.z_dim = z_dim self.window_length = window_length self.time_first_shape = tf.convert_to_tensor( [self.window_length, tf.shape(input_q)[0], self.z_dim]) def sample(self, n_samples=1024, is_reparameterized=None, group_ndims=0, compute_density=False, name=None): from tfsnippet.stochastic import StochasticTensor if n_samples is None: n_samples = 1 n_samples_is_none = True else: n_samples_is_none = False with tf.name_scope(name=name, default_name='sample'): noise = self.normal.sample(n_samples=n_samples) noise = tf.transpose( noise, [1, 0, 2]) # window_length * n_samples * z_dim noise = tf.truncated_normal(tf.shape(noise)) time_indices_shape = tf.convert_to_tensor( [n_samples, tf.shape(self.input_q)[1], self.z_dim]) samples = tf.scan( fn=self.sample_step, elems=(noise, self.input_q), initializer=(tf.zeros(time_indices_shape), tf.zeros(time_indices_shape), tf.ones(time_indices_shape)), back_prop=False)[ 0] # time_step * n_samples * batch_size * z_dim samples = tf.transpose( samples, [1, 2, 0, 3]) # n_samples * batch_size * time_step * z_dim if n_samples_is_none: t = StochasticTensor( distribution=self, tensor=tf.reduce_mean(samples, axis=0), n_samples=1, group_ndims=group_ndims, is_reparameterized=self.is_reparameterized) else: t = StochasticTensor( distribution=self, tensor=samples, n_samples=n_samples, group_ndims=group_ndims, is_reparameterized=self.is_reparameterized) if compute_density: with tf.name_scope('compute_prob_and_log_prob'): log_p = t.log_prob() t._self_prob = tf.exp(log_p) return t def log_prob(self, given, group_ndims=0, name=None): with tf.name_scope(name=name, default_name='log_prob'): if len(given.shape) > 3: time_indices_shape = tf.convert_to_tensor([ tf.shape(given)[0], tf.shape(self.input_q)[1], self.z_dim ]) given = tf.transpose(given, [2, 0, 1, 3]) else: time_indices_shape = tf.convert_to_tensor( [tf.shape(self.input_q)[1], self.z_dim]) given = tf.transpose(given, [1, 0, 2]) log_prob = tf.scan(fn=self.log_prob_step, elems=(given, self.input_q), initializer=tf.zeros(time_indices_shape), back_prop=False) if len(given.shape) > 3: log_prob = tf.transpose(log_prob, [1, 2, 0, 3]) else: log_prob = tf.transpose(log_prob, [1, 0, 2]) if group_ndims == 1: log_prob = tf.reduce_sum(log_prob, axis=-1) return log_prob def prob(self, given, group_ndims=0, name=None): with tf.name_scope(name=name, default_name='prob'): log_prob = self.log_prob(given, group_ndims, name) return tf.exp(log_prob)