예제 #1
0
def serialize_polymorphic_function(polymorphic_function, node_ids):
    """Build a SavedPolymorphicProto."""
    coder = nested_structure_coder.StructureCoder()
    proto = saved_object_graph_pb2.SavedPolymorphicFunction()

    proto.function_spec_tuple.CopyFrom(
        coder.encode_structure(polymorphic_function.function_spec.as_tuple()))  # pylint: disable=protected-access
    for signature, concrete_function in list_all_concrete_functions(
            polymorphic_function):
        bound_inputs = []
        try:
            for capture in concrete_function.captured_inputs:
                bound_inputs.append(node_ids[capture])
        except KeyError:
            # TODO(andresp): Would it better to throw an exception?
            logging.warning(
                "Concrete function %s not added to object based saved model as it "
                "captures tensor %s which is unsupported or not reachable from root.",
                concrete_function.name, capture)
            continue
        function_proto = proto.monomorphic_function.add()
        function_proto.concrete_function = concrete_function.name
        function_proto.canonicalized_input.CopyFrom(
            coder.encode_structure(signature))
        function_proto.bound_inputs.extend(bound_inputs)
    return proto
예제 #2
0
def _serialize_polymorphic_function(polymorphic_function):
    monomorphic_functions = []
    for concrete_function in list_all_concrete_functions(polymorphic_function):
        monomorphic_functions.append(
            saved_object_graph_pb2.SavedMonomorphicFunction(
                concrete_function=concrete_function.name))
    saved_polymorphic_function = saved_object_graph_pb2.SavedPolymorphicFunction(
        monomorphic_function=monomorphic_functions)
    return saved_polymorphic_function
예제 #3
0
def _serialize_polymorphic_function(polymorphic_function, node_ids):
    """Build a SavedPolymorphicProto."""
    proto = saved_object_graph_pb2.SavedPolymorphicFunction()
    for concrete_function in list_all_concrete_functions(polymorphic_function):
        bound_inputs = []
        try:
            for capture in concrete_function.captured_inputs:
                bound_inputs.append(node_ids[capture])
        except KeyError:
            # TODO(andresp): Would it better to throw an exception?
            logging.warning(
                "Concrete function %s not added to object based saved model as it "
                "captures tensor %s which is unsupported or not reachable from root.",
                concrete_function.name, capture)
            continue
        function_proto = proto.monomorphic_function.add()
        function_proto.concrete_function = concrete_function.name
        function_proto.bound_inputs.extend(bound_inputs)
    return proto
예제 #4
0
def _serialize_polymorphic_function(function):
    """Represents a PolymorphicFunction in a SavedModel.

  Adds `function`'s concrete functions to the current graph.

  Args:
    function: A `PolymorphicFunction` to serialize.

  Returns:
    An unserialized `SavedPolymorphicFunction` protocol buffer object.
  """
    monomorphic_functions = []
    for signature in function._cached_input_signatures:  # pylint: disable=protected-access
        if any(
                isinstance(arg, defun_lib.UnknownArgument)
                for arg in signature):
            continue
        concrete_function = function.get_concrete_function(*signature)
        concrete_function.add_to_graph()
        monomorphic_functions.append(
            saved_object_graph_pb2.SavedMonomorphicFunction(
                concrete_function=concrete_function.name))
    return saved_object_graph_pb2.SavedPolymorphicFunction(
        monomorphic_function=monomorphic_functions)