Ejemplo n.º 1
0
    def testReprWorksCorrectlyScalar(self):
        normal = tfd.Normal(loc=np.float16(0), scale=np.float16(1))
        self.assertEqual(
            repr(normal), '<tfp.distributions.Normal'
            ' \'Normal\''
            ' batch_shape=[]'
            ' event_shape=[]'
            ' dtype=float16>')

        chi2 = tfd.Chi2(df=np.float32([1., 2.]), name='silly')
        self.assertEqual(
            repr(chi2),
            '<tfp.distributions.Chi2'
            ' \'silly\''  # What a silly name that is!
            ' batch_shape=[2]'
            ' event_shape=[]'
            ' dtype=float32>')

        # There's no notion of partially known shapes in eager mode, so exit
        # early.
        if tf.executing_eagerly():
            return

        exp = tfd.Exponential(
            rate=tf1.placeholder_with_default(1., shape=None))
        self.assertEqual(
            repr(exp), '<tfp.distributions.Exponential'
            ' \'Exponential\''
            ' batch_shape=?'
            ' event_shape=[]'
            ' dtype=float32>')
Ejemplo n.º 2
0
    def testStrWorksCorrectlyScalar(self):
        normal = tfd.Normal(loc=np.float16(0), scale=1, validate_args=True)
        self.assertEqual(
            str(normal), 'tfp.distributions.Normal('
            '"Normal", '
            'batch_shape=[], '
            'event_shape=[], '
            'dtype=float16)')

        chi2 = tfd.Chi2(df=np.float32([1., 2.]),
                        name='silly',
                        validate_args=True)
        self.assertEqual(
            str(chi2),
            'tfp.distributions.Chi2('
            '"silly", '  # What a silly name that is!
            'batch_shape=[2], '
            'event_shape=[], '
            'dtype=float32)')

        # There's no notion of partially known shapes in eager mode, so exit
        # early.
        if tf.executing_eagerly():
            return

        exp = tfd.Exponential(rate=tf1.placeholder_with_default(1.,
                                                                shape=None),
                              validate_args=True)
        self.assertEqual(
            str(exp),
            'tfp.distributions.Exponential("Exponential", '
            # No batch shape.
            'event_shape=[], '
            'dtype=float32)')
Ejemplo n.º 3
0
 def test_default_event_space_bijector(self):
     dist = tfd.Chi2([1., 2., 3., 6.], validate_args=True)
     batch_shape = [2, 2, 1]
     reshape_dist = tfd.BatchReshape(dist, batch_shape, validate_args=True)
     x = self.evaluate(dist._experimental_default_event_space_bijector()(
         10. * tf.ones(dist.batch_shape)))
     x_reshape = self.evaluate(
         reshape_dist._experimental_default_event_space_bijector()(
             10. * tf.ones(reshape_dist.batch_shape)))
     self.assertAllEqual(tf.reshape(x, batch_shape), x_reshape)