예제 #1
0
def serialize_concrete_function(concrete_function, node_ids):
    """Build a SavedConcreteFunction."""
    bound_inputs = []
    try:
        for capture in concrete_function.captured_inputs:
            bound_inputs.append(node_ids[capture])
    except KeyError:
        raise KeyError(
            f"Failed to add concrete function '{concrete_function.name}' to object-"
            f"based SavedModel as it captures tensor {capture!r} which is unsupported"
            " or not reachable from root. "
            "One reason could be that a stateful object or a variable that the "
            "function depends on is not assigned to an attribute of the serialized "
            "trackable object (see SaveTest.test_captures_unreachable_variable)."
        )
    concrete_function_proto = saved_object_graph_pb2.SavedConcreteFunction()
    structured_outputs = func_graph_module.convert_structure_to_signature(
        concrete_function.structured_outputs)
    concrete_function_proto.canonicalized_input_signature.CopyFrom(
        nested_structure_coder.encode_structure(
            concrete_function.structured_input_signature))
    concrete_function_proto.output_signature.CopyFrom(
        nested_structure_coder.encode_structure(structured_outputs))
    concrete_function_proto.bound_inputs.extend(bound_inputs)
    return concrete_function_proto
예제 #2
0
def _serialize_function_spec(function_spec):
    """Serialize a FunctionSpec object into its proto representation."""
    if function_spec.is_method and not function_spec.fullargspec.args:
        raise NotImplementedError(
            "Cannot serialize a method function without a named "
            "'self' argument.")
    proto = saved_object_graph_pb2.FunctionSpec()

    # Intentionally skip encoding annotations of a function because function
    # annotations are mainly for optional type checking during development
    # and does not affect runtime behavior.
    # https://www.python.org/dev/peps/pep-3107/
    # https://docs.python.org/3/library/inspect.html#inspect.getfullargspec
    proto.fullargspec.CopyFrom(
        nested_structure_coder.encode_structure(
            function_spec.fullargspec._replace(annotations={})))

    proto.is_method = function_spec.is_method
    proto.input_signature.CopyFrom(
        nested_structure_coder.encode_structure(function_spec.input_signature))

    # See `tf.function` and the JitCompile proto for details.
    proto.jit_compile = {
        None: saved_object_graph_pb2.FunctionSpec.JitCompile.DEFAULT,
        True: saved_object_graph_pb2.FunctionSpec.JitCompile.ON,
        False: saved_object_graph_pb2.FunctionSpec.JitCompile.OFF,
    }.get(function_spec.jit_compile)

    return proto
