Ejemplo n.º 1
0
 def __init__(self,
              input_q,
              mean_q_mlp,
              std_q_mlp,
              z_dim,
              window_length=100,
              is_reparameterized=True,
              check_numerics=True):
     self.normal = Normal(mean=tf.zeros([window_length, z_dim]),
                          std=tf.ones([window_length, z_dim]))
     super(RecurrentDistribution,
           self).__init__(dtype=self.normal.dtype,
                          is_continuous=True,
                          is_reparameterized=is_reparameterized,
                          batch_shape=self.normal.batch_shape,
                          batch_static_shape=self.normal.get_batch_shape(),
                          value_ndims=self.normal.value_ndims)
     self.std_q_mlp = std_q_mlp
     self.mean_q_mlp = mean_q_mlp
     self._check_numerics = check_numerics
     self.input_q = tf.transpose(input_q, [1, 0, 2])
     self._dtype = input_q.dtype
     self._is_reparameterized = is_reparameterized
     self._is_continuous = True
     self.z_dim = z_dim
     self.window_length = window_length
     self.time_first_shape = tf.convert_to_tensor(
         [self.window_length,
          tf.shape(input_q)[0], self.z_dim])
Ejemplo n.º 2
0
    def test_transform_on_transformed(self):
        with self.test_session() as sess:
            normal = Normal(mean=tf.zeros([3, 4, 5]), logstd=0.)
            self.assertEqual(normal.value_ndims, 0)
            self.assertEqual(normal.get_batch_shape().as_list(), [3, 4, 5])
            self.assertEqual(list(sess.run(normal.batch_shape)), [3, 4, 5])

            distrib = normal.batch_ndims_to_value(0)
            self.assertIs(distrib, normal)

            distrib = normal.batch_ndims_to_value(1)
            self.assertIsInstance(distrib, BatchToValueDistribution)
            self.assertEqual(distrib.value_ndims, 1)
            self.assertEqual(distrib.get_batch_shape().as_list(), [3, 4])
            self.assertEqual(list(sess.run(distrib.batch_shape)), [3, 4])
            self.assertIs(distrib.base_distribution, normal)

            distrib2 = distrib.expand_value_ndims(1)
            self.assertIsInstance(distrib2, BatchToValueDistribution)
            self.assertEqual(distrib2.value_ndims, 2)
            self.assertEqual(distrib2.get_batch_shape().as_list(), [3])
            self.assertEqual(list(sess.run(distrib2.batch_shape)), [3])
            self.assertIs(distrib.base_distribution, normal)

            distrib2 = distrib.expand_value_ndims(0)
            self.assertIs(distrib2, distrib)
            self.assertEqual(distrib2.value_ndims, 1)
            self.assertEqual(distrib.value_ndims, 1)
            self.assertEqual(distrib2.get_batch_shape().as_list(), [3, 4])
            self.assertEqual(list(sess.run(distrib2.batch_shape)), [3, 4])
            self.assertIs(distrib.base_distribution, normal)
Ejemplo n.º 3
0
    def test_invert_flow(self):
        with self.test_session() as sess:
            # test invert a normal flow
            flow = QuadraticFlow(2., 5.)
            inv_flow = flow.invert()

            self.assertIsInstance(inv_flow, InvertFlow)
            self.assertEqual(inv_flow.x_value_ndims, 0)
            self.assertEqual(inv_flow.y_value_ndims, 0)
            self.assertFalse(inv_flow.require_batch_dims)

            test_x = np.arange(12, dtype=np.float32) + 1.
            test_y, test_log_det = quadratic_transform(npyops, test_x, 2., 5.)

            self.assertFalse(flow._has_built)
            y, log_det_y = inv_flow.inverse_transform(tf.constant(test_x))
            self.assertTrue(flow._has_built)

            np.testing.assert_allclose(sess.run(y), test_y)
            np.testing.assert_allclose(sess.run(log_det_y), test_log_det)
            invertible_flow_standard_check(self, inv_flow, sess, test_y)

            # test invert an InvertFlow
            inv_inv_flow = inv_flow.invert()
            self.assertIs(inv_inv_flow, flow)

            # test use with FlowDistribution
            normal = Normal(mean=1., std=2.)
            inv_flow = QuadraticFlow(2., 5.).invert()
            distrib = FlowDistribution(normal, inv_flow)
            distrib_log_det = distrib.log_prob(test_x)
            np.testing.assert_allclose(*sess.run(
                [distrib_log_det,
                 normal.log_prob(test_y) + test_log_det]))
