Пример #1
0
def serialize_concrete_function(concrete_function, node_ids):
    """Build a SavedConcreteFunction."""
    bound_inputs = []
    try:
        for capture in concrete_function.captured_inputs:
            bound_inputs.append(node_ids[capture])
    except KeyError:
        raise KeyError(
            f"Failed to add concrete function '{concrete_function.name}' to object-"
            f"based SavedModel as it captures tensor {capture!r} which is unsupported"
            " or not reachable from root. "
            "One reason could be that a stateful object or a variable that the "
            "function depends on is not assigned to an attribute of the serialized "
            "trackable object (see SaveTest.test_captures_unreachable_variable)."
        )
    concrete_function_proto = saved_object_graph_pb2.SavedConcreteFunction()
    structured_outputs = func_graph_module.convert_structure_to_signature(
        concrete_function.structured_outputs)
    concrete_function_proto.canonicalized_input_signature.CopyFrom(
        nested_structure_coder.encode_structure(
            concrete_function.structured_input_signature))
    concrete_function_proto.output_signature.CopyFrom(
        nested_structure_coder.encode_structure(structured_outputs))
    concrete_function_proto.bound_inputs.extend(bound_inputs)
    return concrete_function_proto
Пример #2
0
def serialize_function(function, node_ids):
    """Build a SavedFunction proto."""
    coder = nested_structure_coder.StructureCoder()
    proto = saved_object_graph_pb2.SavedFunction()

    function_spec_proto = _serialize_function_spec(function.function_spec,
                                                   coder)
    proto.function_spec.CopyFrom(function_spec_proto)
    all_concrete_functions = \
        function._list_all_concrete_functions_for_serialization()  # pylint: disable=protected-access
    for concrete_function in all_concrete_functions:
        bound_inputs = []
        try:
            for capture in concrete_function.captured_inputs:
                bound_inputs.append(node_ids[capture])
        except KeyError:
            # TODO(andresp): Would it better to throw an exception?
            logging.warning(
                "Concrete function %s not added to object based saved model as it "
                "captures tensor %s which is unsupported or not reachable from root.",
                concrete_function.name, capture)
            continue
        signature_args, signature_kwargs = \
            concrete_function.structured_input_signature
        del signature_kwargs
        concrete_function_proto = proto.concrete_function.add()
        concrete_function_proto.name = concrete_function.name
        concrete_function_proto.canonicalized_input_signature.CopyFrom(
            coder.encode_structure(signature_args))
        structured_outputs = func_graph_module.convert_structure_to_signature(
            concrete_function.structured_outputs)
        concrete_function_proto.output_signature.CopyFrom(
            coder.encode_structure(structured_outputs))
        concrete_function_proto.bound_inputs.extend(bound_inputs)
    return proto
Пример #3
0
def _prune_receiver_tensors(wrapped_function, receiver_tensors, outputs, name):
    inputs = _canonicalize_receiver_tensors(receiver_tensors)
    return wrapped_function.prune(
        inputs,
        outputs,
        name=name,
        input_signature=(None,
                         func_graph.convert_structure_to_signature(inputs)))
Пример #4
0
 def __init__(self, closure, type_spec):
   self._closure = closure
   # The type spec for this `RemoteValue` which is used to trace functions that
   # take this `RemoteValue` as input.
   self._type_spec = func_graph.convert_structure_to_signature(type_spec)
   self._value = None
   self._error = None
   self._status_available_event = threading.Event()
   self._status = _RemoteValueStatus.NOT_READY
Пример #5
0
def serialize_concrete_function(concrete_function, node_ids, coder):
    """Build a SavedConcreteFunction."""
    bound_inputs = []
    try:
        for capture in concrete_function.captured_inputs:
            bound_inputs.append(node_ids[capture])
    except KeyError:
        raise KeyError(
            "Concrete function %s not added to object based saved model as it "
            "captures tensor %s which is unsupported or not reachable from root."
            % (concrete_function.name, capture))
    concrete_function_proto = saved_object_graph_pb2.SavedConcreteFunction()
    structured_outputs = func_graph_module.convert_structure_to_signature(
        concrete_function.structured_outputs)
    concrete_function_proto.canonicalized_input_signature.CopyFrom(
        coder.encode_structure(concrete_function.structured_input_signature))
    concrete_function_proto.output_signature.CopyFrom(
        coder.encode_structure(structured_outputs))
    concrete_function_proto.bound_inputs.extend(bound_inputs)
    return concrete_function_proto
