예제 #1
0
def coercible_tensor(d,
                     convert_to_tensor_fn=tfd.Distribution.sample,
                     return_value=False):
    r""" make a distribution convertible to Tensor using the
  `convert_to_tensor_fn`

  This code is copied from: `distribution_layers.py` tensorflow_probability
  """
    assert isinstance(d, tfd.Distribution), \
      "dist must be instance of tensorflow_probability.Distribution"
    convert_to_tensor_fn = _get_convert_to_tensor_fn(convert_to_tensor_fn)
    # Wraps the distribution to return both dist and concrete value."""
    value_is_seq = isinstance(d.dtype, collections.Sequence)
    maybe_composite_convert_to_tensor_fn = (
        (lambda d: tensor_tuple.TensorTuple(convert_to_tensor_fn(d)))
        if value_is_seq else convert_to_tensor_fn)
    distribution = dtc._TensorCoercible(
        distribution=d,
        convert_to_tensor_fn=maybe_composite_convert_to_tensor_fn)
    ### prepare the value
    value = distribution._value()
    value._tfp_distribution = distribution
    if value_is_seq:
        value.shape = value[-1].shape
        value.get_shape = value[-1].get_shape
        value.dtype = value[-1].dtype
        distribution.shape = value[-1].shape
        distribution.get_shape = value[-1].get_shape
    else:
        distribution.shape = value.shape
        distribution.get_shape = value.get_shape
    ### return
    if return_value:
        return distribution, value
    return distribution
예제 #2
0
    """Pretend user-side class for `ConvertToCompositeTensorTest ."""
    def __init__(self, sequence):
        self._sequence = tuple(sequence)

    def __getitem__(self, key):
        return self._sequence[key]

    def __len__(self):
        return len(self._sequence)

    def __iter__(self):
        return iter(self._sequence)


tf.register_tensor_conversion_function(
    MyTuple, conversion_func=lambda x, *_, **__: tensor_tuple.TensorTuple(x))


@test_util.test_all_tf_execution_regimes
class CustomConvertToCompositeTensorTest(test_util.TestCase):
    def test_iter(self):
        x = MyTuple((1, [2., 3.], [[4, 5], [6, 7]]))
        y = ops.convert_to_tensor_or_composite(value=x)
        # TODO(jsimsa): The behavior of `is_tensor` for composite tensors have
        # changed (from True to False) and this check needs to be disabled so that
        # both TAP presubmits (running at HEAD) and Kokoro presubmit (using TF
        # nightly) pass. Re-enable this check when TF nightly picks up this change.
        # self.assertTrue(tf.is_tensor(y))
        self.assertIsInstance(y, tensor_tuple.TensorTuple)
        self.assertLen(y, 3)
        for x_, y_ in zip(x, y):