예제 #1
0
def canonicalize_signatures(signatures):
    """Converts `signatures` into a dictionary of concrete functions."""
    if signatures is None:
        return {}, {}
    if not isinstance(signatures, collections_abc.Mapping):
        signatures = {
            signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signatures
        }
    concrete_signatures = {}
    wrapped_functions = {}
    for signature_key, function in signatures.items():
        original_function = signature_function = _get_signature(function)

        if signature_function is None:
            raise ValueError((
                "Expected a TensorFlow function to generate a signature for, but "
                "got {}. Only `tf.functions` with an input signature or "
                "concrete functions can be used as a signature."
            ).format(function))

        wrapped_functions[original_function] = signature_function = (
            wrapped_functions.get(original_function)
            or function_serialization.wrap_cached_variables(original_function))
        _validate_inputs(signature_function)

        # Re-wrap the function so that it returns a dictionary of Tensors. This
        # matches the format of 1.x-style signatures.
        # pylint: disable=cell-var-from-loop
        @def_function.function
        def signature_wrapper(**kwargs):
            structured_outputs = signature_function(**kwargs)
            return _normalize_outputs(structured_outputs,
                                      signature_function.name, signature_key)

        # TODO(b/123902469): Use ConcreteFunction.structured_inputs once their names
        # always match keyword arguments.
        tensor_spec_signature = {}
        for keyword, tensor in zip(
                signature_function._arg_keywords,  # pylint: disable=protected-access
                signature_function.inputs):
            keyword = compat.as_str(keyword)
            tensor_spec_signature[
                keyword] = tensor_spec.TensorSpec.from_tensor(tensor,
                                                              name=keyword)
        final_concrete = signature_wrapper._get_concrete_function_garbage_collected(  # pylint: disable=protected-access
            **tensor_spec_signature)
        # pylint: disable=protected-access
        if len(final_concrete._arg_keywords) == 1:
            # If there is only one input to the signature, a very common case, then
            # ordering is unambiguous and we can let people pass a positional
            # argument. Since SignatureDefs are unordered (protobuf "map") multiple
            # arguments means we need to be keyword-only.
            final_concrete._num_positional_args = 1
        else:
            final_concrete._num_positional_args = 0
        # pylint: enable=protected-access
        concrete_signatures[signature_key] = final_concrete
        # pylint: enable=cell-var-from-loop
    return concrete_signatures, wrapped_functions
def canonicalize_signatures(signatures):
    """Converts `signatures` into a dictionary of concrete functions."""
    if signatures is None:
        return {}, {}
    if not isinstance(signatures, collections_abc.Mapping):
        signatures = {
            signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signatures
        }
    num_normalized_signatures_counter = 0
    concrete_signatures = {}
    wrapped_functions = {}
    for signature_key, function in signatures.items():
        original_function = signature_function = _get_signature(function)
        if signature_function is None:
            raise ValueError(
                "Expected a TensorFlow function for which to generate a signature, "
                f"but got {function}. Only `tf.functions` with an input signature or "
                "concrete functions can be used as a signature.")

        wrapped_functions[original_function] = signature_function = (
            wrapped_functions.get(original_function)
            or function_serialization.wrap_cached_variables(original_function))
        _validate_inputs(signature_function)
        if num_normalized_signatures_counter < _NUM_DISPLAY_NORMALIZED_SIGNATURES:
            signature_name_changes = _get_signature_name_changes(
                signature_function)
            if signature_name_changes:
                num_normalized_signatures_counter += 1
                logging.warning(
                    "Function `%s` contains input name(s) %s with unsupported "
                    "characters which will be renamed to %s in the SavedModel.",
                    compat.as_str(signature_function.graph.name),
                    ", ".join(signature_name_changes.keys()),
                    ", ".join(signature_name_changes.values()))
        # Re-wrap the function so that it returns a dictionary of Tensors. This
        # matches the format of 1.x-style signatures.
        # pylint: disable=cell-var-from-loop
        @def_function.function
        def signature_wrapper(**kwargs):
            structured_outputs = signature_function(**kwargs)
            return _normalize_outputs(structured_outputs,
                                      signature_function.name, signature_key)

        tensor_spec_signature = {}
        if signature_function.structured_input_signature is not None:
            # The structured input signature may contain other non-tensor arguments.
            inputs = filter(
                lambda x: isinstance(x, tensor_spec.TensorSpec),
                nest.flatten(signature_function.structured_input_signature,
                             expand_composites=True))
        else:
            # Structured input signature isn't always defined for some functions.
            inputs = signature_function.inputs

        for keyword, inp in zip(
                signature_function._arg_keywords,  # pylint: disable=protected-access
                inputs):
            keyword = compat.as_str(keyword)
            if isinstance(inp, tensor_spec.TensorSpec):
                spec = tensor_spec.TensorSpec(inp.shape,
                                              inp.dtype,
                                              name=keyword)
            else:
                spec = tensor_spec.TensorSpec.from_tensor(inp, name=keyword)
            tensor_spec_signature[keyword] = spec
        final_concrete = signature_wrapper._get_concrete_function_garbage_collected(  # pylint: disable=protected-access
            **tensor_spec_signature)
        # pylint: disable=protected-access
        if len(final_concrete._arg_keywords) == 1:
            # If there is only one input to the signature, a very common case, then
            # ordering is unambiguous and we can let people pass a positional
            # argument. Since SignatureDefs are unordered (protobuf "map") multiple
            # arguments means we need to be keyword-only.
            final_concrete._num_positional_args = 1
        else:
            final_concrete._num_positional_args = 0
        # pylint: enable=protected-access
        concrete_signatures[signature_key] = final_concrete
        # pylint: enable=cell-var-from-loop
    return concrete_signatures, wrapped_functions
