Exemplo n.º 1
0
    def test_transform_on_transformed(self):
        with self.test_session() as sess:
            normal = Normal(mean=tf.zeros([3, 4, 5]), logstd=0.)
            self.assertEqual(normal.value_ndims, 0)
            self.assertEqual(normal.get_batch_shape().as_list(), [3, 4, 5])
            self.assertEqual(list(sess.run(normal.batch_shape)), [3, 4, 5])

            distrib = normal.batch_ndims_to_value(0)
            self.assertIs(distrib, normal)

            distrib = normal.batch_ndims_to_value(1)
            self.assertIsInstance(distrib, BatchToValueDistribution)
            self.assertEqual(distrib.value_ndims, 1)
            self.assertEqual(distrib.get_batch_shape().as_list(), [3, 4])
            self.assertEqual(list(sess.run(distrib.batch_shape)), [3, 4])
            self.assertIs(distrib.base_distribution, normal)

            distrib2 = distrib.expand_value_ndims(1)
            self.assertIsInstance(distrib2, BatchToValueDistribution)
            self.assertEqual(distrib2.value_ndims, 2)
            self.assertEqual(distrib2.get_batch_shape().as_list(), [3])
            self.assertEqual(list(sess.run(distrib2.batch_shape)), [3])
            self.assertIs(distrib.base_distribution, normal)

            distrib2 = distrib.expand_value_ndims(0)
            self.assertIs(distrib2, distrib)
            self.assertEqual(distrib2.value_ndims, 1)
            self.assertEqual(distrib.value_ndims, 1)
            self.assertEqual(distrib2.get_batch_shape().as_list(), [3, 4])
            self.assertEqual(list(sess.run(distrib2.batch_shape)), [3, 4])
            self.assertIs(distrib.base_distribution, normal)
Exemplo n.º 2
0
 def __init__(self,
              input_q,
              mean_q_mlp,
              std_q_mlp,
              z_dim,
              window_length=100,
              is_reparameterized=True,
              check_numerics=True):
     self.normal = Normal(mean=tf.zeros([window_length, z_dim]),
                          std=tf.ones([window_length, z_dim]))
     super(RecurrentDistribution,
           self).__init__(dtype=self.normal.dtype,
                          is_continuous=True,
                          is_reparameterized=is_reparameterized,
                          batch_shape=self.normal.batch_shape,
                          batch_static_shape=self.normal.get_batch_shape(),
                          value_ndims=self.normal.value_ndims)
     self.std_q_mlp = std_q_mlp
     self.mean_q_mlp = mean_q_mlp
     self._check_numerics = check_numerics
     self.input_q = tf.transpose(input_q, [1, 0, 2])
     self._dtype = input_q.dtype
     self._is_reparameterized = is_reparameterized
     self._is_continuous = True
     self.z_dim = z_dim
     self.window_length = window_length
     self.time_first_shape = tf.convert_to_tensor(
         [self.window_length,
          tf.shape(input_q)[0], self.z_dim])
Exemplo n.º 3
0
    def test_invert_flow(self):
        with self.test_session() as sess:
            # test invert a normal flow
            flow = QuadraticFlow(2., 5.)
            inv_flow = flow.invert()

            self.assertIsInstance(inv_flow, InvertFlow)
            self.assertEqual(inv_flow.x_value_ndims, 0)
            self.assertEqual(inv_flow.y_value_ndims, 0)
            self.assertFalse(inv_flow.require_batch_dims)

            test_x = np.arange(12, dtype=np.float32) + 1.
            test_y, test_log_det = quadratic_transform(npyops, test_x, 2., 5.)

            self.assertFalse(flow._has_built)
            y, log_det_y = inv_flow.inverse_transform(tf.constant(test_x))
            self.assertTrue(flow._has_built)

            np.testing.assert_allclose(sess.run(y), test_y)
            np.testing.assert_allclose(sess.run(log_det_y), test_log_det)
            invertible_flow_standard_check(self, inv_flow, sess, test_y)

            # test invert an InvertFlow
            inv_inv_flow = inv_flow.invert()
            self.assertIs(inv_inv_flow, flow)

            # test use with FlowDistribution
            normal = Normal(mean=1., std=2.)
            inv_flow = QuadraticFlow(2., 5.).invert()
            distrib = FlowDistribution(normal, inv_flow)
            distrib_log_det = distrib.log_prob(test_x)
            np.testing.assert_allclose(*sess.run(
                [distrib_log_det,
                 normal.log_prob(test_y) + test_log_det]))
