def serialize_concrete_function(concrete_function, node_ids): """Build a SavedConcreteFunction.""" bound_inputs = [] try: for capture in concrete_function.captured_inputs: bound_inputs.append(node_ids[capture]) except KeyError: raise KeyError( f"Failed to add concrete function '{concrete_function.name}' to object-" f"based SavedModel as it captures tensor {capture!r} which is unsupported" " or not reachable from root. " "One reason could be that a stateful object or a variable that the " "function depends on is not assigned to an attribute of the serialized " "trackable object (see SaveTest.test_captures_unreachable_variable)." ) concrete_function_proto = saved_object_graph_pb2.SavedConcreteFunction() structured_outputs = func_graph_module.convert_structure_to_signature( concrete_function.structured_outputs) concrete_function_proto.canonicalized_input_signature.CopyFrom( nested_structure_coder.encode_structure( concrete_function.structured_input_signature)) concrete_function_proto.output_signature.CopyFrom( nested_structure_coder.encode_structure(structured_outputs)) concrete_function_proto.bound_inputs.extend(bound_inputs) return concrete_function_proto
def _serialize_function_spec(function_spec): """Serialize a FunctionSpec object into its proto representation.""" if function_spec.is_method and not function_spec.fullargspec.args: raise NotImplementedError( "Cannot serialize a method function without a named " "'self' argument.") proto = saved_object_graph_pb2.FunctionSpec() # Intentionally skip encoding annotations of a function because function # annotations are mainly for optional type checking during development # and does not affect runtime behavior. # https://www.python.org/dev/peps/pep-3107/ # https://docs.python.org/3/library/inspect.html#inspect.getfullargspec proto.fullargspec.CopyFrom( nested_structure_coder.encode_structure( function_spec.fullargspec._replace(annotations={}))) proto.is_method = function_spec.is_method proto.input_signature.CopyFrom( nested_structure_coder.encode_structure(function_spec.input_signature)) # See `tf.function` and the JitCompile proto for details. proto.jit_compile = { None: saved_object_graph_pb2.FunctionSpec.JitCompile.DEFAULT, True: saved_object_graph_pb2.FunctionSpec.JitCompile.ON, False: saved_object_graph_pb2.FunctionSpec.JitCompile.OFF, }.get(function_spec.jit_compile) return proto
def composite_tensor_to_variants(value, type_spec=None, name=None): """Encodes `value` as a scalar variant tensor. Args: value: The `ExtensionType` value to encode. type_spec: Information about the value's type that should be included in the encoding. name: Optional name for the operation. Returns: A Tensor with shape=`()` and dtype=`tf.variant`. Raises: ValueError: If `type_spec` is not compatible with `value`. """ if not isinstance(value, composite_tensor.CompositeTensor): raise TypeError("Expected `value` to be a CompositeTensor. " f"Received {type(value)}.") if type_spec is None: type_spec = value._type_spec # pylint: disable=protected-access if not type_spec.is_compatible_with(value): raise ValueError( f"`type_spec` {type_spec} is not compatible with `value` " f"{value!r}.") metadata = composite_tensor_variant_pb2.CompositeTensorVariantMetadata() metadata.type_spec_proto.CopyFrom( nested_structure_coder.encode_structure(type_spec).type_spec_value) return gen_composite_tensor_ops.CompositeTensorVariantFromComponents( components=nest.flatten(value, expand_composites=True), metadata=metadata.SerializeToString(), name=name)
def testEncodeDecodeBoundedTensorSpecNoName(self): structure = [ tensor_spec.BoundedTensorSpec((28, 28, 3), dtypes.float64, -2, (1, 1, 20)) ] self.assertTrue(nested_structure_coder.can_encode(structure)) encoded = nested_structure_coder.encode_structure(structure) expected = struct_pb2.StructuredValue() expected_list = expected.list_value expected_tensor_spec = expected_list.values.add( ).bounded_tensor_spec_value expected_tensor_spec.shape.dim.add().size = 28 expected_tensor_spec.shape.dim.add().size = 28 expected_tensor_spec.shape.dim.add().size = 3 expected_tensor_spec.name = "" expected_tensor_spec.dtype = dtypes.float64.as_datatype_enum expected_tensor_spec.minimum.CopyFrom( tensor_util.make_tensor_proto([-2], dtype=dtypes.float64, shape=[])) expected_tensor_spec.maximum.CopyFrom( tensor_util.make_tensor_proto([1, 1, 20], dtype=dtypes.float64, shape=[3])) self.assertEqual(expected, encoded) decoded = nested_structure_coder.decode_proto(encoded) self.assertEqual(structure, decoded)
def testEncodeDecodeSparseTensorSpec(self): structure = [sparse_tensor.SparseTensorSpec([10, 20], dtypes.float32)] self.assertTrue(nested_structure_coder.can_encode(structure)) encoded = nested_structure_coder.encode_structure(structure) expected_pbtxt = r""" list_value { values { type_spec_value { type_spec_class: SPARSE_TENSOR_SPEC type_spec_class_name: 'SparseTensorSpec' num_flat_components: 3 type_state { tuple_value { # spec._shape values { tensor_shape_value { dim { size: 10 } dim { size: 20 } } } # spec._dtype values { tensor_dtype_value: DT_FLOAT } } } } } } """ expected = struct_pb2.StructuredValue() text_format.Parse(expected_pbtxt, expected) self.assertEqual(expected, encoded) decoded = nested_structure_coder.decode_proto(encoded) self.assertEqual(structure, decoded)
def testBool(self): structure = [False] self.assertTrue(nested_structure_coder.can_encode(structure)) encoded = nested_structure_coder.encode_structure(structure) expected = struct_pb2.StructuredValue() expected.list_value.values.add().bool_value = False self.assertEqual(expected, encoded) decoded = nested_structure_coder.decode_proto(encoded) self.assertEqual(structure, decoded)
def _build_composite_tensor_info_internal(tensor): """Utility function to build TensorInfo proto from a CompositeTensor.""" spec = tensor._type_spec # pylint: disable=protected-access tensor_info = meta_graph_pb2.TensorInfo() spec_proto = nested_structure_coder.encode_structure(spec) tensor_info.composite_tensor.type_spec.CopyFrom(spec_proto.type_spec_value) for component in nest.flatten(tensor, expand_composites=True): tensor_info.composite_tensor.components.add().CopyFrom( build_tensor_info_internal(component)) return tensor_info
def testDtype(self): structure = [dtypes.int64] self.assertTrue(nested_structure_coder.can_encode(structure)) encoded = nested_structure_coder.encode_structure(structure) expected = struct_pb2.StructuredValue() list_value = expected.list_value.values.add() list_value.tensor_dtype_value = dtypes.int64.as_datatype_enum self.assertEqual(expected, encoded) decoded = nested_structure_coder.decode_proto(encoded) self.assertEqual(structure, decoded)
def create_config(pattern: Pattern, table: str, conditions: Sequence[patterns_pb2.Condition] = ()): structure = tree.map_structure(lambda _: None, pattern) return patterns_pb2.StructuredWriterConfig( flat=tree.flatten(pattern), pattern_structure=nested_structure_coder.encode_structure(structure), table=table, priority=1.0, conditions=conditions)
def testEncodeDecodeList(self): structure = [1.5, 2.5, 3.0] self.assertTrue(nested_structure_coder.can_encode(structure)) encoded = nested_structure_coder.encode_structure(structure) expected = struct_pb2.StructuredValue() expected.list_value.values.add().float64_value = 1.5 expected.list_value.values.add().float64_value = 2.5 expected.list_value.values.add().float64_value = 3.0 self.assertEqual(expected, encoded) decoded = nested_structure_coder.decode_proto(encoded) self.assertEqual(structure, decoded)
def testNone(self): structure = [1.0, None] self.assertTrue(nested_structure_coder.can_encode(structure)) encoded = nested_structure_coder.encode_structure(structure) expected = struct_pb2.StructuredValue() expected.list_value.values.add().float64_value = 1.0 expected.list_value.values.add().none_value.CopyFrom( struct_pb2.NoneValue()) self.assertEqual(expected, encoded) decoded = nested_structure_coder.decode_proto(encoded) self.assertEqual(structure, decoded)
def to_proto(spec): """Encodes a nested spec into a struct_pb2.StructuredValue proto. Args: spec: Nested list/tuple or dict of TensorSpecs, describing the shape of the non-batched Tensors. Returns: A `struct_pb2.StructuredValue` proto. """ # Make sure spec is a tensor_spec. spec = from_spec(spec) return nested_structure_coder.encode_structure(spec)
def testEncodeDecodeDict(self): structure = dict(a=3, b=[7, 2.5]) self.assertTrue(nested_structure_coder.can_encode(structure)) encoded = nested_structure_coder.encode_structure(structure) expected = struct_pb2.StructuredValue() expected.dict_value.fields["a"].int64_value = 3 list_value = expected.dict_value.fields["b"].list_value list_value.values.add().int64_value = 7 list_value.values.add().float64_value = 2.5 self.assertEqual(expected, encoded) decoded = nested_structure_coder.decode_proto(encoded) self.assertIsInstance(decoded["a"], int) self.assertEqual(structure, decoded)
def testRegisteredTypeSpec(self): expected_warning = ("Encoding a StructuredValue with type " "NestedStructureTest.RegisteredTypeSpec; loading " "this StructuredValue will require that this type " "be imported and registered") structure = {"x": RegisteredTypeSpec()} self.assertTrue(nested_structure_coder.can_encode(structure)) with warnings.catch_warnings(record=True) as w: encoded = nested_structure_coder.encode_structure(structure) self.assertLen(w, 1) self.assertIn(expected_warning, str(w[0].message)) decoded = nested_structure_coder.decode_proto(encoded) self.assertEqual(structure, decoded)
def testEmptyStructures(self): structure = [list(), dict(), tuple()] self.assertTrue(nested_structure_coder.can_encode(structure)) encoded = nested_structure_coder.encode_structure(structure) expected = struct_pb2.StructuredValue() expected.list_value.values.add().list_value.CopyFrom( struct_pb2.ListValue()) expected.list_value.values.add().dict_value.CopyFrom( struct_pb2.DictValue()) expected.list_value.values.add().tuple_value.CopyFrom( struct_pb2.TupleValue()) self.assertEqual(expected, encoded) decoded = nested_structure_coder.decode_proto(encoded) self.assertEqual(structure, decoded)
def testEncodeDecodeTuple(self): structure = ("hello", [3, (2, 1)]) self.assertTrue(nested_structure_coder.can_encode(structure)) encoded = nested_structure_coder.encode_structure(structure) expected = struct_pb2.StructuredValue() expected.tuple_value.values.add().string_value = "hello" list_value = expected.tuple_value.values.add().list_value list_value.values.add().int64_value = 3 tuple_value = list_value.values.add().tuple_value tuple_value.values.add().int64_value = 2 tuple_value.values.add().int64_value = 1 self.assertEqual(expected, encoded) decoded = nested_structure_coder.decode_proto(encoded) self.assertEqual(structure, decoded)
def test_composite_tensor(self): with self.session(graph=ops.Graph()) as sess: sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED operator, mat = self.operator_and_matrix( shapes_info, dtype, use_placeholder=use_placeholder) self.assertIsInstance(operator, composite_tensor.CompositeTensor) flat = nest.flatten(operator, expand_composites=True) unflat = nest.pack_sequence_as(operator, flat, expand_composites=True) self.assertIsInstance(unflat, type(operator)) # Input the operator to a `tf.function`. x = self.make_x(operator, adjoint=False) op_y = def_function.function(lambda op: op.matmul(x))(unflat) mat_y = math_ops.matmul(mat, x) if not use_placeholder: self.assertAllEqual(mat_y.shape, op_y.shape) # Test while_loop. def body(op): return type(op)(**op.parameters), op_out, = while_v2.while_loop(cond=lambda _: True, body=body, loop_vars=(operator, ), maximum_iterations=3) loop_y = op_out.matmul(x) op_y_, loop_y_, mat_y_ = sess.run([op_y, loop_y, mat_y]) self.assertAC(op_y_, mat_y_) self.assertAC(loop_y_, mat_y_) # Ensure that the `TypeSpec` can be encoded. nested_structure_coder.encode_structure(operator._type_spec) # pylint: disable=protected-access
def testEncodeDecodeTensorSpecWithNoName(self): structure = [tensor_spec.TensorSpec([1, 2, 3], dtypes.int64)] self.assertTrue(nested_structure_coder.can_encode(structure)) encoded = nested_structure_coder.encode_structure(structure) expected = struct_pb2.StructuredValue() expected_list = expected.list_value expected_tensor_spec = expected_list.values.add().tensor_spec_value expected_tensor_spec.shape.dim.add().size = 1 expected_tensor_spec.shape.dim.add().size = 2 expected_tensor_spec.shape.dim.add().size = 3 expected_tensor_spec.name = "" expected_tensor_spec.dtype = dtypes.int64.as_datatype_enum self.assertEqual(expected, encoded) decoded = nested_structure_coder.decode_proto(encoded) self.assertEqual(structure, decoded)
def testEncodeDecodeTensorShape(self): structure = [tensor_shape.TensorShape([1, 2, 3]), "hello"] self.assertTrue(nested_structure_coder.can_encode(structure)) encoded = nested_structure_coder.encode_structure(structure) expected = struct_pb2.StructuredValue() expected_list = expected.list_value expected_tensor_shape = expected_list.values.add().tensor_shape_value expected_tensor_shape.dim.add().size = 1 expected_tensor_shape.dim.add().size = 2 expected_tensor_shape.dim.add().size = 3 expected_tensor_shape = expected_list.values.add( ).string_value = "hello" self.assertEqual(expected, encoded) decoded = nested_structure_coder.decode_proto(encoded) self.assertEqual(structure, decoded)
def testEncodeDataSetSpec(self): structure = [ dataset_ops.DatasetSpec({ "rt": ragged_tensor.RaggedTensorSpec([10, None], dtypes.int32), "st": sparse_tensor.SparseTensorSpec([10, 20], dtypes.float32), "t": tensor_spec.TensorSpec([10, 8], dtypes.string) }) ] self.assertTrue(nested_structure_coder.can_encode(structure)) encoded = nested_structure_coder.encode_structure(structure) decoded = nested_structure_coder.decode_proto(encoded) self.assertEqual(structure, decoded)
def testEncodeDecodeExtensionTypeSpec(self): class Zoo(extension_type.ExtensionType): __name__ = "tf.nested_structure_coder_test.Zoo" zookeepers: typing.Tuple[str, ...] animals: typing.Mapping[str, ops.Tensor] structure = [ Zoo.Spec(zookeepers=["Zoey", "Zack"], animals={"tiger": tensor_spec.TensorSpec([16])}) ] self.assertTrue(nested_structure_coder.can_encode(structure)) encoded = nested_structure_coder.encode_structure(structure) expected_pbtxt = r""" list_value { values { type_spec_value { type_spec_class: EXTENSION_TYPE_SPEC type_spec_class_name: "tf.nested_structure_coder_test.Zoo.Spec" num_flat_components: 1 type_state { tuple_value { values { tuple_value { values { string_value: "zookeepers" } values { tuple_value { values { string_value: "Zoey" } values { string_value: "Zack" } } } } } values { tuple_value { values { string_value: "animals" } values { dict_value { fields { key: "tiger" value { tensor_spec_value { shape { dim { size: 16 } } dtype: DT_FLOAT } } } } } } } } } } } } """ expected = struct_pb2.StructuredValue() text_format.Parse(expected_pbtxt, expected) self.assertEqual(expected, encoded) decoded = nested_structure_coder.decode_proto(encoded) self.assertEqual(structure, decoded)
def testEncodeDecodeRaggedTensorSpec(self): structure = [ ragged_tensor.RaggedTensorSpec([1, 2, 3], dtypes.int64, 2, dtypes.int32) ] self.assertTrue(nested_structure_coder.can_encode(structure)) encoded = nested_structure_coder.encode_structure(structure) expected_pbtxt = r""" list_value { values { type_spec_value { type_spec_class: RAGGED_TENSOR_SPEC type_spec_class_name: 'RaggedTensorSpec' num_flat_components: 3 type_state { tuple_value { # spec._shape values { tensor_shape_value { dim { size: 1 } dim { size: 2 } dim { size: 3 } } } # spec._dtype values { tensor_dtype_value: DT_INT64 } # spec._ragged_rank values { int64_value: 2 } # spec._row_splits_dtype values { tensor_dtype_value: DT_INT32 } } } } } } """ expected = struct_pb2.StructuredValue() text_format.Parse(expected_pbtxt, expected) self.assertEqual(expected, encoded) decoded = nested_structure_coder.decode_proto(encoded) self.assertEqual(structure, decoded)
def testEncodeDecodeNamedTuple(self): named_tuple_type = collections.namedtuple("NamedTuple", ["x", "y"]) named_tuple = named_tuple_type(x=[1, 2], y="hello") self.assertTrue(nested_structure_coder.can_encode(named_tuple)) encoded = nested_structure_coder.encode_structure(named_tuple) expected = struct_pb2.StructuredValue() expected_named_tuple = expected.named_tuple_value expected_named_tuple.name = "NamedTuple" key_value_pair = expected_named_tuple.values.add() key_value_pair.key = "x" list_value = key_value_pair.value.list_value list_value.values.add().int64_value = 1 list_value.values.add().int64_value = 2 key_value_pair = expected_named_tuple.values.add() key_value_pair.key = "y" key_value_pair.value.string_value = "hello" self.assertEqual(expected, encoded) decoded = nested_structure_coder.decode_proto(encoded) self.assertEqual(named_tuple._asdict(), decoded._asdict()) self.assertEqual(named_tuple.__class__.__name__, decoded.__class__.__name__)
def composite_tensor_from_variant(encoded, type_spec, name=None): """Returns the `ExtensionType` value encoded by a variant scalar tensor. Args: encoded: A Tensor returned by `composite_tensor_to_variants`. type_spec: The `TypeSpec` of the original value. This is used to determine the number and types of the component tensors that comprise the decoded value. Must be compatible with the `TypeSpec` serilized in `encoded`. name: Optional name for the operation. Returns: An `ExtensionType` value that is compatible with `TypeSpec`. Raises: TypeError: If `encoded` is not a Tensor with dtype=variant. InvalidArgumentError: If `encoded` is not compatible with `type_spec`. """ if not isinstance(encoded, ops.Tensor): raise TypeError(f"Expected `encoded` to be a Tensor, got {encoded!r}.") if encoded.dtype != dtypes.variant: raise TypeError("Expected `encoded` to have dtype=variant, got " f"{encoded!r}.") encoded.shape.assert_is_compatible_with(()) metadata = composite_tensor_variant_pb2.CompositeTensorVariantMetadata() metadata.type_spec_proto.CopyFrom( nested_structure_coder.encode_structure(type_spec).type_spec_value) component_dtypes = [ t.dtype for t in nest.flatten(type_spec, expand_composites=True) ] components = gen_composite_tensor_ops.CompositeTensorVariantToComponents( encoded=encoded, metadata=metadata.SerializeToString(), Tcomponents=component_dtypes, name=name) return nest.pack_sequence_as(type_spec, components, expand_composites=True)
def testEncodeDecodeBoundedTensorSpec(self): structure = [ tensor_spec.BoundedTensorSpec([1, 2, 3], dtypes.int64, 0, 10, "hello-0-10") ] self.assertTrue(nested_structure_coder.can_encode(structure)) encoded = nested_structure_coder.encode_structure(structure) expected = struct_pb2.StructuredValue() expected_list = expected.list_value expected_tensor_spec = expected_list.values.add( ).bounded_tensor_spec_value expected_tensor_spec.shape.dim.add().size = 1 expected_tensor_spec.shape.dim.add().size = 2 expected_tensor_spec.shape.dim.add().size = 3 expected_tensor_spec.name = "hello-0-10" expected_tensor_spec.dtype = dtypes.int64.as_datatype_enum expected_tensor_spec.minimum.CopyFrom( tensor_util.make_tensor_proto([0], dtype=dtypes.int64, shape=[])) expected_tensor_spec.maximum.CopyFrom( tensor_util.make_tensor_proto([10], dtype=dtypes.int64, shape=[])) self.assertEqual(expected, encoded) decoded = nested_structure_coder.decode_proto(encoded) self.assertEqual(structure, decoded)
def as_composite(obj): """Returns a `CompositeTensor` equivalent to the given object. Note that the returned object will have any `Variable`, `tfp.util.DeferredTensor`, or `tfp.util.TransformedVariable` references it closes over converted to tensors at the time this function is called. The type of the returned object will be a subclass of both `CompositeTensor` and `type(obj)`. For this reason, one should be careful about using `as_composite()`, especially for `tf.Module` objects. For example, when the composite tensor is created even as part of a `tf.Module`, it "fixes" the values of the `DeferredTensor` and `tf.Variable` objects it uses: ```python class M(tf.Module): def __init__(self): self._v = tf.Variable(1.) self._d = tfp.distributions.Normal( tfp.util.DeferredTensor(self._v, lambda v: v + 1), 10) self._dct = tfp.experimental.as_composite(self._d) @tf.function def mean(self): return self._dct.mean() m = M() m.mean() >>> <tf.Tensor: numpy=2.0> m._v.assign(2.) # Doesn't update the CompositeTensor distribution. m.mean() >>> <tf.Tensor: numpy=2.0> ``` If, however, the creation of the composite is deferred to a method call, then the Variable and DeferredTensor will be properly captured and respected by the Module and its `SavedModel` (if it is serialized). ```python class M(tf.Module): def __init__(self): self._v = tf.Variable(1.) self._d = tfp.distributions.Normal( tfp.util.DeferredTensor(self._v, lambda v: v + 1), 10) @tf.function def d(self): return tfp.experimental.as_composite(self._d) m = M() m.d().mean() >>> <tf.Tensor: numpy=2.0> m._v.assign(2.) m.d().mean() >>> <tf.Tensor: numpy=3.0> ``` Note: This method is best-effort and based on a heuristic for what the tensor parameters are and what the non-tensor parameters are. Things might be broken, especially for meta-distributions like `TransformedDistribution` or `Independent`. (We try to raise NotImplementedError in such cases.) If you'd benefit from better coverage, please file an issue on github or send an email to `[email protected]`. Args: obj: A `tfp.distributions.Distribution`. Returns: obj: A `tfp.distributions.Distribution` that extends `CompositeTensor`. """ if isinstance(obj, CompositeTensor): return obj cls = _make_convertible(type(obj)) kwargs = dict(obj.parameters) def mk_err_msg(suffix=''): return ( 'Unable to make a CompositeTensor for "{}" of type `{}`. Email ' '`[email protected]` or file an issue on github if you ' 'would benefit from this working. {}'.format( obj, type(obj), suffix)) try: composite_tensor_params = obj._composite_tensor_params # pylint: disable=protected-access except (AttributeError, NotImplementedError): composite_tensor_params = () for k in composite_tensor_params: # Use dtype inference from ctor. if k in kwargs and kwargs[k] is not None: v = getattr(obj, k, kwargs[k]) try: kwargs[k] = tf.convert_to_tensor(v, name=k) except (ValueError, TypeError) as e: kwargs[k] = v for k, v in kwargs.items(): def composite_helper(v): # If we have a parameters attribute, then we may be able to convert to # a composite tensor by guessing which of the parameters are tensors. In # essence, we duck-type based on this attribute. if hasattr(v, 'parameters'): return as_composite(v) return v kwargs[k] = tf.nest.map_structure(composite_helper, v) # Unfortunately, tensor_util.is_ref(v) returns true for a # tf.linalg.LinearOperator even though that is not ideal behavior. if tensor_util.is_ref(v) and not isinstance(v, tf.linalg.LinearOperator): try: kwargs[k] = tf.convert_to_tensor(v, name=k) except TypeError as e: raise NotImplementedError( mk_err_msg( '(Unable to convert dependent entry \'{}\' of object ' '\'{}\': {})'.format(k, obj, str(e)))) result = cls(**kwargs) try: nested_structure_coder.encode_structure(result._type_spec) # pylint: disable=protected-access except nested_structure_coder.NotEncodableError as e: raise NotImplementedError( mk_err_msg('(Unable to serialize: {})'.format(str(e)))) return result
def get_input_specs_from_function(func: tf_function.ConcreteFunction): arg_specs, _ = func.structured_input_signature arg_specs_proto = nested_structure_coder.encode_structure(arg_specs) return arg_specs_proto.SerializeToString()
def get_output_specs_from_function(func: tf_function.ConcreteFunction): output_specs = nest.map_structure(type_spec.type_spec_from_value, func.structured_outputs) output_specs_proto = nested_structure_coder.encode_structure(output_specs) return output_specs_proto.SerializeToString()
def __init__(self, name: str, sampler: reverb_types.SelectorType, remover: reverb_types.SelectorType, max_size: int, rate_limiter: rate_limiters.RateLimiter, max_times_sampled: int = 0, extensions: Sequence[TableExtensionBase] = (), signature: Optional[reverb_types.SpecNest] = None): """Constructor of the Table. Args: name: Name of the priority table. sampler: The strategy to use when selecting samples. remover: The strategy to use when selecting which items to remove. max_size: The maximum number of items which the replay is allowed to hold. When an item is inserted into an already full priority table the `remover` is used for selecting which item to remove before proceeding with the new insert. rate_limiter: Manages the data flow by limiting the sample and insert calls. max_times_sampled: Maximum number of times an item can be sampled before it is deleted. Any value < 1 is ignored and means there is no limit. extensions: Optional sequence of extensions used to add extra features to the table. signature: Optional nested structure containing `tf.TypeSpec` objects, describing the schema of items in this table. Raises: ValueError: If name is empty. ValueError: If max_size <= 0. """ if not name: raise ValueError('name must be nonempty') if max_size <= 0: raise ValueError('max_size (%d) must be a positive integer' % max_size) self._sampler = sampler self._remover = remover self._rate_limiter = rate_limiter self._extensions = extensions self._signature = signature # Merge the c++ extensions into a single list. internal_extensions = [] for extension in extensions: internal_extensions += list( extension.build_internal_extensions(name)) if signature: flat_signature = tree.flatten(signature) for s in flat_signature: if not isinstance(s, tensor_spec.TensorSpec): raise ValueError(f'Unsupported signature spec: {s}') signature_proto_str = (nested_structure_coder.encode_structure( signature).SerializeToString()) else: signature_proto_str = None self.internal_table = pybind.Table( name=name, sampler=sampler, remover=remover, max_size=max_size, max_times_sampled=max_times_sampled, rate_limiter=rate_limiter.internal_limiter, extensions=internal_extensions, signature=signature_proto_str)
def testUnregisteredTypeSpec(self): structure = {"x": UnregisteredTypeSpec()} self.assertFalse(nested_structure_coder.can_encode(structure)) with self.assertRaises(nested_structure_coder.NotEncodableError): nested_structure_coder.encode_structure(structure)