Ejemplo n.º 4
0
    def test_ndims_exceed_limit(self):
        normal = Normal(mean=tf.zeros([3, 4]), logstd=0.)

        with pytest.raises(ValueError,
                           match='`distribution.batch_shape.ndims` '
                           'is less then `ndims`'):
            _ = normal.expand_value_ndims(3)
Ejemplo n.º 5
0
    def test_ndims_equals_zero_and_negative(self):
        normal = Normal(mean=tf.zeros([3, 4]), logstd=0.)

        self.assertIs(normal.batch_ndims_to_value(0), normal)
        self.assertIs(normal.expand_value_ndims(0), normal)

        with pytest.raises(ValueError,
                           match='`ndims` must be non-negative integers'):
            _ = normal.batch_ndims_to_value(-1)
        with pytest.raises(ValueError,
                           match='`ndims` must be non-negative integers'):
            _ = normal.expand_value_ndims(-1)
Ejemplo n.º 6
0
 def test_value_ndims_0(self):
     self.do_check_mixture(
         lambda: Normal(
             mean=np.random.normal(size=[4, 5]).astype(np.float64),
             logstd=np.random.normal(size=[4, 5]).astype(np.float64)
         ),
         value_ndims=0,
         batch_shape=[4, 5],
         is_continuous=True,
         dtype=tf.float64,
         logits_dtype=np.float64,
         is_reparameterized=True
     )
Ejemplo n.º 7
0
    def test_with_normal(self):
        mean = np.random.normal(size=[4, 5]).astype(np.float64)
        logstd = np.random.normal(size=mean.shape).astype(np.float64)
        x = np.random.normal(size=[3, 4, 5])

        with self.test_session() as sess:
            normal = Normal(mean=mean, logstd=logstd)
            distrib = normal.batch_ndims_to_value(1)

            self.assertIsInstance(distrib, BatchToValueDistribution)
            self.assertEqual(distrib.value_ndims, 1)
            self.assertEqual(distrib.get_batch_shape().as_list(), [4])
            self.assertEqual(list(sess.run(distrib.batch_shape)), [4])
            self.assertEqual(distrib.dtype, tf.float64)
            self.assertTrue(distrib.is_continuous)
            self.assertTrue(distrib.is_reparameterized)
            self.assertIs(distrib.base_distribution, normal)

            log_prob = distrib.log_prob(x)
            log_prob2 = distrib.log_prob(x, group_ndims=1)
            self.assertEqual(get_static_shape(log_prob), (3, 4))
            self.assertEqual(get_static_shape(log_prob2), (3, ))
            np.testing.assert_allclose(*sess.run(
                [log_prob, normal.log_prob(x, group_ndims=1)]))
            np.testing.assert_allclose(*sess.run(
                [log_prob2, normal.log_prob(x, group_ndims=2)]))

            prob = distrib.prob(x)
            prob2 = distrib.prob(x, group_ndims=1)
            self.assertEqual(get_static_shape(prob), (3, 4))
            self.assertEqual(get_static_shape(prob2), (3, ))
            np.testing.assert_allclose(
                *sess.run([prob, normal.prob(x, group_ndims=1)]))
            np.testing.assert_allclose(
                *sess.run([prob2, normal.prob(x, group_ndims=2)]))

            sample = distrib.sample(3, compute_density=False)
            sample2 = distrib.sample(3, compute_density=True, group_ndims=1)
            log_prob = sample.log_prob()
            log_prob2 = sample2.log_prob()
            self.assertEqual(get_static_shape(log_prob), (3, 4))
            self.assertEqual(get_static_shape(log_prob2), (3, ))
            np.testing.assert_allclose(*sess.run(
                [log_prob, normal.log_prob(sample, group_ndims=1)]))
            np.testing.assert_allclose(*sess.run(
                [log_prob2, normal.log_prob(sample2, group_ndims=2)]))