Exemplo n.º 4
0
    def test_ndims_exceed_limit(self):
        normal = Normal(mean=tf.zeros([3, 4]), logstd=0.)

        with pytest.raises(ValueError,
                           match='`distribution.batch_shape.ndims` '
                           'is less then `ndims`'):
            _ = normal.expand_value_ndims(3)
Exemplo n.º 5
0
    def test_with_normal(self):
        mean = np.random.normal(size=[4, 5]).astype(np.float64)
        logstd = np.random.normal(size=mean.shape).astype(np.float64)
        x = np.random.normal(size=[3, 4, 5])

        with self.test_session() as sess:
            normal = Normal(mean=mean, logstd=logstd)
            distrib = normal.batch_ndims_to_value(1)

            self.assertIsInstance(distrib, BatchToValueDistribution)
            self.assertEqual(distrib.value_ndims, 1)
            self.assertEqual(distrib.get_batch_shape().as_list(), [4])
            self.assertEqual(list(sess.run(distrib.batch_shape)), [4])
            self.assertEqual(distrib.dtype, tf.float64)
            self.assertTrue(distrib.is_continuous)
            self.assertTrue(distrib.is_reparameterized)
            self.assertIs(distrib.base_distribution, normal)

            log_prob = distrib.log_prob(x)
            log_prob2 = distrib.log_prob(x, group_ndims=1)
            self.assertEqual(get_static_shape(log_prob), (3, 4))
            self.assertEqual(get_static_shape(log_prob2), (3, ))
            np.testing.assert_allclose(*sess.run(
                [log_prob, normal.log_prob(x, group_ndims=1)]))
            np.testing.assert_allclose(*sess.run(
                [log_prob2, normal.log_prob(x, group_ndims=2)]))

            prob = distrib.prob(x)
            prob2 = distrib.prob(x, group_ndims=1)
            self.assertEqual(get_static_shape(prob), (3, 4))
            self.assertEqual(get_static_shape(prob2), (3, ))
            np.testing.assert_allclose(
                *sess.run([prob, normal.prob(x, group_ndims=1)]))
            np.testing.assert_allclose(
                *sess.run([prob2, normal.prob(x, group_ndims=2)]))

            sample = distrib.sample(3, compute_density=False)
            sample2 = distrib.sample(3, compute_density=True, group_ndims=1)
            log_prob = sample.log_prob()
            log_prob2 = sample2.log_prob()
            self.assertEqual(get_static_shape(log_prob), (3, 4))
            self.assertEqual(get_static_shape(log_prob2), (3, ))
            np.testing.assert_allclose(*sess.run(
                [log_prob, normal.log_prob(sample, group_ndims=1)]))
            np.testing.assert_allclose(*sess.run(
                [log_prob2, normal.log_prob(sample2, group_ndims=2)]))
Exemplo n.º 6
0
 def test_value_ndims_0(self):
     self.do_check_mixture(
         lambda: Normal(
             mean=np.random.normal(size=[4, 5]).astype(np.float64),
             logstd=np.random.normal(size=[4, 5]).astype(np.float64)
         ),
         value_ndims=0,
         batch_shape=[4, 5],
         is_continuous=True,
         dtype=tf.float64,
         logits_dtype=np.float64,
         is_reparameterized=True
     )
Exemplo n.º 7
0
    def test_ndims_equals_zero_and_negative(self):
        normal = Normal(mean=tf.zeros([3, 4]), logstd=0.)

        self.assertIs(normal.batch_ndims_to_value(0), normal)
        self.assertIs(normal.expand_value_ndims(0), normal)

        with pytest.raises(ValueError,
                           match='`ndims` must be non-negative integers'):
            _ = normal.batch_ndims_to_value(-1)
        with pytest.raises(ValueError,
                           match='`ndims` must be non-negative integers'):
            _ = normal.expand_value_ndims(-1)
