def __init__(self, address: str, name: str = "", list_registered_methods=False, timeout_in_ms=0): self._client_handle, methods = gen_rpc_ops.rpc_client( shared_name=name, server_address=address, list_registered_methods=list_registered_methods, timeout_in_ms=timeout_in_ms) if context.executing_eagerly(): self._handle_deleter = resource_variable_ops.EagerResourceDeleter( handle=self._client_handle, handle_device=self._client_handle.device) else: raise NotImplementedError( "Client creation is supported only in eager mode.") self._server_address = address self._method_registry = {} for method in methods.numpy(): m = rpc_pb2.RegisteredMethod() m.ParseFromString(method) output_specs = nested_structure_coder.decode_proto(m.output_specs) input_specs = nested_structure_coder.decode_proto(m.input_specs) self._method_registry[m.method] = output_specs # TODO(ishark): Perhaps doc string can also be taken as input during # function registration. doc_string = "RPC Call for " + m.method + " method to server " + address self._add_method(m.method, output_specs, input_specs, self._client_handle, doc_string)
def testDecodeUnknownTensorSpec(self): encoded = struct_pb2.StructuredValue() encoded.type_spec_value.type_spec_class = 0 encoded.type_spec_value.type_spec_class_name = "FutureTensorSpec" with self.assertRaisesRegex( ValueError, "The type 'FutureTensorSpec' is not supported"): nested_structure_coder.decode_proto(encoded)
def testEncodeDecodeBoundedTensorSpecNoName(self): structure = [ tensor_spec.BoundedTensorSpec((28, 28, 3), dtypes.float64, -2, (1, 1, 20)) ] self.assertTrue(nested_structure_coder.can_encode(structure)) encoded = nested_structure_coder.encode_structure(structure) expected = struct_pb2.StructuredValue() expected_list = expected.list_value expected_tensor_spec = expected_list.values.add( ).bounded_tensor_spec_value expected_tensor_spec.shape.dim.add().size = 28 expected_tensor_spec.shape.dim.add().size = 28 expected_tensor_spec.shape.dim.add().size = 3 expected_tensor_spec.name = "" expected_tensor_spec.dtype = dtypes.float64.as_datatype_enum expected_tensor_spec.minimum.CopyFrom( tensor_util.make_tensor_proto([-2], dtype=dtypes.float64, shape=[])) expected_tensor_spec.maximum.CopyFrom( tensor_util.make_tensor_proto([1, 1, 20], dtype=dtypes.float64, shape=[3])) self.assertEqual(expected, encoded) decoded = nested_structure_coder.decode_proto(encoded) self.assertEqual(structure, decoded)
def _get_element_from_tensor_info(tensor_info, graph): """Simplified copy of the deprecated `get_tensor_from_tensor_info`.""" encoding = tensor_info.WhichOneof("encoding") if encoding == "name": # We may get operations here in some cases. TensorInfo is a bit of a # misnomer if so. return graph.as_graph_element(tensor_info.name) elif encoding == "coo_sparse": return sparse_tensor.SparseTensor( graph.get_tensor_by_name( tensor_info.coo_sparse.indices_tensor_name), graph.get_tensor_by_name( tensor_info.coo_sparse.values_tensor_name), graph.get_tensor_by_name( tensor_info.coo_sparse.dense_shape_tensor_name)) elif encoding == "composite_tensor": spec_proto = struct_pb2.StructuredValue( type_spec_value=tensor_info.composite_tensor.type_spec) spec = nested_structure_coder.decode_proto(spec_proto) components = [ graph.get_tensor_by_name(component.name) for component in tensor_info.composite_tensor.components ] return spec._from_components(components) # pylint: disable=protected-access else: raise ValueError(f"Invalid TensorInfo.encoding: {encoding}. Valid " "encodings are 'name', 'coo_sparse', and " "'composite_tensor'.")
def testEncodeDecodeSparseTensorSpec(self): structure = [sparse_tensor.SparseTensorSpec([10, 20], dtypes.float32)] self.assertTrue(nested_structure_coder.can_encode(structure)) encoded = nested_structure_coder.encode_structure(structure) expected_pbtxt = r""" list_value { values { type_spec_value { type_spec_class: SPARSE_TENSOR_SPEC type_spec_class_name: 'SparseTensorSpec' num_flat_components: 3 type_state { tuple_value { # spec._shape values { tensor_shape_value { dim { size: 10 } dim { size: 20 } } } # spec._dtype values { tensor_dtype_value: DT_FLOAT } } } } } } """ expected = struct_pb2.StructuredValue() text_format.Parse(expected_pbtxt, expected) self.assertEqual(expected, encoded) decoded = nested_structure_coder.decode_proto(encoded) self.assertEqual(structure, decoded)
def lookup_tensor_or_sparse_or_composite_tensor(tensor_info): """Returns the remapped tensor corresponding to TensorInfo.""" encoding = tensor_info.WhichOneof('encoding') if encoding == 'coo_sparse': return tf.SparseTensor( lookup_remapped_tensor(tensor_info.coo_sparse.indices_tensor_name), lookup_remapped_tensor(tensor_info.coo_sparse.values_tensor_name), lookup_remapped_tensor( tensor_info.coo_sparse.dense_shape_tensor_name)) elif encoding == 'composite_tensor': components = [lookup_remapped_tensor(info.name) for info in tensor_info.composite_tensor.components] spec_proto = struct_pb2.StructuredValue( type_spec_value=tensor_info.composite_tensor.type_spec) # StrcutureCoder.decode_proto was migrated after TF 2.7 to # nested_structure_coder.decode_proto. try: spec = nested_structure_coder.decode_proto(spec_proto) except AttributeError: struct_coder = nested_structure_coder.StructureCoder() spec = struct_coder.decode_proto(spec_proto) return spec._from_components(components) # pylint: disable=protected-access elif encoding == 'name': return lookup_remapped_tensor(tensor_info.name) else: raise ValueError('Unsupported TensorInfo encoding %s' % encoding)
def testBool(self): structure = [False] self.assertTrue(nested_structure_coder.can_encode(structure)) encoded = nested_structure_coder.encode_structure(structure) expected = struct_pb2.StructuredValue() expected.list_value.values.add().bool_value = False self.assertEqual(expected, encoded) decoded = nested_structure_coder.decode_proto(encoded) self.assertEqual(structure, decoded)
def testDtype(self): structure = [dtypes.int64] self.assertTrue(nested_structure_coder.can_encode(structure)) encoded = nested_structure_coder.encode_structure(structure) expected = struct_pb2.StructuredValue() list_value = expected.list_value.values.add() list_value.tensor_dtype_value = dtypes.int64.as_datatype_enum self.assertEqual(expected, encoded) decoded = nested_structure_coder.decode_proto(encoded) self.assertEqual(structure, decoded)
def testNone(self): structure = [1.0, None] self.assertTrue(nested_structure_coder.can_encode(structure)) encoded = nested_structure_coder.encode_structure(structure) expected = struct_pb2.StructuredValue() expected.list_value.values.add().float64_value = 1.0 expected.list_value.values.add().none_value.CopyFrom( struct_pb2.NoneValue()) self.assertEqual(expected, encoded) decoded = nested_structure_coder.decode_proto(encoded) self.assertEqual(structure, decoded)
def testEncodeDecodeList(self): structure = [1.5, 2.5, 3.0] self.assertTrue(nested_structure_coder.can_encode(structure)) encoded = nested_structure_coder.encode_structure(structure) expected = struct_pb2.StructuredValue() expected.list_value.values.add().float64_value = 1.5 expected.list_value.values.add().float64_value = 2.5 expected.list_value.values.add().float64_value = 3.0 self.assertEqual(expected, encoded) decoded = nested_structure_coder.decode_proto(encoded) self.assertEqual(structure, decoded)
def testEncodeDecodeDict(self): structure = dict(a=3, b=[7, 2.5]) self.assertTrue(nested_structure_coder.can_encode(structure)) encoded = nested_structure_coder.encode_structure(structure) expected = struct_pb2.StructuredValue() expected.dict_value.fields["a"].int64_value = 3 list_value = expected.dict_value.fields["b"].list_value list_value.values.add().int64_value = 7 list_value.values.add().float64_value = 2.5 self.assertEqual(expected, encoded) decoded = nested_structure_coder.decode_proto(encoded) self.assertIsInstance(decoded["a"], int) self.assertEqual(structure, decoded)
def _setup_functions_structures(self): """Setup structure for inputs and outputs of restored functions.""" for name, proto in sorted(self._proto.concrete_functions.items()): concrete_function = self._concrete_functions[name] # By setting the structured_outputs directly, we can rely on this # function_lib.ConcreteFunction object to perform the output repacking # logic. The only limitation of that logic is that it only works # with output that is convertible to Tensors and the conversion # always happens. For example tf.TensorShape([2, 3]) will be # converted to Tensor representing [2, 3]. original_outputs = nested_structure_coder.decode_proto( proto.output_signature) # The original_outputs here had Tensors converted to TensorSpecs, so # the restored function's structured_outputs field will not be # exactly the same. Fortunately the repacking logic cares only about # the structure; and the unpacking logic cares only about structure # and types. concrete_function._func_graph.structured_outputs = original_outputs # pylint: disable=protected-access concrete_function._func_graph.structured_input_signature = ( # pylint: disable=protected-access nested_structure_coder.decode_proto( proto.canonicalized_input_signature)) concrete_function._initialize_function_spec() # pylint: disable=protected-access
def _deserialize_function_spec_as_nonmethod(function_spec_proto): """Deserialize a FunctionSpec object from its proto representation.""" typeless_fullargspec = nested_structure_coder.decode_proto( function_spec_proto.fullargspec) # Convert a method function into a non method. if function_spec_proto.is_method: if not typeless_fullargspec.args: raise NotImplementedError( "Cannot deserialize a method function without a named " "'self' argument.") args = typeless_fullargspec.args[1:] else: args = typeless_fullargspec.args fullargspec = tf_inspect.FullArgSpec( args=args, varargs=typeless_fullargspec.varargs, varkw=typeless_fullargspec.varkw, defaults=typeless_fullargspec.defaults, kwonlyargs=typeless_fullargspec.kwonlyargs, kwonlydefaults=typeless_fullargspec.kwonlydefaults, annotations=typeless_fullargspec.annotations) input_signature = nested_structure_coder.decode_proto( function_spec_proto.input_signature) # See `tf.function` and the JitCompile proto for details. jit_compile = { saved_object_graph_pb2.FunctionSpec.JitCompile.DEFAULT: None, saved_object_graph_pb2.FunctionSpec.JitCompile.ON: True, saved_object_graph_pb2.FunctionSpec.JitCompile.OFF: False, }.get(function_spec_proto.jit_compile) return function_spec_lib.FunctionSpec( fullargspec=fullargspec, is_method=False, input_signature=input_signature, jit_compile=jit_compile)
def testRegisteredTypeSpec(self): expected_warning = ("Encoding a StructuredValue with type " "NestedStructureTest.RegisteredTypeSpec; loading " "this StructuredValue will require that this type " "be imported and registered") structure = {"x": RegisteredTypeSpec()} self.assertTrue(nested_structure_coder.can_encode(structure)) with warnings.catch_warnings(record=True) as w: encoded = nested_structure_coder.encode_structure(structure) self.assertLen(w, 1) self.assertIn(expected_warning, str(w[0].message)) decoded = nested_structure_coder.decode_proto(encoded) self.assertEqual(structure, decoded)
def testEmptyStructures(self): structure = [list(), dict(), tuple()] self.assertTrue(nested_structure_coder.can_encode(structure)) encoded = nested_structure_coder.encode_structure(structure) expected = struct_pb2.StructuredValue() expected.list_value.values.add().list_value.CopyFrom( struct_pb2.ListValue()) expected.list_value.values.add().dict_value.CopyFrom( struct_pb2.DictValue()) expected.list_value.values.add().tuple_value.CopyFrom( struct_pb2.TupleValue()) self.assertEqual(expected, encoded) decoded = nested_structure_coder.decode_proto(encoded) self.assertEqual(structure, decoded)
def testEncodeDecodeTuple(self): structure = ("hello", [3, (2, 1)]) self.assertTrue(nested_structure_coder.can_encode(structure)) encoded = nested_structure_coder.encode_structure(structure) expected = struct_pb2.StructuredValue() expected.tuple_value.values.add().string_value = "hello" list_value = expected.tuple_value.values.add().list_value list_value.values.add().int64_value = 3 tuple_value = list_value.values.add().tuple_value tuple_value.values.add().int64_value = 2 tuple_value.values.add().int64_value = 1 self.assertEqual(expected, encoded) decoded = nested_structure_coder.decode_proto(encoded) self.assertEqual(structure, decoded)
def testEncodeDataSetSpec(self): structure = [ dataset_ops.DatasetSpec({ "rt": ragged_tensor.RaggedTensorSpec([10, None], dtypes.int32), "st": sparse_tensor.SparseTensorSpec([10, 20], dtypes.float32), "t": tensor_spec.TensorSpec([10, 8], dtypes.string) }) ] self.assertTrue(nested_structure_coder.can_encode(structure)) encoded = nested_structure_coder.encode_structure(structure) decoded = nested_structure_coder.decode_proto(encoded) self.assertEqual(structure, decoded)
def testEncodeDecodeTensorSpecWithNoName(self): structure = [tensor_spec.TensorSpec([1, 2, 3], dtypes.int64)] self.assertTrue(nested_structure_coder.can_encode(structure)) encoded = nested_structure_coder.encode_structure(structure) expected = struct_pb2.StructuredValue() expected_list = expected.list_value expected_tensor_spec = expected_list.values.add().tensor_spec_value expected_tensor_spec.shape.dim.add().size = 1 expected_tensor_spec.shape.dim.add().size = 2 expected_tensor_spec.shape.dim.add().size = 3 expected_tensor_spec.name = "" expected_tensor_spec.dtype = dtypes.int64.as_datatype_enum self.assertEqual(expected, encoded) decoded = nested_structure_coder.decode_proto(encoded) self.assertEqual(structure, decoded)
def testEncodeDecodeTensorShape(self): structure = [tensor_shape.TensorShape([1, 2, 3]), "hello"] self.assertTrue(nested_structure_coder.can_encode(structure)) encoded = nested_structure_coder.encode_structure(structure) expected = struct_pb2.StructuredValue() expected_list = expected.list_value expected_tensor_shape = expected_list.values.add().tensor_shape_value expected_tensor_shape.dim.add().size = 1 expected_tensor_shape.dim.add().size = 2 expected_tensor_shape.dim.add().size = 3 expected_tensor_shape = expected_list.values.add( ).string_value = "hello" self.assertEqual(expected, encoded) decoded = nested_structure_coder.decode_proto(encoded) self.assertEqual(structure, decoded)
def composite_tensor_info_to_type_spec(tensor_info): """Converts a `TensorInfo` for a composite tensor to a `TypeSpec` object.""" if nested_structure_coder is None or struct_pb2 is None: raise ValueError("This version of TensorFlow does not support " "composite tensors.") if tensor_info.WhichOneof("encoding") != "composite_tensor": raise ValueError("Expected a TensorInfo with encoding=composite_tensor") spec_proto = struct_pb2.StructuredValue( type_spec_value=tensor_info.composite_tensor.type_spec) # StrcutureCoder.decode_proto was migrated after TF 2.7 to # nested_structure_coder.decode_proto. try: return nested_structure_coder.decode_proto(spec_proto) except AttributeError: struct_coder = nested_structure_coder.StructureCoder() return struct_coder.decode_proto(spec_proto)
def testBuildTensorInfoRagged(self): x = ragged_factory_ops.constant([[1, 2], [3]]) x_tensor_info = utils.build_tensor_info(x) # Check components self.assertEqual(x.values.name, x_tensor_info.composite_tensor.components[0].name) self.assertEqual(types_pb2.DT_INT32, x_tensor_info.composite_tensor.components[0].dtype) self.assertEqual(x.row_splits.name, x_tensor_info.composite_tensor.components[1].name) self.assertEqual(types_pb2.DT_INT64, x_tensor_info.composite_tensor.components[1].dtype) # Check type_spec. spec_proto = struct_pb2.StructuredValue( type_spec_value=x_tensor_info.composite_tensor.type_spec) spec = nested_structure_coder.decode_proto(spec_proto) self.assertEqual(spec, x._type_spec)
def testEncodeDecodeExtensionTypeSpec(self): class Zoo(extension_type.ExtensionType): __name__ = "tf.nested_structure_coder_test.Zoo" zookeepers: typing.Tuple[str, ...] animals: typing.Mapping[str, ops.Tensor] structure = [ Zoo.Spec(zookeepers=["Zoey", "Zack"], animals={"tiger": tensor_spec.TensorSpec([16])}) ] self.assertTrue(nested_structure_coder.can_encode(structure)) encoded = nested_structure_coder.encode_structure(structure) expected_pbtxt = r""" list_value { values { type_spec_value { type_spec_class: EXTENSION_TYPE_SPEC type_spec_class_name: "tf.nested_structure_coder_test.Zoo.Spec" num_flat_components: 1 type_state { tuple_value { values { tuple_value { values { string_value: "zookeepers" } values { tuple_value { values { string_value: "Zoey" } values { string_value: "Zack" } } } } } values { tuple_value { values { string_value: "animals" } values { dict_value { fields { key: "tiger" value { tensor_spec_value { shape { dim { size: 16 } } dtype: DT_FLOAT } } } } } } } } } } } } """ expected = struct_pb2.StructuredValue() text_format.Parse(expected_pbtxt, expected) self.assertEqual(expected, encoded) decoded = nested_structure_coder.decode_proto(encoded) self.assertEqual(structure, decoded)
def get_tensor_from_tensor_info(tensor_info, graph=None, import_scope=None): """Returns the Tensor or CompositeTensor described by a TensorInfo proto. Args: tensor_info: A TensorInfo proto describing a Tensor or SparseTensor or CompositeTensor. graph: The tf.Graph in which tensors are looked up. If None, the current default graph is used. import_scope: If not None, names in `tensor_info` are prefixed with this string before lookup. Returns: The Tensor or SparseTensor or CompositeTensor in `graph` described by `tensor_info`. Raises: KeyError: If `tensor_info` does not correspond to a tensor in `graph`. ValueError: If `tensor_info` is malformed. """ graph = graph or ops.get_default_graph() def _get_tensor(name): return graph.get_tensor_by_name( ops.prepend_name_scope(name, import_scope=import_scope)) encoding = tensor_info.WhichOneof("encoding") if encoding == "name": return _get_tensor(tensor_info.name) elif encoding == "coo_sparse": return sparse_tensor.SparseTensor( _get_tensor(tensor_info.coo_sparse.indices_tensor_name), _get_tensor(tensor_info.coo_sparse.values_tensor_name), _get_tensor(tensor_info.coo_sparse.dense_shape_tensor_name)) elif encoding == "composite_tensor": spec_proto = struct_pb2.StructuredValue( type_spec_value=tensor_info.composite_tensor.type_spec) spec = nested_structure_coder.decode_proto(spec_proto) components = [_get_tensor(component.name) for component in tensor_info.composite_tensor.components] return nest.pack_sequence_as(spec, components, expand_composites=True) else: raise ValueError(f"Invalid TensorInfo.encoding: {encoding}. Expected `" "coo_sparse`, `composite_tensor`, or `name` for a dense " "tensor.")
def _infer_inputs(self, layer_node_id, convert_to_shapes=False): """Infers input shape of layer from SavedModel functions.""" call_fn_id = self._search_for_child_node( layer_node_id, ['call_and_return_all_conditional_losses']) if call_fn_id is None: return None concrete_functions = ( self._proto.nodes[call_fn_id].function.concrete_functions) if not concrete_functions: return None call_fn_name = concrete_functions[0] call_fn_proto = self._proto.concrete_functions[call_fn_name] structured_input_signature = nested_structure_coder.decode_proto( call_fn_proto.canonicalized_input_signature) inputs = structured_input_signature[0][0] if convert_to_shapes: return nest.map_structure(lambda spec: spec.shape, inputs) else: return inputs
def testEncodeDecodeRaggedTensorSpec(self): structure = [ ragged_tensor.RaggedTensorSpec([1, 2, 3], dtypes.int64, 2, dtypes.int32) ] self.assertTrue(nested_structure_coder.can_encode(structure)) encoded = nested_structure_coder.encode_structure(structure) expected_pbtxt = r""" list_value { values { type_spec_value { type_spec_class: RAGGED_TENSOR_SPEC type_spec_class_name: 'RaggedTensorSpec' num_flat_components: 3 type_state { tuple_value { # spec._shape values { tensor_shape_value { dim { size: 1 } dim { size: 2 } dim { size: 3 } } } # spec._dtype values { tensor_dtype_value: DT_INT64 } # spec._ragged_rank values { int64_value: 2 } # spec._row_splits_dtype values { tensor_dtype_value: DT_INT32 } } } } } } """ expected = struct_pb2.StructuredValue() text_format.Parse(expected_pbtxt, expected) self.assertEqual(expected, encoded) decoded = nested_structure_coder.decode_proto(encoded) self.assertEqual(structure, decoded)
def from_serialized_proto(cls, proto_string: bytes) -> 'TableInfo': """Constructs a TableInfo from a serialized `schema_pb2.TableInfo`.""" proto = schema_pb2.TableInfo.FromString(proto_string) if proto.HasField('signature'): signature = nested_structure_coder.decode_proto(proto.signature) else: signature = None return cls( name=proto.name, sampler_options=proto.sampler_options, remover_options=proto.remover_options, max_size=proto.max_size, max_times_sampled=proto.max_times_sampled, rate_limiter_info=proto.rate_limiter_info, signature=signature, current_size=proto.current_size, num_episodes=proto.num_episodes, num_deleted_episodes=proto.num_deleted_episodes, num_unique_samples=proto.num_unique_samples, table_worker_time=proto.table_worker_time, )
def testEncodeDecodeNamedTuple(self): named_tuple_type = collections.namedtuple("NamedTuple", ["x", "y"]) named_tuple = named_tuple_type(x=[1, 2], y="hello") self.assertTrue(nested_structure_coder.can_encode(named_tuple)) encoded = nested_structure_coder.encode_structure(named_tuple) expected = struct_pb2.StructuredValue() expected_named_tuple = expected.named_tuple_value expected_named_tuple.name = "NamedTuple" key_value_pair = expected_named_tuple.values.add() key_value_pair.key = "x" list_value = key_value_pair.value.list_value list_value.values.add().int64_value = 1 list_value.values.add().int64_value = 2 key_value_pair = expected_named_tuple.values.add() key_value_pair.key = "y" key_value_pair.value.string_value = "hello" self.assertEqual(expected, encoded) decoded = nested_structure_coder.decode_proto(encoded) self.assertEqual(named_tuple._asdict(), decoded._asdict()) self.assertEqual(named_tuple.__class__.__name__, decoded.__class__.__name__)
def testEncodeDecodeBoundedTensorSpec(self): structure = [ tensor_spec.BoundedTensorSpec([1, 2, 3], dtypes.int64, 0, 10, "hello-0-10") ] self.assertTrue(nested_structure_coder.can_encode(structure)) encoded = nested_structure_coder.encode_structure(structure) expected = struct_pb2.StructuredValue() expected_list = expected.list_value expected_tensor_spec = expected_list.values.add( ).bounded_tensor_spec_value expected_tensor_spec.shape.dim.add().size = 1 expected_tensor_spec.shape.dim.add().size = 2 expected_tensor_spec.shape.dim.add().size = 3 expected_tensor_spec.name = "hello-0-10" expected_tensor_spec.dtype = dtypes.int64.as_datatype_enum expected_tensor_spec.minimum.CopyFrom( tensor_util.make_tensor_proto([0], dtype=dtypes.int64, shape=[])) expected_tensor_spec.maximum.CopyFrom( tensor_util.make_tensor_proto([10], dtype=dtypes.int64, shape=[])) self.assertEqual(expected, encoded) decoded = nested_structure_coder.decode_proto(encoded) self.assertEqual(structure, decoded)
def load_function_def_library(library, saved_object_graph=None, load_shared_name_suffix=None, wrapper_function=None): """Load a set of functions as concrete functions without captured inputs. Functions names are manipulated during load such that they do not overlap with previously created ones. Gradients are re-registered under new names. Ops that reference the gradients are updated to reflect the new registered names. Args: library: FunctionDefLibrary proto message. saved_object_graph: SavedObjectGraph proto message. If not passed in, concrete function structured signatures and outputs will not be set. load_shared_name_suffix: If specified, used to uniquify shared names. Otherwise, a unique name is generated. wrapper_function: An object that will be wrapped on newly created functions. Returns: Map of original function names in the library to instances of `ConcreteFunction` without captured inputs. Raises: ValueError: if functions dependencies have a cycle. """ library_function_names = set(fdef.signature.name for fdef in library.function) functions = {} renamed_functions = {} # Our graph building code currently requires functions to be registered with # some tf.Graph in order to import functions using the # op-name-is-function-name calling convention. To avoid leaking memory into # the global default graph when executing eagerly, we create a temporary # Graph. # # TODO(b/205023033): Make this Graph creation unnecessary when executing # eagerly by fixing function_def_to_graph_def. if ops.executing_eagerly_outside_functions(): graph = ops.Graph() else: graph = ops.get_default_graph() if load_shared_name_suffix is None: load_shared_name_suffix = "_load_{}".format(ops.uid()) # Custom gradient functions must be re-registered under new UIDs. library_gradient_names = {} # Maps old op type to old function name new_gradient_op_types = {} # Maps old gradient op type to new op type. gradients_to_register = {} # Maps old function name to new op type for gdef in library.registered_gradients: if gdef.registered_op_type: new_op_type = custom_gradient.generate_name() old_op_type = compat.as_bytes(gdef.registered_op_type) library_gradient_names[old_op_type] = gdef.gradient_func new_gradient_op_types[old_op_type] = new_op_type gradients_to_register[gdef.gradient_func] = new_op_type function_deps = {} for fdef in library.function: function_deps[fdef.signature.name] = _list_function_deps( fdef, library_function_names, library_gradient_names) loaded_gradients = {} for fdef in _sort_function_defs(library, function_deps): copy = _fix_fdef(fdef, functions, load_shared_name_suffix, new_gradient_op_types) # Setup function signatures and outputs # # When concrete functions are created normally (i.e. when they're originally # created and not loaded via saved model), the inputs and outputs are # calculated based on the values passed in by the user and returned from the # original function, respectively. We don't have access to those anymore at # restore time, so we must instead pass them to the FuncGraph explicitly. structured_input_signature = None structured_outputs = None if (saved_object_graph is not None and fdef.signature.name in saved_object_graph.concrete_functions): # TODO(b/204324043): Offload the deserialization of the protos to the # first class objects by passing the actual protos. This is blocked on # importing `nested_structure_coder` in function.py causing a circular # dependency. proto = saved_object_graph.concrete_functions[fdef.signature.name] structured_input_signature = nested_structure_coder.decode_proto( proto.canonicalized_input_signature) structured_outputs = nested_structure_coder.decode_proto( proto.output_signature) # There is no need to copy all functions into the function def graph. It # leads to a O(n^2) increase of memory when importing functions and the # extra function definitions are a no-op since they already imported as a # function before and passed in explicitly (due to the topologic sort # import). with graph.as_default(): func_graph = function_def_lib.function_def_to_graph( copy, structured_input_signature=structured_input_signature, structured_outputs=structured_outputs) # Restores gradients for function-call ops (not the same as ops that use # custom gradients) _restore_gradient_functions(func_graph, renamed_functions, loaded_gradients) for dep in function_deps[fdef.signature.name]: functions[dep].add_to_graph(func_graph) # We do not initialize the new ConcreteFunction's function_spec and/or # arg_keywords here (which are used to parse the structured and flat # signatures, respectively). ConcreteFunction that are part of a saved # function is set up later by recreate_function(); and bare ConcreteFunction # is set up by by setup_bare_concrete_function(). # However, we copy the FunctionDef attributes to the new ConcreteFunction, # excluding the "_input_shapes", which may cause an error during input shape # initialization at a later stage. if "_input_shapes" in copy.attr: del copy.attr["_input_shapes"] func = function_lib.ConcreteFunction(func_graph, attrs=copy.attr) if wrapper_function: func = wrapper_function(func) func.add_to_graph(graph) functions[fdef.signature.name] = func renamed_functions[func.name] = func if any(op.type == "TRTEngineOp" for op in func_graph.get_operations()): # TODO(b/150708051): Remove this hack once TensorRT SavedModel integration # is fixed. Currently it's leaking memory to maintain bug compatibility # with previous behavior. func.add_to_graph(ops.get_default_graph()) if fdef.signature.name in gradients_to_register: gradient_op_type = gradients_to_register[fdef.signature.name] loaded_gradients[compat.as_bytes(gradient_op_type)] = func ops.RegisterGradient(gradient_op_type)(_gen_gradient_func(func)) return functions
def unpack_pattern(config: Config) -> Pattern: if not config.HasField('pattern_structure'): return config.flat structure = nested_structure_coder.decode_proto(config.pattern_structure) return tree.unflatten_as(structure, config.flat)