Ejemplo n.º 8
0
    def test_errors(self):
        with pytest.raises(TypeError,
                           match='`categorical` must be a Categorical '
                                 'distribution'):
            _ = Mixture(Normal(0., 0.), [Normal(0., 0.)])

        with pytest.raises(ValueError,
                           match='Dynamic `categorical.n_categories` is not '
                                 'supported'):
            _ = Mixture(Categorical(logits=tf.placeholder(tf.float32, [None])),
                        [Normal(0., 0.)])

        with pytest.raises(ValueError, match='`components` must not be empty'):
            _ = Mixture(Categorical(logits=tf.zeros([5])), [])

        with pytest.raises(ValueError,
                           match=r'`len\(components\)` != `categorical.'
                                 r'n_categories`: 1 vs 5'):
            _ = Mixture(Categorical(logits=tf.zeros([5])), [Normal(0., 0.)])

        with pytest.raises(ValueError,
                           match='`dtype` of the 1-th component does not '
                                 'agree with the first component'):
            _ = Mixture(Categorical(logits=tf.zeros([2])),
                        [Categorical(tf.zeros([2, 3]), dtype=tf.int32),
                         Categorical(tf.zeros([2, 3]), dtype=tf.float32)])

        with pytest.raises(ValueError,
                           match='`value_ndims` of the 1-th component does not '
                                 'agree with the first component'):
            _ = Mixture(Categorical(logits=tf.zeros([2])),
                        [Categorical(tf.zeros([2, 3])),
                         OnehotCategorical(tf.zeros([2, 3]))])

        with pytest.raises(ValueError,
                           match='`is_continuous` of the 1-th component does '
                                 'not agree with the first component'):
            _ = Mixture(Categorical(logits=tf.zeros([2])),
                        [Categorical(tf.zeros([2, 3]), dtype=tf.float32),
                         Normal(tf.zeros([2]), tf.zeros([2]))])

        with pytest.raises(ValueError,
                           match='the 0-th component is not re-parameterized'):
            _ = Mixture(Categorical(logits=tf.zeros([2])),
                        [Categorical(tf.zeros([2, 3]), dtype=tf.float32),
                         Normal(tf.zeros([2]), tf.zeros([2]))],
                        is_reparameterized=True)

        with pytest.raises(RuntimeError,
                           match='.* is not re-parameterized'):
            m = Mixture(
                Categorical(logits=tf.zeros([2])),
                [Normal(-1., 0.), Normal(1., 0.)]
            )
            _ = m.sample(1, is_reparameterized=True)

        with pytest.raises(ValueError,
                           match='Batch shape of `categorical` does not '
                                 'agree with the first component'):
            _ = Mixture(
                Categorical(logits=tf.zeros([1, 3, 2])),
                [Normal(mean=tf.zeros([3]), logstd=0.),
                 Normal(mean=tf.zeros([3]), logstd=0.)]
            )

        with pytest.raises(ValueError,
                           match='Batch shape of the 1-th component does not '
                                 'agree with the first component'):
            _ = Mixture(
                Categorical(logits=tf.zeros([3, 2])),
                [Normal(mean=tf.zeros([3]), logstd=0.),
                 Normal(mean=tf.zeros([4]), logstd=0.)]
            )