Пример #6
0
def serialize_concrete_function(concrete_function, node_ids):
    """Build a SavedConcreteFunction."""
    bound_inputs = []
    try:
        for capture in concrete_function.captured_inputs:
            bound_inputs.append(node_ids[capture])
    except KeyError:
        # TODO(andresp): Would it better to throw an exception?
        logging.warning(
            "Concrete function %s not added to object based saved model as it "
            "captures tensor %s which is unsupported or not reachable from root.",
            concrete_function.name, capture)
        return None
    coder = nested_structure_coder.StructureCoder()
    concrete_function_proto = saved_object_graph_pb2.SavedConcreteFunction()
    concrete_function_proto.name = concrete_function.name
    concrete_function_proto.canonicalized_input_signature.CopyFrom(
        coder.encode_structure(concrete_function.structured_input_signature))
    structured_outputs = func_graph_module.convert_structure_to_signature(
        concrete_function.structured_outputs)
    concrete_function_proto.output_signature.CopyFrom(
        coder.encode_structure(structured_outputs))
    concrete_function_proto.bound_inputs.extend(bound_inputs)
    return concrete_function_proto
Пример #7
0
def serialize_concrete_function(concrete_function, node_ids, coder):
    """Build a SavedConcreteFunction."""
    bound_inputs = []
    try:
        for capture in concrete_function.captured_inputs:
            bound_inputs.append(node_ids[capture])
    except KeyError:
        # TODO(allenl): This warning shadows a real issue in test_table in
        # save_test.py, where we don't handle captured constants. Fix that and
        # then make this an exception.
        logging.warning(
            "Concrete function %s not added to object based saved model as it "
            "captures tensor %s which is unsupported or not reachable from root.",
            concrete_function.name, capture)
        return None
    concrete_function_proto = saved_object_graph_pb2.SavedConcreteFunction()
    structured_outputs = func_graph_module.convert_structure_to_signature(
        concrete_function.structured_outputs)
    concrete_function_proto.canonicalized_input_signature.CopyFrom(
        coder.encode_structure(concrete_function.structured_input_signature))
    concrete_function_proto.output_signature.CopyFrom(
        coder.encode_structure(structured_outputs))
    concrete_function_proto.bound_inputs.extend(bound_inputs)
    return concrete_function_proto
Пример #8
0
def serialize_concrete_function(concrete_function, node_ids):
  """Build a SavedConcreteFunction."""
  bound_inputs = []
  try:
    for capture in concrete_function.captured_inputs:
      bound_inputs.append(node_ids[capture])
  except KeyError:
    # TODO(andresp): Would it better to throw an exception?
    logging.warning(
        "Concrete function %s not added to object based saved model as it "
        "captures tensor %s which is unsupported or not reachable from root.",
        concrete_function.name, capture)
    return None
  coder = nested_structure_coder.StructureCoder()
  concrete_function_proto = saved_object_graph_pb2.SavedConcreteFunction()
  concrete_function_proto.name = concrete_function.name
  concrete_function_proto.canonicalized_input_signature.CopyFrom(
      coder.encode_structure(concrete_function.structured_input_signature))
  structured_outputs = func_graph_module.convert_structure_to_signature(
      concrete_function.structured_outputs)
  concrete_function_proto.output_signature.CopyFrom(
      coder.encode_structure(structured_outputs))
  concrete_function_proto.bound_inputs.extend(bound_inputs)
  return concrete_function_proto
