Exemplo n.º 1
0
def serialize_concrete_function(concrete_function, node_ids, coder):
    """Build a SavedConcreteFunction."""
    bound_inputs = []
    try:
        for capture in concrete_function.captured_inputs:
            bound_inputs.append(node_ids[capture])
    except KeyError:
        raise KeyError(
            "Concrete function %s not added to object based saved model as it "
            "captures tensor %s which is unsupported or not reachable from root."
            % (concrete_function.name, capture))
    concrete_function_proto = saved_object_graph_pb2.SavedConcreteFunction()
    structured_outputs = func_graph_module.convert_structure_to_signature(
        concrete_function.structured_outputs)
    concrete_function_proto.canonicalized_input_signature.CopyFrom(
        coder.encode_structure(concrete_function.structured_input_signature))
    concrete_function_proto.output_signature.CopyFrom(
        coder.encode_structure(structured_outputs))
    concrete_function_proto.bound_inputs.extend(bound_inputs)
    return concrete_function_proto
Exemplo n.º 2
0
def serialize_concrete_function(concrete_function, node_ids, coder):
    """Build a SavedConcreteFunction."""
    bound_inputs = []
    try:
        for capture in concrete_function.captured_inputs:
            bound_inputs.append(node_ids[capture])
    except KeyError:
        raise KeyError(
            "Failed to add concrete function %s to object based saved model as it "
            "captures tensor %s which is unsupported or not reachable from root. "
            "One reason could be that a stateful object or a variable that the "
            "function depends on is not assigned to an attribute of the serialized "
            "trackable object "
            "(see SaveTest.test_captures_unreachable_variable)." %
            (concrete_function.name, capture))
    concrete_function_proto = saved_object_graph_pb2.SavedConcreteFunction()
    structured_outputs = func_graph_module.convert_structure_to_signature(
        concrete_function.structured_outputs)
    concrete_function_proto.canonicalized_input_signature.CopyFrom(
        coder.encode_structure(concrete_function.structured_input_signature))
    concrete_function_proto.output_signature.CopyFrom(
        coder.encode_structure(structured_outputs))
    concrete_function_proto.bound_inputs.extend(bound_inputs)
    return concrete_function_proto
Exemplo n.º 3
0
def serialize_concrete_function(concrete_function, node_ids):
    """Build a SavedConcreteFunction."""
    bound_inputs = []
    try:
        for capture in concrete_function.captured_inputs:
            bound_inputs.append(node_ids[capture])
    except KeyError:
        # TODO(andresp): Would it better to throw an exception?
        logging.warning(
            "Concrete function %s not added to object based saved model as it "
            "captures tensor %s which is unsupported or not reachable from root.",
            concrete_function.name, capture)
        return None
    coder = nested_structure_coder.StructureCoder()
    concrete_function_proto = saved_object_graph_pb2.SavedConcreteFunction()
    concrete_function_proto.name = concrete_function.name
    concrete_function_proto.canonicalized_input_signature.CopyFrom(
        coder.encode_structure(concrete_function.structured_input_signature))
    structured_outputs = func_graph_module.convert_structure_to_signature(
        concrete_function.structured_outputs)
    concrete_function_proto.output_signature.CopyFrom(
        coder.encode_structure(structured_outputs))
    concrete_function_proto.bound_inputs.extend(bound_inputs)
    return concrete_function_proto
Exemplo n.º 4
0
def serialize_concrete_function(concrete_function, node_ids, coder):
    """Build a SavedConcreteFunction."""
    bound_inputs = []
    try:
        for capture in concrete_function.captured_inputs:
            bound_inputs.append(node_ids[capture])
    except KeyError:
        # TODO(allenl): This warning shadows a real issue in test_table in
        # save_test.py, where we don't handle captured constants. Fix that and
        # then make this an exception.
        logging.warning(
            "Concrete function %s not added to object based saved model as it "
            "captures tensor %s which is unsupported or not reachable from root.",
            concrete_function.name, capture)
        return None
    concrete_function_proto = saved_object_graph_pb2.SavedConcreteFunction()
    structured_outputs = func_graph_module.convert_structure_to_signature(
        concrete_function.structured_outputs)
    concrete_function_proto.canonicalized_input_signature.CopyFrom(
        coder.encode_structure(concrete_function.structured_input_signature))
    concrete_function_proto.output_signature.CopyFrom(
        coder.encode_structure(structured_outputs))
    concrete_function_proto.bound_inputs.extend(bound_inputs)
    return concrete_function_proto