예제 #3
0
def composite_tensor_to_variants(value, type_spec=None, name=None):
    """Encodes `value` as a scalar variant tensor.

  Args:
    value: The `ExtensionType` value to encode.
    type_spec: Information about the value's type that should be included in the
      encoding.
    name: Optional name for the operation.

  Returns:
    A Tensor with shape=`()` and dtype=`tf.variant`.

  Raises:
    ValueError: If `type_spec` is not compatible with `value`.
  """
    if not isinstance(value, composite_tensor.CompositeTensor):
        raise TypeError("Expected `value` to be a CompositeTensor. "
                        f"Received {type(value)}.")

    if type_spec is None:
        type_spec = value._type_spec  # pylint: disable=protected-access
    if not type_spec.is_compatible_with(value):
        raise ValueError(
            f"`type_spec` {type_spec} is not compatible with `value` "
            f"{value!r}.")
    metadata = composite_tensor_variant_pb2.CompositeTensorVariantMetadata()
    metadata.type_spec_proto.CopyFrom(
        nested_structure_coder.encode_structure(type_spec).type_spec_value)

    return gen_composite_tensor_ops.CompositeTensorVariantFromComponents(
        components=nest.flatten(value, expand_composites=True),
        metadata=metadata.SerializeToString(),
        name=name)
 def testEncodeDecodeBoundedTensorSpecNoName(self):
     structure = [
         tensor_spec.BoundedTensorSpec((28, 28, 3), dtypes.float64, -2,
                                       (1, 1, 20))
     ]
     self.assertTrue(nested_structure_coder.can_encode(structure))
     encoded = nested_structure_coder.encode_structure(structure)
     expected = struct_pb2.StructuredValue()
     expected_list = expected.list_value
     expected_tensor_spec = expected_list.values.add(
     ).bounded_tensor_spec_value
     expected_tensor_spec.shape.dim.add().size = 28
     expected_tensor_spec.shape.dim.add().size = 28
     expected_tensor_spec.shape.dim.add().size = 3
     expected_tensor_spec.name = ""
     expected_tensor_spec.dtype = dtypes.float64.as_datatype_enum
     expected_tensor_spec.minimum.CopyFrom(
         tensor_util.make_tensor_proto([-2], dtype=dtypes.float64,
                                       shape=[]))
     expected_tensor_spec.maximum.CopyFrom(
         tensor_util.make_tensor_proto([1, 1, 20],
                                       dtype=dtypes.float64,
                                       shape=[3]))
     self.assertEqual(expected, encoded)
     decoded = nested_structure_coder.decode_proto(encoded)
     self.assertEqual(structure, decoded)
 def testEncodeDecodeSparseTensorSpec(self):
     structure = [sparse_tensor.SparseTensorSpec([10, 20], dtypes.float32)]
     self.assertTrue(nested_structure_coder.can_encode(structure))
     encoded = nested_structure_coder.encode_structure(structure)
     expected_pbtxt = r"""
   list_value {
     values {
       type_spec_value {
         type_spec_class: SPARSE_TENSOR_SPEC
         type_spec_class_name: 'SparseTensorSpec'
         num_flat_components: 3
         type_state {
           tuple_value {
             # spec._shape
             values {
               tensor_shape_value {
                 dim { size: 10 }
                 dim { size: 20 }
               }
             }
             # spec._dtype
             values { tensor_dtype_value: DT_FLOAT }
           }
         }
       }
     }
   }
 """
     expected = struct_pb2.StructuredValue()
     text_format.Parse(expected_pbtxt, expected)
     self.assertEqual(expected, encoded)
     decoded = nested_structure_coder.decode_proto(encoded)
     self.assertEqual(structure, decoded)
 def testBool(self):
     structure = [False]
     self.assertTrue(nested_structure_coder.can_encode(structure))
     encoded = nested_structure_coder.encode_structure(structure)
     expected = struct_pb2.StructuredValue()
     expected.list_value.values.add().bool_value = False
     self.assertEqual(expected, encoded)
     decoded = nested_structure_coder.decode_proto(encoded)
     self.assertEqual(structure, decoded)
예제 #7
0
def _build_composite_tensor_info_internal(tensor):
  """Utility function to build TensorInfo proto from a CompositeTensor."""
  spec = tensor._type_spec  # pylint: disable=protected-access
  tensor_info = meta_graph_pb2.TensorInfo()
  spec_proto = nested_structure_coder.encode_structure(spec)
  tensor_info.composite_tensor.type_spec.CopyFrom(spec_proto.type_spec_value)
  for component in nest.flatten(tensor, expand_composites=True):
    tensor_info.composite_tensor.components.add().CopyFrom(
        build_tensor_info_internal(component))
  return tensor_info
 def testDtype(self):
     structure = [dtypes.int64]
     self.assertTrue(nested_structure_coder.can_encode(structure))
     encoded = nested_structure_coder.encode_structure(structure)
     expected = struct_pb2.StructuredValue()
     list_value = expected.list_value.values.add()
     list_value.tensor_dtype_value = dtypes.int64.as_datatype_enum
     self.assertEqual(expected, encoded)
     decoded = nested_structure_coder.decode_proto(encoded)
     self.assertEqual(structure, decoded)
