def _output_distribution_spec(self, sample_spec, network_name):
        is_multivariate = sample_spec.shape.ndims > 0
        input_param_shapes = (tfp.distributions.Normal.param_static_shapes(
            sample_spec.shape))

        input_param_spec = {
            name: tensor_spec.TensorSpec(  # pylint: disable=g-complex-comprehension
                shape=shape,
                dtype=sample_spec.dtype,
                name=network_name + '_' + name)
            for name, shape in input_param_shapes.items()
        }

        def distribution_builder(*args, **kwargs):
            if is_multivariate:
                # For backwards compatibility, and because MVNDiag does not support
                # `param_static_shapes`, even when using MVNDiag the spec
                # continues to use the terms 'loc' and 'scale'.  Here we have to massage
                # the construction to use 'scale' for kwarg 'scale_diag'.  Since they
                # have the same shape and dtype expectationts, this is okay.
                kwargs = kwargs.copy()
                kwargs['scale_diag'] = kwargs['scale']
                del kwargs['scale']
                distribution = tfp.distributions.MultivariateNormalDiag(
                    *args, **kwargs)
            else:
                distribution = tfp.distributions.Normal(*args, **kwargs)
            if self._scale_distribution:
                return distribution_utils.scale_distribution_to_spec(
                    distribution, sample_spec)
            return distribution

        return distribution_spec.DistributionSpec(distribution_builder,
                                                  input_param_spec,
                                                  sample_spec=sample_spec)
Example #2
0
    def _get_normal_distribution_spec(self, sample_spec):
        is_multivariate = sample_spec.shape.ndims > 0
        param_properties = tfp.distributions.Normal.parameter_properties()
        input_param_spec = {  # pylint: disable=g-complex-comprehension
            name: tensor_spec.TensorSpec(shape=properties.shape_fn(
                sample_spec.shape),
                                         dtype=sample_spec.dtype)
            for name, properties in param_properties.items()
        }

        def distribution_builder(*args, **kwargs):
            if is_multivariate:
                # For backwards compatibility, and because MVNDiag does not support
                # `param_static_shapes`, even when using MVNDiag the spec
                # continues to use the terms 'loc' and 'scale'.  Here we have to massage
                # the construction to use 'scale' for kwarg 'scale_diag'.  Since they
                # have the same shape and dtype expectationts, this is okay.
                kwargs = kwargs.copy()
                kwargs['scale_diag'] = kwargs['scale']
                del kwargs['scale']
                return tfp.distributions.MultivariateNormalDiag(
                    *args, **kwargs)
            else:
                return tfp.distributions.Normal(*args, **kwargs)

        return distribution_spec.DistributionSpec(distribution_builder,
                                                  input_param_spec,
                                                  sample_spec=sample_spec)
Example #3
0
    def setUp(self):
        super(MaskSplitterNetworkTest, self).setUp()
        self._observation_and_mask_spec = {
            'observation': tensor_spec.BoundedTensorSpec((2, ), tf.float32, 0,
                                                         5),
            'mask': tensor_spec.BoundedTensorSpec((3, ), tf.int32, 0, 1),
        }
        self._observation_spec = self._observation_and_mask_spec['observation']
        self._mask_spec = self._observation_and_mask_spec['mask']
        self._state_spec = tensor_spec.BoundedTensorSpec((1, ), tf.int64, 0,
                                                         10)

        def splitter_fn(observation_and_mask):
            return observation_and_mask['observation'], observation_and_mask[
                'mask']

        self._splitter_fn = splitter_fn
        self._observation_and_mask = tensor_spec.sample_spec_nest(
            self._observation_and_mask_spec, outer_dims=(4, ))
        self._network_state = tensor_spec.sample_spec_nest(self._state_spec,
                                                           outer_dims=(4, ))

        self._output_spec = distribution_spec.DistributionSpec(
            tfp.distributions.Categorical,
            self._observation_spec,
            sample_spec=tensor_spec.BoundedTensorSpec((1, ), tf.int64, 0, 1),
            **tfp.distributions.Categorical(logits=[0, 5]).parameters)
  def _output_distribution_spec(self, output_shape, sample_spec):
    input_param_spec = {
        'logits': tensor_spec.TensorSpec(shape=output_shape, dtype=tf.float32)
    }

    return distribution_spec.DistributionSpec(
        tfp.distributions.Categorical,
        input_param_spec,
        sample_spec=sample_spec,
        dtype=sample_spec.dtype)