예제 #3
0
  def map_resources(self):
    """Makes new resource handle ops corresponding to existing resource tensors.

    Creates resource handle ops in the current default graph, whereas
    `accessible_objects` will be from an eager context. Resource mapping adds
    resource handle ops to the main GraphDef of a SavedModel, which allows the
    C++ loader API to interact with variables.

    Returns:
      A tuple of (object_map, resource_map, asset_info):
        object_map: A dictionary mapping from object in `accessible_objects` to
          replacement objects created to hold the new resource tensors.
        resource_map: A dictionary mapping from resource tensors extracted from
          `accessible_objects` to newly created resource tensors.
        asset_info: An _AssetInfo tuple describing external assets referenced
          from accessible_objects.
    """
    # Only makes sense when adding to the export Graph
    assert not context.executing_eagerly()
    # TODO(allenl): Handle MirroredVariables and other types of variables which
    # may need special casing.
    object_map = object_identity.ObjectIdentityDictionary()
    resource_map = {}
    asset_info = _AssetInfo(
        asset_defs=[],
        asset_initializers_by_resource={},
        asset_filename_map={},
        asset_index={})

    for node_id, obj in enumerate(self.nodes):
      if isinstance(obj, tracking.CapturableResource):
        new_obj = object_map[obj] = copy.copy(obj)
        # pylint: disable=protected-access
        with ops.device(obj._resource_device):
          new_resource = new_obj._create_resource()
        new_obj._resource_handle = new_resource
        # pylint: enable=protected-access
        resource_map[obj.resource_handle] = new_resource
        self.captured_tensor_node_ids[obj.resource_handle] = node_id
      elif (ds_values.is_distributed_variable(obj) or
            resource_variable_ops.is_resource_variable(obj)):
        obj_to_copy = obj._primary if ds_values.is_distributed_variable(  # pylint: disable=protected-access
            obj) else obj
        new_variable = resource_variable_ops.copy_to_graph_uninitialized(
            obj_to_copy)
        if ds_values.is_distributed_variable(obj):
          self.captured_tensor_node_ids[obj] = node_id
          for v in obj.values:
            object_map[v] = new_variable
            resource_map[v.handle] = new_variable.handle
            self.captured_tensor_node_ids[v.handle] = node_id
        object_map[obj] = new_variable
        resource_map[obj.handle] = new_variable.handle
        self.captured_tensor_node_ids[obj.handle] = node_id
      elif isinstance(obj, tracking.Asset):
        _process_asset(obj, asset_info, resource_map)
        self.captured_tensor_node_ids[obj.asset_path] = node_id

    # Note: some concrete functions can have been realized when tracing other
    # functions, and might closure-capture tensors from their parent functions.
    # This is normal, but it means those concrete functions can't be serialized
    # as their own independent endpoints, so we filter them out here.
    bad_functions = []
    for concrete_function in self.concrete_functions:
      if not concrete_function.graph.saveable:
        raise ValueError(
            ("Unable to save function {name} for the following reason(s):\n" +
             "\n".join(concrete_function.graph.saving_errors)).format(
                 name=concrete_function.name))
      for capture in concrete_function.captured_inputs:
        if (tensor_util.is_tensor(capture) and
            capture.dtype not in _UNCOPIABLE_DTYPES and
            capture not in self.captured_tensor_node_ids):
          if hasattr(capture, "_cached_variable"):
            if concrete_function not in self.wrapped_functions:
              wrapped = self.wrapped_functions[concrete_function] = (
                  function_serialization.wrap_cached_variables(
                      concrete_function))
              self.function_name_map[compat.as_text(concrete_function.name)] = (
                  compat.as_text(wrapped.name))
            continue
          capture_constant_value = tensor_util.constant_value(capture)
          if capture_constant_value is None:
            bad_functions.append(concrete_function)
            continue
          copied_tensor = constant_op.constant(capture_constant_value)
          node_id = len(self.nodes)
          node = _CapturedConstant(
              eager_tensor=capture, graph_tensor=copied_tensor)
          self.nodes.append(node)
          self.node_ids[capture] = node_id
          self.node_ids[node] = node_id
          self.captured_tensor_node_ids[capture] = node_id
          resource_map[capture] = copied_tensor

    self.concrete_functions = [
        self.wrapped_functions.get(x, x) for x in self.concrete_functions
        if x not in bad_functions
    ]
    return object_map, resource_map, asset_info