예제 #9
0
def create_config(pattern: Pattern,
                  table: str,
                  conditions: Sequence[patterns_pb2.Condition] = ()):
    structure = tree.map_structure(lambda _: None, pattern)
    return patterns_pb2.StructuredWriterConfig(
        flat=tree.flatten(pattern),
        pattern_structure=nested_structure_coder.encode_structure(structure),
        table=table,
        priority=1.0,
        conditions=conditions)
 def testEncodeDecodeList(self):
     structure = [1.5, 2.5, 3.0]
     self.assertTrue(nested_structure_coder.can_encode(structure))
     encoded = nested_structure_coder.encode_structure(structure)
     expected = struct_pb2.StructuredValue()
     expected.list_value.values.add().float64_value = 1.5
     expected.list_value.values.add().float64_value = 2.5
     expected.list_value.values.add().float64_value = 3.0
     self.assertEqual(expected, encoded)
     decoded = nested_structure_coder.decode_proto(encoded)
     self.assertEqual(structure, decoded)
 def testNone(self):
     structure = [1.0, None]
     self.assertTrue(nested_structure_coder.can_encode(structure))
     encoded = nested_structure_coder.encode_structure(structure)
     expected = struct_pb2.StructuredValue()
     expected.list_value.values.add().float64_value = 1.0
     expected.list_value.values.add().none_value.CopyFrom(
         struct_pb2.NoneValue())
     self.assertEqual(expected, encoded)
     decoded = nested_structure_coder.decode_proto(encoded)
     self.assertEqual(structure, decoded)
