def canonicalize_signatures(signatures): """Converts `signatures` into a dictionary of concrete functions.""" if signatures is None: return {}, {} if not isinstance(signatures, collections_abc.Mapping): signatures = { signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signatures } concrete_signatures = {} wrapped_functions = {} for signature_key, function in signatures.items(): original_function = signature_function = _get_signature(function) if signature_function is None: raise ValueError(( "Expected a TensorFlow function to generate a signature for, but " "got {}. Only `tf.functions` with an input signature or " "concrete functions can be used as a signature." ).format(function)) wrapped_functions[original_function] = signature_function = ( wrapped_functions.get(original_function) or function_serialization.wrap_cached_variables(original_function)) _validate_inputs(signature_function) # Re-wrap the function so that it returns a dictionary of Tensors. This # matches the format of 1.x-style signatures. # pylint: disable=cell-var-from-loop @def_function.function def signature_wrapper(**kwargs): structured_outputs = signature_function(**kwargs) return _normalize_outputs(structured_outputs, signature_function.name, signature_key) # TODO(b/123902469): Use ConcreteFunction.structured_inputs once their names # always match keyword arguments. tensor_spec_signature = {} for keyword, tensor in zip( signature_function._arg_keywords, # pylint: disable=protected-access signature_function.inputs): keyword = compat.as_str(keyword) tensor_spec_signature[ keyword] = tensor_spec.TensorSpec.from_tensor(tensor, name=keyword) final_concrete = signature_wrapper._get_concrete_function_garbage_collected( # pylint: disable=protected-access **tensor_spec_signature) # pylint: disable=protected-access if len(final_concrete._arg_keywords) == 1: # If there is only one input to the signature, a very common case, then # ordering is unambiguous and we can let people pass a positional # argument. Since SignatureDefs are unordered (protobuf "map") multiple # arguments means we need to be keyword-only. final_concrete._num_positional_args = 1 else: final_concrete._num_positional_args = 0 # pylint: enable=protected-access concrete_signatures[signature_key] = final_concrete # pylint: enable=cell-var-from-loop return concrete_signatures, wrapped_functions
def canonicalize_signatures(signatures): """Converts `signatures` into a dictionary of concrete functions.""" if signatures is None: return {}, {} if not isinstance(signatures, collections_abc.Mapping): signatures = { signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signatures } num_normalized_signatures_counter = 0 concrete_signatures = {} wrapped_functions = {} for signature_key, function in signatures.items(): original_function = signature_function = _get_signature(function) if signature_function is None: raise ValueError( "Expected a TensorFlow function for which to generate a signature, " f"but got {function}. Only `tf.functions` with an input signature or " "concrete functions can be used as a signature.") wrapped_functions[original_function] = signature_function = ( wrapped_functions.get(original_function) or function_serialization.wrap_cached_variables(original_function)) _validate_inputs(signature_function) if num_normalized_signatures_counter < _NUM_DISPLAY_NORMALIZED_SIGNATURES: signature_name_changes = _get_signature_name_changes( signature_function) if signature_name_changes: num_normalized_signatures_counter += 1 logging.warning( "Function `%s` contains input name(s) %s with unsupported " "characters which will be renamed to %s in the SavedModel.", compat.as_str(signature_function.graph.name), ", ".join(signature_name_changes.keys()), ", ".join(signature_name_changes.values())) # Re-wrap the function so that it returns a dictionary of Tensors. This # matches the format of 1.x-style signatures. # pylint: disable=cell-var-from-loop @def_function.function def signature_wrapper(**kwargs): structured_outputs = signature_function(**kwargs) return _normalize_outputs(structured_outputs, signature_function.name, signature_key) tensor_spec_signature = {} if signature_function.structured_input_signature is not None: # The structured input signature may contain other non-tensor arguments. inputs = filter( lambda x: isinstance(x, tensor_spec.TensorSpec), nest.flatten(signature_function.structured_input_signature, expand_composites=True)) else: # Structured input signature isn't always defined for some functions. inputs = signature_function.inputs for keyword, inp in zip( signature_function._arg_keywords, # pylint: disable=protected-access inputs): keyword = compat.as_str(keyword) if isinstance(inp, tensor_spec.TensorSpec): spec = tensor_spec.TensorSpec(inp.shape, inp.dtype, name=keyword) else: spec = tensor_spec.TensorSpec.from_tensor(inp, name=keyword) tensor_spec_signature[keyword] = spec final_concrete = signature_wrapper._get_concrete_function_garbage_collected( # pylint: disable=protected-access **tensor_spec_signature) # pylint: disable=protected-access if len(final_concrete._arg_keywords) == 1: # If there is only one input to the signature, a very common case, then # ordering is unambiguous and we can let people pass a positional # argument. Since SignatureDefs are unordered (protobuf "map") multiple # arguments means we need to be keyword-only. final_concrete._num_positional_args = 1 else: final_concrete._num_positional_args = 0 # pylint: enable=protected-access concrete_signatures[signature_key] = final_concrete # pylint: enable=cell-var-from-loop return concrete_signatures, wrapped_functions
def map_resources(self): """Makes new resource handle ops corresponding to existing resource tensors. Creates resource handle ops in the current default graph, whereas `accessible_objects` will be from an eager context. Resource mapping adds resource handle ops to the main GraphDef of a SavedModel, which allows the C++ loader API to interact with variables. Returns: A tuple of (object_map, resource_map, asset_info): object_map: A dictionary mapping from object in `accessible_objects` to replacement objects created to hold the new resource tensors. resource_map: A dictionary mapping from resource tensors extracted from `accessible_objects` to newly created resource tensors. asset_info: An _AssetInfo tuple describing external assets referenced from accessible_objects. """ # Only makes sense when adding to the export Graph assert not context.executing_eagerly() # TODO(allenl): Handle MirroredVariables and other types of variables which # may need special casing. object_map = object_identity.ObjectIdentityDictionary() resource_map = {} asset_info = _AssetInfo( asset_defs=[], asset_initializers_by_resource={}, asset_filename_map={}, asset_index={}) for node_id, obj in enumerate(self.nodes): if isinstance(obj, tracking.CapturableResource): new_obj = object_map[obj] = copy.copy(obj) # pylint: disable=protected-access with ops.device(obj._resource_device): new_resource = new_obj._create_resource() new_obj._resource_handle = new_resource # pylint: enable=protected-access resource_map[obj.resource_handle] = new_resource self.captured_tensor_node_ids[obj.resource_handle] = node_id elif (ds_values.is_distributed_variable(obj) or resource_variable_ops.is_resource_variable(obj)): obj_to_copy = obj._primary if ds_values.is_distributed_variable( # pylint: disable=protected-access obj) else obj new_variable = resource_variable_ops.copy_to_graph_uninitialized( obj_to_copy) if ds_values.is_distributed_variable(obj): self.captured_tensor_node_ids[obj] = node_id for v in obj.values: object_map[v] = new_variable resource_map[v.handle] = new_variable.handle self.captured_tensor_node_ids[v.handle] = node_id object_map[obj] = new_variable resource_map[obj.handle] = new_variable.handle self.captured_tensor_node_ids[obj.handle] = node_id elif isinstance(obj, tracking.Asset): _process_asset(obj, asset_info, resource_map) self.captured_tensor_node_ids[obj.asset_path] = node_id # Note: some concrete functions can have been realized when tracing other # functions, and might closure-capture tensors from their parent functions. # This is normal, but it means those concrete functions can't be serialized # as their own independent endpoints, so we filter them out here. bad_functions = [] for concrete_function in self.concrete_functions: if not concrete_function.graph.saveable: raise ValueError( ("Unable to save function {name} for the following reason(s):\n" + "\n".join(concrete_function.graph.saving_errors)).format( name=concrete_function.name)) for capture in concrete_function.captured_inputs: if (tensor_util.is_tensor(capture) and capture.dtype not in _UNCOPIABLE_DTYPES and capture not in self.captured_tensor_node_ids): if hasattr(capture, "_cached_variable"): if concrete_function not in self.wrapped_functions: wrapped = self.wrapped_functions[concrete_function] = ( function_serialization.wrap_cached_variables( concrete_function)) self.function_name_map[compat.as_text(concrete_function.name)] = ( compat.as_text(wrapped.name)) continue capture_constant_value = tensor_util.constant_value(capture) if capture_constant_value is None: bad_functions.append(concrete_function) continue copied_tensor = constant_op.constant(capture_constant_value) node_id = len(self.nodes) node = _CapturedConstant( eager_tensor=capture, graph_tensor=copied_tensor) self.nodes.append(node) self.node_ids[capture] = node_id self.node_ids[node] = node_id self.captured_tensor_node_ids[capture] = node_id resource_map[capture] = copied_tensor self.concrete_functions = [ self.wrapped_functions.get(x, x) for x in self.concrete_functions if x not in bad_functions ] return object_map, resource_map, asset_info