Exemplo n.º 8
0
class RecurrentDistribution(Distribution):
    """
    A multi-variable distribution integrated with recurrent structure.
    """
    @property
    def dtype(self):
        return self._dtype

    @property
    def is_continuous(self):
        return self._is_continuous

    @property
    def is_reparameterized(self):
        return self._is_reparameterized

    @property
    def value_shape(self):
        return self.normal.value_shape

    def get_value_shape(self):
        return self.normal.get_value_shape()

    @property
    def batch_shape(self):
        return self.normal.batch_shape

    def get_batch_shape(self):
        return self.normal.get_batch_shape()

    def sample_step(self, a, t):
        z_previous, mu_q_previous, std_q_previous = a
        noise_n, input_q_n = t
        input_q_n = tf.broadcast_to(input_q_n, [
            tf.shape(z_previous)[0],
            tf.shape(input_q_n)[0], input_q_n.shape[1]
        ])
        input_q = tf.concat([input_q_n, z_previous], axis=-1)
        mu_q = self.mean_q_mlp(
            input_q, reuse=tf.AUTO_REUSE)  # n_sample * batch_size * z_dim

        std_q = self.std_q_mlp(input_q)  # n_sample * batch_size * z_dim

        temp = tf.einsum('ik,ijk->ijk', noise_n,
                         std_q)  # n_sample * batch_size * z_dim
        mu_q = tf.broadcast_to(mu_q, tf.shape(temp))
        std_q = tf.broadcast_to(std_q, tf.shape(temp))
        z_n = temp + mu_q

        return z_n, mu_q, std_q

    # @global_reuse
    def log_prob_step(self, _, t):

        given_n, input_q_n = t
        if len(given_n.shape) > 2:
            input_q_n = tf.broadcast_to(input_q_n, [
                tf.shape(given_n)[0],
                tf.shape(input_q_n)[0], input_q_n.shape[1]
            ])
        input_q = tf.concat([given_n, input_q_n], axis=-1)
        mu_q = self.mean_q_mlp(input_q, reuse=tf.AUTO_REUSE)

        std_q = self.std_q_mlp(input_q)
        logstd_q = tf.log(std_q)
        precision = tf.exp(-2 * logstd_q)
        if self._check_numerics:
            precision = tf.check_numerics(precision, "precision")
        log_prob_n = -0.9189385332046727 - logstd_q - 0.5 * precision * tf.square(
            tf.minimum(tf.abs(given_n - mu_q), 1e8))
        return log_prob_n

    def __init__(self,
                 input_q,
                 mean_q_mlp,
                 std_q_mlp,
                 z_dim,
                 window_length=100,
                 is_reparameterized=True,
                 check_numerics=True):
        self.normal = Normal(mean=tf.zeros([window_length, z_dim]),
                             std=tf.ones([window_length, z_dim]))
        super(RecurrentDistribution,
              self).__init__(dtype=self.normal.dtype,
                             is_continuous=True,
                             is_reparameterized=is_reparameterized,
                             batch_shape=self.normal.batch_shape,
                             batch_static_shape=self.normal.get_batch_shape(),
                             value_ndims=self.normal.value_ndims)
        self.std_q_mlp = std_q_mlp
        self.mean_q_mlp = mean_q_mlp
        self._check_numerics = check_numerics
        self.input_q = tf.transpose(input_q, [1, 0, 2])
        self._dtype = input_q.dtype
        self._is_reparameterized = is_reparameterized
        self._is_continuous = True
        self.z_dim = z_dim
        self.window_length = window_length
        self.time_first_shape = tf.convert_to_tensor(
            [self.window_length,
             tf.shape(input_q)[0], self.z_dim])

    def sample(self,
               n_samples=1024,
               is_reparameterized=None,
               group_ndims=0,
               compute_density=False,
               name=None):

        from tfsnippet.stochastic import StochasticTensor
        if n_samples is None:
            n_samples = 1
            n_samples_is_none = True
        else:
            n_samples_is_none = False
        with tf.name_scope(name=name, default_name='sample'):
            noise = self.normal.sample(n_samples=n_samples)

            noise = tf.transpose(
                noise, [1, 0, 2])  # window_length * n_samples * z_dim
            noise = tf.truncated_normal(tf.shape(noise))

            time_indices_shape = tf.convert_to_tensor(
                [n_samples, tf.shape(self.input_q)[1], self.z_dim])

            samples = tf.scan(
                fn=self.sample_step,
                elems=(noise, self.input_q),
                initializer=(tf.zeros(time_indices_shape),
                             tf.zeros(time_indices_shape),
                             tf.ones(time_indices_shape)),
                back_prop=False)[
                    0]  # time_step * n_samples * batch_size * z_dim

            samples = tf.transpose(
                samples,
                [1, 2, 0, 3])  # n_samples * batch_size * time_step *  z_dim

            if n_samples_is_none:
                t = StochasticTensor(
                    distribution=self,
                    tensor=tf.reduce_mean(samples, axis=0),
                    n_samples=1,
                    group_ndims=group_ndims,
                    is_reparameterized=self.is_reparameterized)
            else:
                t = StochasticTensor(
                    distribution=self,
                    tensor=samples,
                    n_samples=n_samples,
                    group_ndims=group_ndims,
                    is_reparameterized=self.is_reparameterized)
            if compute_density:
                with tf.name_scope('compute_prob_and_log_prob'):
                    log_p = t.log_prob()
                    t._self_prob = tf.exp(log_p)
            return t

    def log_prob(self, given, group_ndims=0, name=None):
        with tf.name_scope(name=name, default_name='log_prob'):
            if len(given.shape) > 3:
                time_indices_shape = tf.convert_to_tensor([
                    tf.shape(given)[0],
                    tf.shape(self.input_q)[1], self.z_dim
                ])
                given = tf.transpose(given, [2, 0, 1, 3])
            else:
                time_indices_shape = tf.convert_to_tensor(
                    [tf.shape(self.input_q)[1], self.z_dim])
                given = tf.transpose(given, [1, 0, 2])
            log_prob = tf.scan(fn=self.log_prob_step,
                               elems=(given, self.input_q),
                               initializer=tf.zeros(time_indices_shape),
                               back_prop=False)
            if len(given.shape) > 3:
                log_prob = tf.transpose(log_prob, [1, 2, 0, 3])
            else:
                log_prob = tf.transpose(log_prob, [1, 0, 2])

            if group_ndims == 1:
                log_prob = tf.reduce_sum(log_prob, axis=-1)
            return log_prob

    def prob(self, given, group_ndims=0, name=None):
        with tf.name_scope(name=name, default_name='prob'):
            log_prob = self.log_prob(given, group_ndims, name)
            return tf.exp(log_prob)
