def coercible_tensor(d, convert_to_tensor_fn=tfd.Distribution.sample, return_value=False): r""" make a distribution convertible to Tensor using the `convert_to_tensor_fn` This code is copied from: `distribution_layers.py` tensorflow_probability """ assert isinstance(d, tfd.Distribution), \ "dist must be instance of tensorflow_probability.Distribution" convert_to_tensor_fn = _get_convert_to_tensor_fn(convert_to_tensor_fn) # Wraps the distribution to return both dist and concrete value.""" value_is_seq = isinstance(d.dtype, collections.Sequence) maybe_composite_convert_to_tensor_fn = ( (lambda d: tensor_tuple.TensorTuple(convert_to_tensor_fn(d))) if value_is_seq else convert_to_tensor_fn) distribution = dtc._TensorCoercible( distribution=d, convert_to_tensor_fn=maybe_composite_convert_to_tensor_fn) ### prepare the value value = distribution._value() value._tfp_distribution = distribution if value_is_seq: value.shape = value[-1].shape value.get_shape = value[-1].get_shape value.dtype = value[-1].dtype distribution.shape = value[-1].shape distribution.get_shape = value[-1].get_shape else: distribution.shape = value.shape distribution.get_shape = value.get_shape ### return if return_value: return distribution, value return distribution
"""Pretend user-side class for `ConvertToCompositeTensorTest .""" def __init__(self, sequence): self._sequence = tuple(sequence) def __getitem__(self, key): return self._sequence[key] def __len__(self): return len(self._sequence) def __iter__(self): return iter(self._sequence) tf.register_tensor_conversion_function( MyTuple, conversion_func=lambda x, *_, **__: tensor_tuple.TensorTuple(x)) @test_util.test_all_tf_execution_regimes class CustomConvertToCompositeTensorTest(test_util.TestCase): def test_iter(self): x = MyTuple((1, [2., 3.], [[4, 5], [6, 7]])) y = ops.convert_to_tensor_or_composite(value=x) # TODO(jsimsa): The behavior of `is_tensor` for composite tensors have # changed (from True to False) and this check needs to be disabled so that # both TAP presubmits (running at HEAD) and Kokoro presubmit (using TF # nightly) pass. Re-enable this check when TF nightly picks up this change. # self.assertTrue(tf.is_tensor(y)) self.assertIsInstance(y, tensor_tuple.TensorTuple) self.assertLen(y, 3) for x_, y_ in zip(x, y):