예제 #12
0
def to_proto(spec):
    """Encodes a nested spec into a struct_pb2.StructuredValue proto.

  Args:
    spec: Nested list/tuple or dict of TensorSpecs, describing the
      shape of the non-batched Tensors.
  Returns:
    A `struct_pb2.StructuredValue` proto.
  """
    # Make sure spec is a tensor_spec.
    spec = from_spec(spec)
    return nested_structure_coder.encode_structure(spec)
 def testEncodeDecodeDict(self):
     structure = dict(a=3, b=[7, 2.5])
     self.assertTrue(nested_structure_coder.can_encode(structure))
     encoded = nested_structure_coder.encode_structure(structure)
     expected = struct_pb2.StructuredValue()
     expected.dict_value.fields["a"].int64_value = 3
     list_value = expected.dict_value.fields["b"].list_value
     list_value.values.add().int64_value = 7
     list_value.values.add().float64_value = 2.5
     self.assertEqual(expected, encoded)
     decoded = nested_structure_coder.decode_proto(encoded)
     self.assertIsInstance(decoded["a"], int)
     self.assertEqual(structure, decoded)
    def testRegisteredTypeSpec(self):
        expected_warning = ("Encoding a StructuredValue with type "
                            "NestedStructureTest.RegisteredTypeSpec; loading "
                            "this StructuredValue will require that this type "
                            "be imported and registered")
        structure = {"x": RegisteredTypeSpec()}

        self.assertTrue(nested_structure_coder.can_encode(structure))
        with warnings.catch_warnings(record=True) as w:
            encoded = nested_structure_coder.encode_structure(structure)
            self.assertLen(w, 1)
            self.assertIn(expected_warning, str(w[0].message))
        decoded = nested_structure_coder.decode_proto(encoded)
        self.assertEqual(structure, decoded)
 def testEmptyStructures(self):
     structure = [list(), dict(), tuple()]
     self.assertTrue(nested_structure_coder.can_encode(structure))
     encoded = nested_structure_coder.encode_structure(structure)
     expected = struct_pb2.StructuredValue()
     expected.list_value.values.add().list_value.CopyFrom(
         struct_pb2.ListValue())
     expected.list_value.values.add().dict_value.CopyFrom(
         struct_pb2.DictValue())
     expected.list_value.values.add().tuple_value.CopyFrom(
         struct_pb2.TupleValue())
     self.assertEqual(expected, encoded)
     decoded = nested_structure_coder.decode_proto(encoded)
     self.assertEqual(structure, decoded)
 def testEncodeDecodeTuple(self):
     structure = ("hello", [3, (2, 1)])
     self.assertTrue(nested_structure_coder.can_encode(structure))
     encoded = nested_structure_coder.encode_structure(structure)
     expected = struct_pb2.StructuredValue()
     expected.tuple_value.values.add().string_value = "hello"
     list_value = expected.tuple_value.values.add().list_value
     list_value.values.add().int64_value = 3
     tuple_value = list_value.values.add().tuple_value
     tuple_value.values.add().int64_value = 2
     tuple_value.values.add().int64_value = 1
     self.assertEqual(expected, encoded)
     decoded = nested_structure_coder.decode_proto(encoded)
     self.assertEqual(structure, decoded)
    def test_composite_tensor(self):
        with self.session(graph=ops.Graph()) as sess:
            sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
            operator, mat = self.operator_and_matrix(
                shapes_info, dtype, use_placeholder=use_placeholder)
            self.assertIsInstance(operator, composite_tensor.CompositeTensor)

            flat = nest.flatten(operator, expand_composites=True)
            unflat = nest.pack_sequence_as(operator,
                                           flat,
                                           expand_composites=True)
            self.assertIsInstance(unflat, type(operator))

            # Input the operator to a `tf.function`.
            x = self.make_x(operator, adjoint=False)
            op_y = def_function.function(lambda op: op.matmul(x))(unflat)
            mat_y = math_ops.matmul(mat, x)

            if not use_placeholder:
                self.assertAllEqual(mat_y.shape, op_y.shape)

            # Test while_loop.
            def body(op):
                return type(op)(**op.parameters),

            op_out, = while_v2.while_loop(cond=lambda _: True,
                                          body=body,
                                          loop_vars=(operator, ),
                                          maximum_iterations=3)
            loop_y = op_out.matmul(x)

            op_y_, loop_y_, mat_y_ = sess.run([op_y, loop_y, mat_y])
            self.assertAC(op_y_, mat_y_)
            self.assertAC(loop_y_, mat_y_)

            # Ensure that the `TypeSpec` can be encoded.
            nested_structure_coder.encode_structure(operator._type_spec)  # pylint: disable=protected-access
 def testEncodeDecodeTensorSpecWithNoName(self):
     structure = [tensor_spec.TensorSpec([1, 2, 3], dtypes.int64)]
     self.assertTrue(nested_structure_coder.can_encode(structure))
     encoded = nested_structure_coder.encode_structure(structure)
     expected = struct_pb2.StructuredValue()
     expected_list = expected.list_value
     expected_tensor_spec = expected_list.values.add().tensor_spec_value
     expected_tensor_spec.shape.dim.add().size = 1
     expected_tensor_spec.shape.dim.add().size = 2
     expected_tensor_spec.shape.dim.add().size = 3
     expected_tensor_spec.name = ""
     expected_tensor_spec.dtype = dtypes.int64.as_datatype_enum
     self.assertEqual(expected, encoded)
     decoded = nested_structure_coder.decode_proto(encoded)
     self.assertEqual(structure, decoded)
 def testEncodeDecodeTensorShape(self):
     structure = [tensor_shape.TensorShape([1, 2, 3]), "hello"]
     self.assertTrue(nested_structure_coder.can_encode(structure))
     encoded = nested_structure_coder.encode_structure(structure)
     expected = struct_pb2.StructuredValue()
     expected_list = expected.list_value
     expected_tensor_shape = expected_list.values.add().tensor_shape_value
     expected_tensor_shape.dim.add().size = 1
     expected_tensor_shape.dim.add().size = 2
     expected_tensor_shape.dim.add().size = 3
     expected_tensor_shape = expected_list.values.add(
     ).string_value = "hello"
     self.assertEqual(expected, encoded)
     decoded = nested_structure_coder.decode_proto(encoded)
     self.assertEqual(structure, decoded)
 def testEncodeDataSetSpec(self):
     structure = [
         dataset_ops.DatasetSpec({
             "rt":
             ragged_tensor.RaggedTensorSpec([10, None], dtypes.int32),
             "st":
             sparse_tensor.SparseTensorSpec([10, 20], dtypes.float32),
             "t":
             tensor_spec.TensorSpec([10, 8], dtypes.string)
         })
     ]
     self.assertTrue(nested_structure_coder.can_encode(structure))
     encoded = nested_structure_coder.encode_structure(structure)
     decoded = nested_structure_coder.decode_proto(encoded)
     self.assertEqual(structure, decoded)
    def testEncodeDecodeExtensionTypeSpec(self):
        class Zoo(extension_type.ExtensionType):
            __name__ = "tf.nested_structure_coder_test.Zoo"
            zookeepers: typing.Tuple[str, ...]
            animals: typing.Mapping[str, ops.Tensor]

        structure = [
            Zoo.Spec(zookeepers=["Zoey", "Zack"],
                     animals={"tiger": tensor_spec.TensorSpec([16])})
        ]

        self.assertTrue(nested_structure_coder.can_encode(structure))
        encoded = nested_structure_coder.encode_structure(structure)
        expected_pbtxt = r"""
    list_value {
      values {
        type_spec_value {
          type_spec_class: EXTENSION_TYPE_SPEC
          type_spec_class_name: "tf.nested_structure_coder_test.Zoo.Spec"
          num_flat_components: 1
          type_state {
            tuple_value {
              values {
                tuple_value {
                  values { string_value: "zookeepers" }
                  values { tuple_value {
                    values { string_value: "Zoey" }
                    values { string_value: "Zack" } } } } }
              values {
                tuple_value {
                  values { string_value: "animals" }
                  values { dict_value {
                    fields {
                      key: "tiger"
                      value { tensor_spec_value {
                        shape { dim { size: 16 } }
                        dtype: DT_FLOAT } } } } } } } } } } } }
    """
        expected = struct_pb2.StructuredValue()
        text_format.Parse(expected_pbtxt, expected)
        self.assertEqual(expected, encoded)
        decoded = nested_structure_coder.decode_proto(encoded)
        self.assertEqual(structure, decoded)
 def testEncodeDecodeRaggedTensorSpec(self):
     structure = [
         ragged_tensor.RaggedTensorSpec([1, 2, 3], dtypes.int64, 2,
                                        dtypes.int32)
     ]
     self.assertTrue(nested_structure_coder.can_encode(structure))
     encoded = nested_structure_coder.encode_structure(structure)
     expected_pbtxt = r"""
   list_value {
     values {
       type_spec_value {
         type_spec_class: RAGGED_TENSOR_SPEC
         type_spec_class_name: 'RaggedTensorSpec'
         num_flat_components: 3
         type_state {
           tuple_value {
             # spec._shape
             values {
               tensor_shape_value {
                 dim { size: 1 }
                 dim { size: 2 }
                 dim { size: 3 }
               }
             }
             # spec._dtype
             values { tensor_dtype_value: DT_INT64 }
             # spec._ragged_rank
             values { int64_value: 2 }
             # spec._row_splits_dtype
             values { tensor_dtype_value: DT_INT32 }
           }
         }
       }
     }
   }
 """
     expected = struct_pb2.StructuredValue()
     text_format.Parse(expected_pbtxt, expected)
     self.assertEqual(expected, encoded)
     decoded = nested_structure_coder.decode_proto(encoded)
     self.assertEqual(structure, decoded)
 def testEncodeDecodeNamedTuple(self):
     named_tuple_type = collections.namedtuple("NamedTuple", ["x", "y"])
     named_tuple = named_tuple_type(x=[1, 2], y="hello")
     self.assertTrue(nested_structure_coder.can_encode(named_tuple))
     encoded = nested_structure_coder.encode_structure(named_tuple)
     expected = struct_pb2.StructuredValue()
     expected_named_tuple = expected.named_tuple_value
     expected_named_tuple.name = "NamedTuple"
     key_value_pair = expected_named_tuple.values.add()
     key_value_pair.key = "x"
     list_value = key_value_pair.value.list_value
     list_value.values.add().int64_value = 1
     list_value.values.add().int64_value = 2
     key_value_pair = expected_named_tuple.values.add()
     key_value_pair.key = "y"
     key_value_pair.value.string_value = "hello"
     self.assertEqual(expected, encoded)
     decoded = nested_structure_coder.decode_proto(encoded)
     self.assertEqual(named_tuple._asdict(), decoded._asdict())
     self.assertEqual(named_tuple.__class__.__name__,
                      decoded.__class__.__name__)
