Пример #1
0
        def create_dist(loc_and_scale):
            ndims = output_spec.shape.num_elements()
            loc = loc_and_scale[..., :ndims]
            scale = tf.exp(loc_and_scale[..., ndims:])

            distribution = tfp.distributions.MultivariateNormalDiag(
                loc=loc,
                scale_diag=scale,
                validate_args=True,
            )
            return distribution_utils.scale_distribution_to_spec(
                distribution, output_spec)
Пример #2
0
    def testScaleDistribution(self):
        action_spec = tensor_spec.BoundedTensorSpec([1], tf.float32, -2, 4)
        distribution = tfp.distributions.Normal(0, 4)
        scaled_distribution = utils.scale_distribution_to_spec(
            distribution, action_spec)
        if tf.executing_eagerly():
            sample = scaled_distribution.sample
        else:
            sample = scaled_distribution.sample()

        for _ in range(1000):
            sample_np = self.evaluate(sample)

            self.assertGreater(sample_np, -2.00001)
            self.assertLess(sample_np, 4.00001)
Пример #3
0
 def distribution_builder(*args, **kwargs):
   if is_multivariate:
     # For backwards compatibility, and because MVNDiag does not support
     # `param_static_shapes`, even when using MVNDiag the spec
     # continues to use the terms 'loc' and 'scale'.  Here we have to massage
     # the construction to use 'scale' for kwarg 'scale_diag'.  Since they
     # have the same shape and dtype expectationts, this is okay.
     kwargs = kwargs.copy()
     kwargs['scale_diag'] = kwargs['scale']
     del kwargs['scale']
     distribution = tfp.distributions.MultivariateNormalDiag(*args, **kwargs)
   else:
     distribution = tfp.distributions.Normal(*args, **kwargs)
   if self._scale_distribution:
     return distribution_utils.scale_distribution_to_spec(
         distribution, sample_spec)
   return distribution
Пример #4
0
 def distribution_builder(*args, **kwargs):
     distribution = tfp.distributions.Normal(*args, **kwargs)
     if self._scale_distribution:
         return distribution_utils.scale_distribution_to_spec(
             distribution, sample_spec)
     return distribution
Пример #5
0
 def distribution_builder(*args, **kwargs):
   distribution = tfp.distributions.MultivariateNormalDiag(*args, **kwargs)
   return distribution_utils.scale_distribution_to_spec(
       distribution, sample_spec)
Пример #6
0
def _build_squash_to_spec_normal(spec, *args, **kwargs):
    distribution = tfp.distributions.Normal(*args, **kwargs)
    return scale_distribution_to_spec(distribution, spec)