def serialize_polymorphic_function(polymorphic_function, node_ids): """Build a SavedPolymorphicProto.""" coder = nested_structure_coder.StructureCoder() proto = saved_object_graph_pb2.SavedPolymorphicFunction() proto.function_spec_tuple.CopyFrom( coder.encode_structure(polymorphic_function.function_spec.as_tuple())) # pylint: disable=protected-access for signature, concrete_function in list_all_concrete_functions( polymorphic_function): bound_inputs = [] try: for capture in concrete_function.captured_inputs: bound_inputs.append(node_ids[capture]) except KeyError: # TODO(andresp): Would it better to throw an exception? logging.warning( "Concrete function %s not added to object based saved model as it " "captures tensor %s which is unsupported or not reachable from root.", concrete_function.name, capture) continue function_proto = proto.monomorphic_function.add() function_proto.concrete_function = concrete_function.name function_proto.canonicalized_input.CopyFrom( coder.encode_structure(signature)) function_proto.bound_inputs.extend(bound_inputs) return proto
def _serialize_polymorphic_function(polymorphic_function): monomorphic_functions = [] for concrete_function in list_all_concrete_functions(polymorphic_function): monomorphic_functions.append( saved_object_graph_pb2.SavedMonomorphicFunction( concrete_function=concrete_function.name)) saved_polymorphic_function = saved_object_graph_pb2.SavedPolymorphicFunction( monomorphic_function=monomorphic_functions) return saved_polymorphic_function
def _serialize_polymorphic_function(polymorphic_function, node_ids): """Build a SavedPolymorphicProto.""" proto = saved_object_graph_pb2.SavedPolymorphicFunction() for concrete_function in list_all_concrete_functions(polymorphic_function): bound_inputs = [] try: for capture in concrete_function.captured_inputs: bound_inputs.append(node_ids[capture]) except KeyError: # TODO(andresp): Would it better to throw an exception? logging.warning( "Concrete function %s not added to object based saved model as it " "captures tensor %s which is unsupported or not reachable from root.", concrete_function.name, capture) continue function_proto = proto.monomorphic_function.add() function_proto.concrete_function = concrete_function.name function_proto.bound_inputs.extend(bound_inputs) return proto
def _serialize_polymorphic_function(function): """Represents a PolymorphicFunction in a SavedModel. Adds `function`'s concrete functions to the current graph. Args: function: A `PolymorphicFunction` to serialize. Returns: An unserialized `SavedPolymorphicFunction` protocol buffer object. """ monomorphic_functions = [] for signature in function._cached_input_signatures: # pylint: disable=protected-access if any( isinstance(arg, defun_lib.UnknownArgument) for arg in signature): continue concrete_function = function.get_concrete_function(*signature) concrete_function.add_to_graph() monomorphic_functions.append( saved_object_graph_pb2.SavedMonomorphicFunction( concrete_function=concrete_function.name)) return saved_object_graph_pb2.SavedPolymorphicFunction( monomorphic_function=monomorphic_functions)