Example #5
0
  def _get_normal_distribution_spec(self, sample_spec):
    input_param_shapes = tfp.distributions.Normal.param_static_shapes(
        sample_spec.shape)
    input_param_spec = tf.nest.map_structure(
        lambda tensor_shape: tensor_spec.TensorSpec(  # pylint: disable=g-long-lambda
            shape=tensor_shape,
            dtype=sample_spec.dtype),
        input_param_shapes)

    return distribution_spec.DistributionSpec(
        tfp.distributions.Normal, input_param_spec, sample_spec=sample_spec)
Example #6
0
    def _get_normal_distribution_spec(self, sample_spec):
        param_properties = tfp.distributions.Normal.parameter_properties()
        input_param_spec = {  # pylint: disable=g-complex-comprehension
            name: tensor_spec.TensorSpec(shape=properties.shape_fn(
                sample_spec.shape),
                                         dtype=sample_spec.dtype)
            for name, properties in param_properties.items()
        }

        return distribution_spec.DistributionSpec(tfp.distributions.Normal,
                                                  input_param_spec,
                                                  sample_spec=sample_spec)
Example #7
0
    def _output_distribution_spec(self, output_shape, sample_spec,
                                  network_name):
        input_param_spec = {
            'temperature':
            tensor_spec.TensorSpec((), name=network_name + '_temp'),
            'logits':
            tensor_spec.TensorSpec(shape=output_shape,
                                   dtype=tf.float32,
                                   name=network_name + '_logits')
        }

        return distribution_spec.DistributionSpec(GumbelLayer,
                                                  input_param_spec,
                                                  sample_spec=sample_spec,
                                                  dtype=sample_spec.dtype)
  def _output_distribution_spec(self, sample_spec):
    input_param_shapes = tfp.distributions.Normal.param_static_shapes(
        sample_spec.shape)
    input_param_spec = tf.nest.map_structure(
        lambda tensor_shape: tensor_spec.TensorSpec(  # pylint: disable=g-long-lambda
            shape=tensor_shape,
            dtype=sample_spec.dtype),
        input_param_shapes)

    def distribution_builder(*args, **kwargs):
      distribution = tfp.distributions.Normal(*args, **kwargs)
      if self._scale_distribution:
        return common.scale_distribution_to_spec(distribution, sample_spec)
      return distribution

    return distribution_spec.DistributionSpec(
        distribution_builder, input_param_spec, sample_spec=sample_spec)
Example #9
0
  def testBuildsDistribution(self):
    expected_distribution = tfd.Categorical([0.2, 0.3, 0.5], validate_args=True)
    input_param_spec = tensor_spec.TensorSpec((3,), dtype=tf.float32)
    sample_spec = tensor_spec.TensorSpec((1,), dtype=tf.int32)

    spec = distribution_spec.DistributionSpec(
        tfd.Categorical,
        input_param_spec,
        sample_spec=sample_spec,
        **expected_distribution.parameters)

    self.assertEqual(expected_distribution.parameters['logits'],
                     spec.distribution_parameters['logits'])

    distribution = spec.build_distribution(logits=[0.1, 0.4, 0.5])

    self.assertTrue(isinstance(distribution, tfd.Categorical))
    self.assertTrue(distribution.parameters['validate_args'])
    self.assertEqual([0.1, 0.4, 0.5], distribution.parameters['logits'])
Example #10
0
  def _output_distribution_spec(self, sample_spec, network_name):
    input_param_shapes = {
        'loc': sample_spec.shape,
        'scale_diag': sample_spec.shape
    }
    input_param_spec = {  # pylint: disable=g-complex-comprehension
        name: tensor_spec.TensorSpec(
            shape=shape,
            dtype=sample_spec.dtype,
            name=network_name + '_' + name)
        for name, shape in input_param_shapes.items()
    }

    def distribution_builder(*args, **kwargs):
      distribution = tfp.distributions.MultivariateNormalDiag(*args, **kwargs)
      return distribution_utils.scale_distribution_to_spec(
          distribution, sample_spec)

    return distribution_spec.DistributionSpec(
        distribution_builder, input_param_spec, sample_spec=sample_spec)
  def _output_distribution_spec(self, sample_spec, network_name):
    input_param_shapes = tfp.distributions.Normal.param_static_shapes(
        sample_spec.shape)

    input_param_spec = {
        name: tensor_spec.TensorSpec(  # pylint: disable=g-complex-comprehension
            shape=shape,
            dtype=sample_spec.dtype,
            name=network_name + '_' + name)
        for name, shape in input_param_shapes.items()
    }

    def distribution_builder(*args, **kwargs):
      distribution = tfp.distributions.Normal(*args, **kwargs)
      if self._scale_distribution:
        return distribution_utils.scale_distribution_to_spec(
            distribution, sample_spec)
      return distribution

    return distribution_spec.DistributionSpec(
        distribution_builder, input_param_spec, sample_spec=sample_spec)