Пример #1
0
def coercible_tensor(d: tfd.Distribution,
                     convert_to_tensor_fn=tfd.Distribution.sample,
                     return_value: bool = False) -> tfd.Distribution:
    r""" make a distribution convertible to Tensor using the
  `convert_to_tensor_fn`

  This code is copied from: `distribution_layers.py` tensorflow_probability
  """
    assert isinstance(d, tfd.Distribution), \
      "dist must be instance of tensorflow_probability.Distribution"
    convert_to_tensor_fn = _get_convert_to_tensor_fn(convert_to_tensor_fn)
    if inspect.isfunction(convert_to_tensor_fn) and \
      convert_to_tensor_fn in list(tfd.Distribution.__dict__.values()):
        convert_to_tensor_fn = getattr(type(d), convert_to_tensor_fn.__name__)
    # Wraps the distribution to return both dist and concrete value."""
    distribution = dtc._TensorCoercible(
        distribution=d, convert_to_tensor_fn=convert_to_tensor_fn)
    ### prepare the value
    value = distribution._value()
    value._tfp_distribution = distribution
    distribution.shape = value.shape
    distribution.get_shape = value.get_shape
    ### return
    if return_value:
        return distribution, value
    return distribution
Пример #2
0
def coercible_tensor(d,
                     convert_to_tensor_fn=tfd.Distribution.sample,
                     return_value=False):
    r""" make a distribution convertible to Tensor using the
  `convert_to_tensor_fn`

  This code is copied from: `distribution_layers.py` tensorflow_probability
  """
    assert isinstance(d, tfd.Distribution), \
      "dist must be instance of tensorflow_probability.Distribution"
    convert_to_tensor_fn = _get_convert_to_tensor_fn(convert_to_tensor_fn)
    # Wraps the distribution to return both dist and concrete value."""
    value_is_seq = isinstance(d.dtype, collections.Sequence)
    maybe_composite_convert_to_tensor_fn = (
        (lambda d: tensor_tuple.TensorTuple(convert_to_tensor_fn(d)))
        if value_is_seq else convert_to_tensor_fn)
    distribution = dtc._TensorCoercible(
        distribution=d,
        convert_to_tensor_fn=maybe_composite_convert_to_tensor_fn)
    ### prepare the value
    value = distribution._value()
    value._tfp_distribution = distribution
    if value_is_seq:
        value.shape = value[-1].shape
        value.get_shape = value[-1].get_shape
        value.dtype = value[-1].dtype
        distribution.shape = value[-1].shape
        distribution.get_shape = value[-1].get_shape
    else:
        distribution.shape = value.shape
        distribution.get_shape = value.get_shape
    ### return
    if return_value:
        return distribution, value
    return distribution
Пример #3
0
 def _fn(*fargs, **fkwargs):
     """Wraps `make_distribution_fn` to return both dist and concrete value."""
     distribution = dtc._TensorCoercible(  # pylint: disable=protected-access
         distribution=make_distribution_fn(*fargs, **fkwargs),
         convert_to_tensor_fn=convert_to_tensor_fn)
     value = tf.convert_to_tensor(distribution)
     # TODO(b/120153609): Keras is incorrectly presuming everything is a
     # `tf.Tensor`. Closing this bug entails ensuring Keras only accesses
     # `tf.Tensor` properties after calling `tf.convert_to_tensor`.
     distribution.shape = value.shape
     distribution.get_shape = value.get_shape
     return distribution, value
Пример #4
0
 def call(self, x):
   attrs = self.attr_name.split('.')
   for a in attrs:
     x = getattr(x, a)
   # special case a distribution is returned
   if isinstance(x, Distribution) and not isinstance(x, dtc._TensorCoercible):
     dist = dtc._TensorCoercible(
         distribution=x, convert_to_tensor_fn=self.convert_to_tensor_fn)
     value = tf.convert_to_tensor(value=dist)
     value._tfp_distribution = dist
     dist.shape = value.shape
     dist.get_shape = value.get_shape
     x = dist
   return x