Exemplo n.º 1
0
 def test_dtypes_multinormal(self, dtype):
     dist = make_multinormal([1], [1], dtype=dtype)
     dist = prob.Tanh(dist)
     self.assertEqual(dtype, dist.dtype)
     self.assertEqual(dtype, dist.log_prob(0).dtype)
     self.assertEqual(dtype, dist.mode().dtype)
     self.assertEqual(dtype, dist.sample().dtype)
Exemplo n.º 2
0
 def test_tanh(self):
     dist = make_normal([], [], dtype=np.float64)
     dist_tanh = prob.Tanh(dist)
     x = np.linspace(-3., 3., 100).reshape([2, 5, 10]).astype(np.float64)
     y = np.tanh(x)
     ildj = -np.log1p(-np.square(np.tanh(x)))
     self.assertAllClose(y, dist_tanh.forward(x), atol=0., rtol=1e-6)
     self.assertAllClose(x, dist_tanh.inverse(y), atol=0., rtol=1e-4)
     self.assertAllClose(-ildj,
                         dist_tanh.log_det_jacob(x),
                         atol=0.,
                         rtol=1e-4)
Exemplo n.º 3
0
 def test_shapes_normal(self, mean_shape, scale_shape):
     batch_shape = utils.broadcast_shape(tf.TensorShape(mean_shape),
                                         tf.TensorShape(scale_shape))
     dist = make_normal(mean_shape, scale_shape, dtype=tf.float32)
     dist = prob.Tanh(dist)
     self.assertEqual(0, dist.event_ndims)
     self.assertArrayEqual(batch_shape, dist.shape)
     self.assertArrayEqual(batch_shape, dist.batch_shape())
     self.assertArrayEqual([], dist.event_shape())
     self.assertArrayEqual(batch_shape,
                           dist.log_prob(np.zeros(batch_shape)).shape)
     self.assertArrayEqual(batch_shape, dist.mode().shape)
     self.assertArrayEqual(batch_shape, dist.sample().shape)
Exemplo n.º 4
0
 def test_shapes_multinormal(self, mean_shape, scale_shape):
     full_shape = (np.ones(mean_shape) * np.ones(scale_shape)).shape
     batch_shape = full_shape[:-1]
     event_shape = full_shape[-1:]
     dist = make_multinormal(mean_shape, scale_shape, dtype=tf.float32)
     dist = prob.Tanh(dist)
     self.assertEqual(1, dist.event_ndims)
     self.assertArrayEqual(full_shape, dist.shape)
     self.assertArrayEqual(batch_shape, dist.batch_shape())
     self.assertArrayEqual(event_shape, dist.event_shape())
     self.assertArrayEqual(batch_shape,
                           dist.log_prob(np.zeros(full_shape)).shape)
     self.assertArrayEqual(full_shape, dist.mode().shape)
     self.assertArrayEqual(full_shape, dist.sample().shape)
Exemplo n.º 5
0
    def call(self, inputs, training=True):
        '''Forward network

        Args:
            inputs (tf.Tensor): Expecting a latent vector in shape
                (b, latent), tf.float32
            training (bool, optional): Training mode. Defaults to True.

        Returns:
            MultiNormal: A multi variate gaussian distribution
        '''
        # forward model
        mean = self._mean_model(inputs, training=training)
        logstd = self._logstd_model(inputs, training=training)
        std = tf.math.softplus(logstd) + 1e-5
        # reshape as action space shape (-1 = batch dim)
        output_shape = [-1] + list(self.action_shape)
        mean = tf.reshape(mean, output_shape)
        std = tf.reshape(std, output_shape)
        # create multi variate gauss dist with tah squashed
        distrib = ub_prob.MultiNormal(mean, std)
        if self.squash:
            distrib = ub_prob.Tanh(distrib)
        return distrib
Exemplo n.º 6
0
 def test_bijector_init_no_exception(self):
     dist = prob.Normal(mean=1.0, scale=2.0)
     dist = prob.Tanh(dist)
Exemplo n.º 7
0
 def test_bijector_init_exception(self):
     with self.assertRaises(ValueError):
         prob.Tanh(object())