def test_soft_matching(self): @def_function.function( input_signature=[tensor_spec.TensorSpec([None], dtypes.int32)]) def func(x): return 2 * x root = tracking.Checkpointable() root.f = func self.assertAllEqual([2], root.f(constant_op.constant([1])).numpy()) self.assertAllEqual([2, 4], root.f(constant_op.constant([1, 2])).numpy()) self.assertEqual( 1, len(function_serialization.list_all_concrete_functions(root.f))) imported = self.cycle(root) with self.assertRaises(AssertionError): # We cannot call the function with a constant of shape (). self.assertEqual(7, imported.f(constant_op.constant(2)).numpy()) # TODO(vbardiovsky): When classes are revived with input_signatures, we # should also check that the calls below are not generating any more # concrete functions. self.assertAllEqual([2, 4, 6, 8], imported.f(constant_op.constant([1, 2, 3, 4])).numpy()) self.assertAllEqual([2, 4, 6], imported.f(constant_op.constant([1, 2, 3])).numpy())
def __init__(self, root): checkpointable_objects, node_ids, slot_variables = util.find_objects( root) self.nodes = checkpointable_objects self.node_ids = node_ids self.slot_variables = slot_variables self.polymorphic_functions = util.ObjectIdentityDictionary() # Also add polymorphic functions as nodes. for obj in self.nodes: self.polymorphic_functions[obj] = self._list_polymorphic_functions( obj) for function in self.polymorphic_functions[obj].values(): if function not in self.node_ids: self.node_ids[function] = len(self.nodes) self.nodes.append(function) # Force listing the concrete functions for the side effects: # - populate the cache for polymorphic functions that have an # input_signature and have not been called. # - force side effects of creation of concrete functions, e.g. create # variables on first run. function_serialization.list_all_concrete_functions(function)
def _fill_meta_graph_def(meta_graph_def, obj, signature_functions, object_saver): """Generates a MetaGraph which calls `signature_functions`. Args: meta_graph_def: The MetaGraphDef proto to fill. obj: The checkpointable object being exported. signature_functions: A dictionary mapping signature keys to concrete functions containing signatures to add to the MetaGraph. object_saver: A CheckpointableSaver to add to the MetaGraph. Returns: An _AssetInfo, which contains information to help creating the SavedModel. """ signatures = {} # List objects from the eager context to make sure Optimizers give us the # right Graph-dependent variables. accessible_objects = util.list_objects(obj) resource_initializer_functions = _trace_resource_initializers( accessible_objects) exported_graph = ops.Graph() resource_initializer_ops = [] with exported_graph.as_default(): object_map, resource_map, asset_info = _map_resources( accessible_objects) for resource_initializer_function in resource_initializer_functions: asset_dependencies = [] for capture in resource_initializer_function.graph.external_captures: asset_initializer = asset_info.asset_initializers_by_resource.get( capture, None) if asset_initializer is not None: asset_dependencies.append(asset_initializer) with ops.control_dependencies(asset_dependencies): resource_initializer_ops.append( _call_function_with_mapped_captures( resource_initializer_function, [], resource_map)) with ops.control_dependencies(resource_initializer_ops): init_op = control_flow_ops.no_op() # Add the same op to the main_op collection and to the init_op # signature. The collection is for compatibility with older loader APIs; # only one will be executed. meta_graph_def.collection_def[ constants.MAIN_OP_KEY].node_list.value.append(init_op.name) meta_graph_def.signature_def[constants.INIT_OP_SIGNATURE_KEY].CopyFrom( signature_def_utils.op_signature_def( init_op, constants.INIT_OP_SIGNATURE_KEY)) # Saving an object-based checkpoint again gathers variables. We need to do the # gathering from the eager context so Optimizers save the right set of # variables, but want any operations associated with the save/restore to be in # the exported graph (thus the `to_graph` argument). saver = object_saver.freeze(object_map=object_map, to_graph=exported_graph) # We must resolve the concrete function to add to MetaGraph while in eager # mode. concrete_functions = [] for accessible_object in accessible_objects: for function in function_serialization.list_all_polymorphic_functions( accessible_object).values(): concrete_functions.extend( function_serialization.list_all_concrete_functions(function)) with exported_graph.as_default(): signatures = _generate_signatures(signature_functions, resource_map) for concrete_function in concrete_functions: concrete_function.add_to_graph() saver_def = saver.to_proto() meta_graph_def.saver_def.CopyFrom(saver_def) graph_def = exported_graph.as_graph_def(add_shapes=True) # Clean reference cycles so repeated export()s don't make work for the garbage # collector. ops.dismantle_graph(exported_graph) meta_graph_def.graph_def.CopyFrom(graph_def) meta_graph_def.meta_info_def.tags.append(tag_constants.SERVING) meta_graph_def.asset_file_def.extend(asset_info.asset_defs) for signature_key, signature in signatures.items(): meta_graph_def.signature_def[signature_key].CopyFrom(signature) meta_graph.strip_graph_default_valued_attrs(meta_graph_def) return asset_info