예제 #24
0
def composite_tensor_from_variant(encoded, type_spec, name=None):
    """Returns the `ExtensionType` value encoded by a variant scalar tensor.

  Args:
    encoded: A Tensor returned by `composite_tensor_to_variants`.
    type_spec: The `TypeSpec` of the original value.  This is used to determine
      the number and types of the component tensors that comprise the decoded
      value.  Must be compatible with the `TypeSpec` serilized in `encoded`.
    name: Optional name for the operation.

  Returns:
    An `ExtensionType` value that is compatible with `TypeSpec`.

  Raises:
    TypeError: If `encoded` is not a Tensor with dtype=variant.
    InvalidArgumentError: If `encoded` is not compatible with `type_spec`.
  """
    if not isinstance(encoded, ops.Tensor):
        raise TypeError(f"Expected `encoded` to be a Tensor, got {encoded!r}.")
    if encoded.dtype != dtypes.variant:
        raise TypeError("Expected `encoded` to have dtype=variant, got "
                        f"{encoded!r}.")
    encoded.shape.assert_is_compatible_with(())

    metadata = composite_tensor_variant_pb2.CompositeTensorVariantMetadata()
    metadata.type_spec_proto.CopyFrom(
        nested_structure_coder.encode_structure(type_spec).type_spec_value)

    component_dtypes = [
        t.dtype for t in nest.flatten(type_spec, expand_composites=True)
    ]

    components = gen_composite_tensor_ops.CompositeTensorVariantToComponents(
        encoded=encoded,
        metadata=metadata.SerializeToString(),
        Tcomponents=component_dtypes,
        name=name)
    return nest.pack_sequence_as(type_spec, components, expand_composites=True)
 def testEncodeDecodeBoundedTensorSpec(self):
     structure = [
         tensor_spec.BoundedTensorSpec([1, 2, 3], dtypes.int64, 0, 10,
                                       "hello-0-10")
     ]
     self.assertTrue(nested_structure_coder.can_encode(structure))
     encoded = nested_structure_coder.encode_structure(structure)
     expected = struct_pb2.StructuredValue()
     expected_list = expected.list_value
     expected_tensor_spec = expected_list.values.add(
     ).bounded_tensor_spec_value
     expected_tensor_spec.shape.dim.add().size = 1
     expected_tensor_spec.shape.dim.add().size = 2
     expected_tensor_spec.shape.dim.add().size = 3
     expected_tensor_spec.name = "hello-0-10"
     expected_tensor_spec.dtype = dtypes.int64.as_datatype_enum
     expected_tensor_spec.minimum.CopyFrom(
         tensor_util.make_tensor_proto([0], dtype=dtypes.int64, shape=[]))
     expected_tensor_spec.maximum.CopyFrom(
         tensor_util.make_tensor_proto([10], dtype=dtypes.int64, shape=[]))
     self.assertEqual(expected, encoded)
     decoded = nested_structure_coder.decode_proto(encoded)
     self.assertEqual(structure, decoded)