Exemplo n.º 9
0
    def test_errors(self):
        with pytest.raises(TypeError,
                           match='`categorical` must be a Categorical '
                                 'distribution'):
            _ = Mixture(Normal(0., 0.), [Normal(0., 0.)])

        with pytest.raises(ValueError,
                           match='Dynamic `categorical.n_categories` is not '
                                 'supported'):
            _ = Mixture(Categorical(logits=tf.placeholder(tf.float32, [None])),
                        [Normal(0., 0.)])

        with pytest.raises(ValueError, match='`components` must not be empty'):
            _ = Mixture(Categorical(logits=tf.zeros([5])), [])

        with pytest.raises(ValueError,
                           match=r'`len\(components\)` != `categorical.'
                                 r'n_categories`: 1 vs 5'):
            _ = Mixture(Categorical(logits=tf.zeros([5])), [Normal(0., 0.)])

        with pytest.raises(ValueError,
                           match='`dtype` of the 1-th component does not '
                                 'agree with the first component'):
            _ = Mixture(Categorical(logits=tf.zeros([2])),
                        [Categorical(tf.zeros([2, 3]), dtype=tf.int32),
                         Categorical(tf.zeros([2, 3]), dtype=tf.float32)])

        with pytest.raises(ValueError,
                           match='`value_ndims` of the 1-th component does not '
                                 'agree with the first component'):
            _ = Mixture(Categorical(logits=tf.zeros([2])),
                        [Categorical(tf.zeros([2, 3])),
                         OnehotCategorical(tf.zeros([2, 3]))])

        with pytest.raises(ValueError,
                           match='`is_continuous` of the 1-th component does '
                                 'not agree with the first component'):
            _ = Mixture(Categorical(logits=tf.zeros([2])),
                        [Categorical(tf.zeros([2, 3]), dtype=tf.float32),
                         Normal(tf.zeros([2]), tf.zeros([2]))])

        with pytest.raises(ValueError,
                           match='the 0-th component is not re-parameterized'):
            _ = Mixture(Categorical(logits=tf.zeros([2])),
                        [Categorical(tf.zeros([2, 3]), dtype=tf.float32),
                         Normal(tf.zeros([2]), tf.zeros([2]))],
                        is_reparameterized=True)

        with pytest.raises(RuntimeError,
                           match='.* is not re-parameterized'):
            m = Mixture(
                Categorical(logits=tf.zeros([2])),
                [Normal(-1., 0.), Normal(1., 0.)]
            )
            _ = m.sample(1, is_reparameterized=True)

        with pytest.raises(ValueError,
                           match='Batch shape of `categorical` does not '
                                 'agree with the first component'):
            _ = Mixture(
                Categorical(logits=tf.zeros([1, 3, 2])),
                [Normal(mean=tf.zeros([3]), logstd=0.),
                 Normal(mean=tf.zeros([3]), logstd=0.)]
            )

        with pytest.raises(ValueError,
                           match='Batch shape of the 1-th component does not '
                                 'agree with the first component'):
            _ = Mixture(
                Categorical(logits=tf.zeros([3, 2])),
                [Normal(mean=tf.zeros([3]), logstd=0.),
                 Normal(mean=tf.zeros([4]), logstd=0.)]
            )