def serialize_concrete_function(concrete_function, node_ids, coder):
  """Build a SavedConcreteFunction."""
  bound_inputs = []
  try:
    for capture in concrete_function.captured_inputs:
      bound_inputs.append(node_ids[capture])
  except KeyError:
    raise KeyError(
        "Failed to add concrete function %s to object based saved model as it "
        "captures tensor %s which is unsupported or not reachable from root. "
        "One reason could be that a stateful object or a variable that the "
        "function depends on is not assigned to an attribute of the serialized "
        "checkpointable object "
        "(see SaveTest.test_captures_unreachable_variable)."
        % (concrete_function.name, capture))
  concrete_function_proto = saved_object_graph_pb2.SavedConcreteFunction()
  structured_outputs = func_graph_module.convert_structure_to_signature(
      concrete_function.structured_outputs)
  concrete_function_proto.canonicalized_input_signature.CopyFrom(
      coder.encode_structure(concrete_function.structured_input_signature))
  concrete_function_proto.output_signature.CopyFrom(
      coder.encode_structure(structured_outputs))
  concrete_function_proto.bound_inputs.extend(bound_inputs)
  return concrete_function_proto
Пример #10
0
 def _set_type_spec(self, type_spec):
   self._type_spec = func_graph.convert_structure_to_signature(type_spec)
Пример #11
0
    def _extract_signatures(self, wrapped, meta_graph_def):
        """Creates ConcreteFunctions for signatures in `meta_graph_def`."""
        signature_functions = {}
        for signature_key, signature_def in meta_graph_def.signature_def.items(
        ):
            if signature_def.inputs:
                input_items = sorted(signature_def.inputs.items(),
                                     key=lambda item: item[1].name)
                original_input_names, input_specs = zip(*input_items)
            else:
                original_input_names = []
                input_specs = []
            # TODO(allenl): Support optional arguments
            feeds = [
                wrap_function._get_element_from_tensor_info(
                    input_spec, wrapped.graph)  # pylint: disable=protected-access
                for input_spec in input_specs
            ]
            input_names = []
            input_tensors = []
            for original_input_name, feed in zip(original_input_names, feeds):
                if isinstance(feed, sparse_tensor.SparseTensor):
                    # We have to give explicit name for SparseTensor arguments, because
                    # these are not present in the TensorInfo.
                    indices_name = "%s_indices" % original_input_name
                    values_name = "%s_values" % original_input_name
                    dense_shape_name = "%s_dense_shape" % original_input_name
                    input_names.extend(
                        [indices_name, values_name, dense_shape_name])
                    input_tensors.extend(
                        [feed.indices, feed.values, feed.dense_shape])
                elif isinstance(feed, composite_tensor.CompositeTensor):
                    component_tensors = nest.flatten(feed,
                                                     expand_composites=True)
                    input_names.extend("%s_component_%d" %
                                       (original_input_name, n)
                                       for n in range(len(component_tensors)))
                    input_tensors.extend(component_tensors)
                else:
                    input_names.append(original_input_name)
                    input_tensors.append(feed)
            fetches = {
                name: out
                for name, out in signature_def.outputs.items()
            }
            try:
                signature_fn = wrapped.prune(feeds=feeds, fetches=fetches)
            except lift_to_graph.UnliftableError as ex:
                # Mutate the exception to add a bit more detail.
                args = ex.args
                if not args:
                    message = ""
                else:
                    message = args[0]
                message = ((
                    "A SavedModel signature needs an input for each placeholder the "
                    "signature's outputs use. An output for signature '{}' depends on "
                    "a placeholder which is not an input (i.e. the placeholder is not "
                    "fed a value).\n\n").format(signature_key) + message)
                ex.args = (message, ) + args[1:]
                raise
            # pylint: disable=protected-access
            signature_fn._arg_keywords = input_names
            signature_fn._func_graph.structured_input_signature = (
                (),
                func_graph.convert_structure_to_signature(
                    dict(zip(input_names, input_tensors))))

            if len(input_names) == 1:
                # Allowing positional arguments does not create any ambiguity if there's
                # only one.
                signature_fn._num_positional_args = 1
            else:
                signature_fn._num_positional_args = 0
            # pylint: enable=protected-access
            signature_functions[signature_key] = signature_fn
        return signature_functions