예제 #26
0
def as_composite(obj):
    """Returns a `CompositeTensor` equivalent to the given object.

  Note that the returned object will have any `Variable`,
  `tfp.util.DeferredTensor`, or `tfp.util.TransformedVariable` references it
  closes over converted to tensors at the time this function is called. The
  type of the returned object will be a subclass of both `CompositeTensor` and
  `type(obj)`.  For this reason, one should be careful about using
  `as_composite()`, especially for `tf.Module` objects.

  For example, when the composite tensor is created even as part of a
  `tf.Module`, it "fixes" the values of the `DeferredTensor` and `tf.Variable`
  objects it uses:

  ```python
  class M(tf.Module):
    def __init__(self):
      self._v = tf.Variable(1.)
      self._d = tfp.distributions.Normal(
        tfp.util.DeferredTensor(self._v, lambda v: v + 1), 10)
      self._dct = tfp.experimental.as_composite(self._d)

    @tf.function
    def mean(self):
      return self._dct.mean()

  m = M()
  m.mean()
  >>> <tf.Tensor: numpy=2.0>
  m._v.assign(2.)  # Doesn't update the CompositeTensor distribution.
  m.mean()
  >>> <tf.Tensor: numpy=2.0>
  ```

  If, however, the creation of the composite is deferred to a method
  call, then the Variable and DeferredTensor will be properly captured
  and respected by the Module and its `SavedModel` (if it is serialized).

  ```python
  class M(tf.Module):
    def __init__(self):
      self._v = tf.Variable(1.)
      self._d = tfp.distributions.Normal(
        tfp.util.DeferredTensor(self._v, lambda v: v + 1), 10)

    @tf.function
    def d(self):
      return tfp.experimental.as_composite(self._d)

  m = M()
  m.d().mean()
  >>> <tf.Tensor: numpy=2.0>
  m._v.assign(2.)
  m.d().mean()
  >>> <tf.Tensor: numpy=3.0>
  ```

  Note: This method is best-effort and based on a heuristic for what the
  tensor parameters are and what the non-tensor parameters are. Things might be
  broken, especially for meta-distributions like `TransformedDistribution` or
  `Independent`. (We try to raise NotImplementedError in such cases.) If you'd
  benefit from better coverage, please file an issue on github or send an email
  to `[email protected]`.

  Args:
    obj: A `tfp.distributions.Distribution`.

  Returns:
    obj: A `tfp.distributions.Distribution` that extends `CompositeTensor`.
  """
    if isinstance(obj, CompositeTensor):
        return obj
    cls = _make_convertible(type(obj))
    kwargs = dict(obj.parameters)

    def mk_err_msg(suffix=''):
        return (
            'Unable to make a CompositeTensor for "{}" of type `{}`. Email '
            '`[email protected]` or file an issue on github if you '
            'would benefit from this working. {}'.format(
                obj, type(obj), suffix))

    try:
        composite_tensor_params = obj._composite_tensor_params  # pylint: disable=protected-access
    except (AttributeError, NotImplementedError):
        composite_tensor_params = ()
    for k in composite_tensor_params:
        # Use dtype inference from ctor.
        if k in kwargs and kwargs[k] is not None:
            v = getattr(obj, k, kwargs[k])
            try:
                kwargs[k] = tf.convert_to_tensor(v, name=k)
            except (ValueError, TypeError) as e:
                kwargs[k] = v
    for k, v in kwargs.items():

        def composite_helper(v):
            # If we have a parameters attribute, then we may be able to convert to
            # a composite tensor by guessing which of the parameters are tensors.  In
            # essence, we duck-type based on this attribute.
            if hasattr(v, 'parameters'):
                return as_composite(v)
            return v

        kwargs[k] = tf.nest.map_structure(composite_helper, v)
        # Unfortunately, tensor_util.is_ref(v) returns true for a
        # tf.linalg.LinearOperator even though that is not ideal behavior.
        if tensor_util.is_ref(v) and not isinstance(v,
                                                    tf.linalg.LinearOperator):
            try:
                kwargs[k] = tf.convert_to_tensor(v, name=k)
            except TypeError as e:
                raise NotImplementedError(
                    mk_err_msg(
                        '(Unable to convert dependent entry \'{}\' of object '
                        '\'{}\': {})'.format(k, obj, str(e))))
    result = cls(**kwargs)
    try:
        nested_structure_coder.encode_structure(result._type_spec)  # pylint: disable=protected-access
    except nested_structure_coder.NotEncodableError as e:
        raise NotImplementedError(
            mk_err_msg('(Unable to serialize: {})'.format(str(e))))
    return result
예제 #27
0
def get_input_specs_from_function(func: tf_function.ConcreteFunction):
    arg_specs, _ = func.structured_input_signature
    arg_specs_proto = nested_structure_coder.encode_structure(arg_specs)
    return arg_specs_proto.SerializeToString()
예제 #28
0
def get_output_specs_from_function(func: tf_function.ConcreteFunction):
    output_specs = nest.map_structure(type_spec.type_spec_from_value,
                                      func.structured_outputs)
    output_specs_proto = nested_structure_coder.encode_structure(output_specs)
    return output_specs_proto.SerializeToString()
예제 #29
0
    def __init__(self,
                 name: str,
                 sampler: reverb_types.SelectorType,
                 remover: reverb_types.SelectorType,
                 max_size: int,
                 rate_limiter: rate_limiters.RateLimiter,
                 max_times_sampled: int = 0,
                 extensions: Sequence[TableExtensionBase] = (),
                 signature: Optional[reverb_types.SpecNest] = None):
        """Constructor of the Table.

    Args:
      name: Name of the priority table.
      sampler: The strategy to use when selecting samples.
      remover: The strategy to use when selecting which items to remove.
      max_size: The maximum number of items which the replay is allowed to hold.
        When an item is inserted into an already full priority table the
        `remover` is used for selecting which item to remove before proceeding
        with the new insert.
      rate_limiter: Manages the data flow by limiting the sample and insert
        calls.
      max_times_sampled: Maximum number of times an item can be sampled before
        it is deleted. Any value < 1 is ignored and means there is no limit.
      extensions: Optional sequence of extensions used to add extra features to
        the table.
      signature: Optional nested structure containing `tf.TypeSpec` objects,
        describing the schema of items in this table.

    Raises:
      ValueError: If name is empty.
      ValueError: If max_size <= 0.
    """
        if not name:
            raise ValueError('name must be nonempty')
        if max_size <= 0:
            raise ValueError('max_size (%d) must be a positive integer' %
                             max_size)
        self._sampler = sampler
        self._remover = remover
        self._rate_limiter = rate_limiter
        self._extensions = extensions
        self._signature = signature

        # Merge the c++ extensions into a single list.
        internal_extensions = []
        for extension in extensions:
            internal_extensions += list(
                extension.build_internal_extensions(name))

        if signature:
            flat_signature = tree.flatten(signature)
            for s in flat_signature:
                if not isinstance(s, tensor_spec.TensorSpec):
                    raise ValueError(f'Unsupported signature spec: {s}')
            signature_proto_str = (nested_structure_coder.encode_structure(
                signature).SerializeToString())
        else:
            signature_proto_str = None

        self.internal_table = pybind.Table(
            name=name,
            sampler=sampler,
            remover=remover,
            max_size=max_size,
            max_times_sampled=max_times_sampled,
            rate_limiter=rate_limiter.internal_limiter,
            extensions=internal_extensions,
            signature=signature_proto_str)
 def testUnregisteredTypeSpec(self):
     structure = {"x": UnregisteredTypeSpec()}
     self.assertFalse(nested_structure_coder.can_encode(structure))
     with self.assertRaises(nested_structure_coder.NotEncodableError):
         nested_structure_coder.encode_structure(structure)