def testNestedLists(self): a = tracking.Checkpointable() a.l = [] b = tracking.Checkpointable() a.l.append([b]) c = tracking.Checkpointable() a.l[0].append(c) a_deps = util.list_objects(a) self.assertIn(b, a_deps) self.assertIn(c, a_deps) a.l[0].append(1) d = tracking.Checkpointable() a.l[0].append(d) a_deps = util.list_objects(a) self.assertIn(d, a_deps) self.assertIn(b, a_deps) self.assertIn(c, a_deps) self.assertNotIn(1, a_deps) e = tracking.Checkpointable() f = tracking.Checkpointable() a.l1 = [[], [e]] a.l1[0].append(f) a_deps = util.list_objects(a) self.assertIn(e, a_deps) self.assertIn(f, a_deps) checkpoint = util.Checkpoint(a=a) checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt")) a.l[0].append(data_structures.NoDependency([])) a.l[0][-1].append(5) checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt")) # Dirtying the inner list means the root object is unsaveable. a.l[0][1] = 2 with self.assertRaisesRegexp(ValueError, "A list element was replaced"): checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
def testNestedLists(self): a = tracking.Checkpointable() a.l = [] b = tracking.Checkpointable() a.l.append([b]) c = tracking.Checkpointable() a.l[0].append(c) a_deps = util.list_objects(a) self.assertIn(b, a_deps) self.assertIn(c, a_deps) a.l[0].append(1) d = tracking.Checkpointable() a.l[0].append(d) a_deps = util.list_objects(a) self.assertIn(d, a_deps) self.assertIn(b, a_deps) self.assertIn(c, a_deps) self.assertNotIn(1, a_deps) e = tracking.Checkpointable() f = tracking.Checkpointable() a.l1 = [[], [e]] a.l1[0].append(f) a_deps = util.list_objects(a) self.assertIn(e, a_deps) self.assertIn(f, a_deps) checkpoint = util.Checkpoint(a=a) checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt")) a.l[0].append(data_structures.NoDependency([])) a.l[0][-1].append(5) checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt")) # Dirtying the inner list means the root object is unsaveable. a.l[0][1] = 2 with self.assertRaisesRegexp(ValueError, "A list element was replaced"): checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
def get_non_optimizer_objects(m, g): """Gather set of model and optimizer checkpointable objects.""" # Set default graph because optimizer.variables() returns optimizer # variables defined in the default graph. with g.as_default(): all_objects = set(checkpointable_utils.list_objects(m)) optimizer_and_variables = set() for obj in all_objects: if isinstance(obj, optimizers.TFOptimizer): optimizer_and_variables.update(checkpointable_utils.list_objects(obj)) optimizer_and_variables.update(set(obj.optimizer.variables())) return all_objects - optimizer_and_variables
def testAddVariableOverwrite(self): root = base.Checkpointable() a = root._add_variable_with_custom_getter( name="v", shape=[], getter=variable_scope.get_variable) self.assertEqual([root, a], util.list_objects(root)) with ops.Graph().as_default(): b = root._add_variable_with_custom_getter( name="v", shape=[], overwrite=True, getter=variable_scope.get_variable) self.assertEqual([root, b], util.list_objects(root)) with ops.Graph().as_default(): with self.assertRaisesRegexp( ValueError, "already declared as a dependency"): root._add_variable_with_custom_getter( name="v", shape=[], overwrite=False, getter=variable_scope.get_variable)
def testAddVariableOverwrite(self): root = base.CheckpointableBase() a = root._add_variable_with_custom_getter( name="v", shape=[], getter=variable_scope.get_variable) self.assertEqual([root, a], util.list_objects(root)) with ops.Graph().as_default(): b = root._add_variable_with_custom_getter( name="v", shape=[], overwrite=True, getter=variable_scope.get_variable) self.assertEqual([root, b], util.list_objects(root)) with ops.Graph().as_default(): with self.assertRaisesRegexp( ValueError, "already declared as a dependency"): root._add_variable_with_custom_getter( name="v", shape=[], overwrite=False, getter=variable_scope.get_variable)
def testDictWrapperNoDependency(self): a = tracking.Checkpointable() a.d = data_structures.NoDependency({}) a.d[1] = [3] self.assertEqual([a], util.list_objects(a)) model = training.Model() model.sub = a save_path = os.path.join(self.get_temp_dir(), "ckpt") model.save_weights(save_path) model.load_weights(save_path)
def testNonStringKeyNotCheckpointableValue(self): a = tracking.Checkpointable() a.d = {} a.d["a"] = [3] a.d[1] = data_structures.NoDependency([3]) self.assertEqual([a, a.d, a.d["a"]], util.list_objects(a)) model = training.Model() model.sub = a save_path = os.path.join(self.get_temp_dir(), "ckpt") model.save_weights(save_path) model.load_weights(save_path)
def test_checkpointable_save_restore(self): def _templated(): v = variable_scope.get_variable( "v", shape=[1], initializer=init_ops.zeros_initializer(), use_resource=True) v2 = variable_scope.get_variable( "v2", shape=[1], initializer=init_ops.zeros_initializer(), use_resource=True) manual = _ManualScope() return v, v + 1., v2, manual, manual() save_template = template.make_template("s1", _templated) v1_save, _, v2_save, manual_scope, manual_scope_v = save_template() six.assertCountEqual( self, [v1_save, v2_save, manual_scope, manual_scope_v, save_template], checkpointable_utils.list_objects(save_template)) manual_dep, = manual_scope._checkpoint_dependencies self.assertEqual("in_manual_scope", manual_dep.name) self.assertIs(manual_scope_v, manual_dep.ref) optimizer = adam.AdamOptimizer(0.0) save_root = checkpointable_utils.Checkpoint(my_template=save_template, optimizer=optimizer) optimizer.minimize(v1_save.read_value) self.evaluate([v.initializer for v in save_template.variables]) self.evaluate([v.initializer for v in optimizer.variables()]) self.evaluate(v1_save.assign([12.])) self.evaluate(v2_save.assign([14.])) checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") save_path = save_root.save(checkpoint_prefix) load_template = template.make_template("s2", _templated) load_optimizer = adam.AdamOptimizer(0.0) load_root = checkpointable_utils.Checkpoint(my_template=load_template, optimizer=load_optimizer) status = load_root.restore(save_path) var, var_plus_one, var2, _, _ = load_template() load_optimizer.minimize(var.read_value) self.assertEqual(3, len(load_template._checkpoint_dependencies)) self.assertEqual("v", load_template._checkpoint_dependencies[0].name) self.assertEqual("v2", load_template._checkpoint_dependencies[1].name) self.assertEqual("ManualScope", load_template._checkpoint_dependencies[2].name) status.assert_consumed().run_restore_ops() self.assertAllEqual([12.], self.evaluate(var)) self.assertAllEqual([13.], self.evaluate(var_plus_one)) self.assertAllEqual([14.], self.evaluate(var2))
def testListBasic(self): a = tracking.Checkpointable() b = tracking.Checkpointable() a.l = [b] c = tracking.Checkpointable() a.l.append(c) a_deps = util.list_objects(a) self.assertIn(b, a_deps) self.assertIn(c, a_deps) direct_a_dep, = a._checkpoint_dependencies self.assertEqual("l", direct_a_dep.name) self.assertIn(b, direct_a_dep.ref) self.assertIn(c, direct_a_dep.ref)
def testShallowCopyCheckpointable(self): original = tracking.Checkpointable() original_sub = tracking.Checkpointable() original.a = [[1.]] original.b = {"a": original_sub} shallow_copied = copy.copy(original) self.assertIs(original_sub, shallow_copied.b["a"]) self.assertIsNot(original, shallow_copied) self.assertEqual([[1.]], shallow_copied.a) shallow_deps = util.list_objects(shallow_copied) self.assertIn(shallow_copied.a, shallow_deps) self.assertIn(shallow_copied.b, shallow_deps) self.assertIn(shallow_copied.b["a"], shallow_deps)
def testListBasic(self): a = tracking.Checkpointable() b = tracking.Checkpointable() a.l = [b] c = tracking.Checkpointable() a.l.append(c) a_deps = util.list_objects(a) self.assertIn(b, a_deps) self.assertIn(c, a_deps) direct_a_dep, = a._checkpoint_dependencies self.assertEqual("l", direct_a_dep.name) self.assertIn(b, direct_a_dep.ref) self.assertIn(c, direct_a_dep.ref)
def test_checkpointable_save_restore(self): def _templated(): v = variable_scope.get_variable( "v", shape=[1], initializer=init_ops.zeros_initializer(), use_resource=True) v2 = variable_scope.get_variable( "v2", shape=[1], initializer=init_ops.zeros_initializer(), use_resource=True) manual = _ManualScope() return v, v + 1., v2, manual, manual() save_template = template.make_template("s1", _templated) v1_save, _, v2_save, manual_scope, manual_scope_v = save_template() six.assertCountEqual( self, [v1_save, v2_save, manual_scope, manual_scope_v, save_template], checkpointable_utils.list_objects(save_template)) manual_dep, = manual_scope._checkpoint_dependencies self.assertEqual("in_manual_scope", manual_dep.name) self.assertIs(manual_scope_v, manual_dep.ref) optimizer = adam.AdamOptimizer(0.0) save_root = checkpointable_utils.Checkpoint( my_template=save_template, optimizer=optimizer) optimizer.minimize(v1_save.read_value) self.evaluate([v.initializer for v in save_template.variables]) self.evaluate([v.initializer for v in optimizer.variables()]) self.evaluate(v1_save.assign([12.])) self.evaluate(v2_save.assign([14.])) checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") save_path = save_root.save(checkpoint_prefix) load_template = template.make_template("s2", _templated) load_optimizer = adam.AdamOptimizer(0.0) load_root = checkpointable_utils.Checkpoint( my_template=load_template, optimizer=load_optimizer) status = load_root.restore(save_path) var, var_plus_one, var2, _, _ = load_template() load_optimizer.minimize(var.read_value) self.assertEqual(3, len(load_template._checkpoint_dependencies)) self.assertEqual("v", load_template._checkpoint_dependencies[0].name) self.assertEqual("v2", load_template._checkpoint_dependencies[1].name) self.assertEqual("ManualScope", load_template._checkpoint_dependencies[2].name) status.assert_consumed().run_restore_ops() self.assertAllEqual([12.], self.evaluate(var)) self.assertAllEqual([13.], self.evaluate(var_plus_one)) self.assertAllEqual([14.], self.evaluate(var2))
def testDeepCopyCheckpointable(self): original = tracking.Checkpointable() original_sub = tracking.Checkpointable() original.a = [[1.]] original.b = {"a": original_sub} deep_copied = copy.deepcopy(original) self.assertIsNot(original, deep_copied) self.assertIsNot(original_sub, deep_copied.b["a"]) self.assertEqual([[1.]], deep_copied.a) self.assertIsInstance(deep_copied.b["a"], tracking.Checkpointable) deps = util.list_objects(deep_copied) self.assertIn(deep_copied.a, deps) self.assertIn(deep_copied.b, deps) self.assertIn(deep_copied.b["a"], deps) self.assertNotIn(original_sub, deps)
def test_checkpointable_dependencies(self): rnn = keras.layers.SimpleRNN x = np.random.random((2, 2, 2)) y = np.random.random((2, 2)) model = keras.models.Sequential() model.add(rnn(2)) model.compile(optimizer=rmsprop.RMSPropOptimizer(learning_rate=0.001), loss='mse') model.fit(x, y, epochs=1, batch_size=1) # check whether the model variables are present in the # checkpointable list of objects checkpointed_objects = set(checkpointable_util.list_objects(model)) for v in model.variables: self.assertIn(v, checkpointed_objects)
def test_checkpointable_dependencies(self): rnn = keras.layers.SimpleRNN with self.test_session(): x = np.random.random((2, 2, 2)) y = np.random.random((2, 2)) model = keras.models.Sequential() model.add(rnn(2)) model.compile(optimizer='rmsprop', loss='mse') model.fit(x, y, epochs=1, batch_size=1) # check whether the model variables are present in the # checkpointable list of objects checkpointed_objects = set(checkpointable_util.list_objects(model)) for v in model.variables: self.assertIn(v, checkpointed_objects)
def test_timedistributed_dense(self): model = keras.models.Sequential() model.add( keras.layers.TimeDistributed(keras.layers.Dense(2), input_shape=(3, 4))) model.compile(optimizer=RMSPropOptimizer(0.01), loss='mse') model.fit(np.random.random((10, 3, 4)), np.random.random((10, 3, 2)), epochs=1, batch_size=10) # test config model.get_config() checkpointed_objects = set(checkpointable_util.list_objects(model)) for v in model.variables: self.assertIn(v, checkpointed_objects)
def test_timedistributed_dense(self): model = keras.models.Sequential() model.add( keras.layers.TimeDistributed( keras.layers.Dense(2), input_shape=(3, 4))) model.compile(optimizer=RMSPropOptimizer(0.01), loss='mse') model.fit( np.random.random((10, 3, 4)), np.random.random((10, 3, 2)), epochs=1, batch_size=10) # test config model.get_config() checkpointed_objects = set(checkpointable_util.list_objects(model)) for v in model.variables: self.assertIn(v, checkpointed_objects)
def testNoDependency(self): root = tracking.Checkpointable() hasdep = tracking.Checkpointable() root.hasdep = hasdep nodep = tracking.Checkpointable() root.nodep = data_structures.NoDependency(nodep) self.assertEqual(1, len(root._checkpoint_dependencies)) self.assertIs(root._checkpoint_dependencies[0].ref, root.hasdep) self.assertIs(root.hasdep, hasdep) self.assertIs(root.nodep, nodep) class NoDependencyModel(training.Model): @base.no_automatic_dependency_tracking def __init__(self): super(NoDependencyModel, self).__init__() self.a = [] self.b = tracking.Checkpointable() nodeps = NoDependencyModel() self.assertEqual([nodeps], util.list_objects(nodeps))
def testNonAppendNotCheckpointable(self): # Non-append mutations (deleting or overwriting values) are OK when the # values aren't tracked. a = tracking.Checkpointable() a.d = {} a.d["a"] = [3] a.d[1] = 3 a.d[1] = 2 self.assertEqual(2, a.d[1]) del a.d[1] a.d[2] = data_structures.NoDependency(tracking.Checkpointable()) second = tracking.Checkpointable() a.d[2] = data_structures.NoDependency(second) self.assertIs(second, a.d[2]) self.assertEqual([a, a.d, a.d["a"]], util.list_objects(a)) model = training.Model() model.sub = a save_path = os.path.join(self.get_temp_dir(), "ckpt") model.save_weights(save_path) model.load_weights(save_path)
def testNoDependency(self): root = tracking.Checkpointable() hasdep = tracking.Checkpointable() root.hasdep = hasdep nodep = tracking.Checkpointable() root.nodep = data_structures.NoDependency(nodep) self.assertEqual(1, len(root._checkpoint_dependencies)) self.assertIs(root._checkpoint_dependencies[0].ref, root.hasdep) self.assertIs(root.hasdep, hasdep) self.assertIs(root.nodep, nodep) class NoDependencyModel(training.Model): @base.no_automatic_dependency_tracking def __init__(self): super(NoDependencyModel, self).__init__() self.a = [] self.b = tracking.Checkpointable() nodeps = NoDependencyModel() self.assertEqual([nodeps], util.list_objects(nodeps))
def testDictionariesBasic(self): a = training.Model() b = training.Model() a.attribute = {"b": b} c = training.Model() a.attribute["c"] = [] a.attribute["c"].append(c) a_deps = util.list_objects(a) self.assertIn(b, a_deps) self.assertIn(c, a_deps) self.assertIs(b, a.attribute["b"]) six.assertCountEqual( self, ["b", "c"], [dep.name for dep in a.attribute._checkpoint_dependencies]) self.assertEqual([b, c], a.layers) self.assertEqual([b, c], a.attribute.layers) self.assertEqual([c], a.attribute["c"].layers) checkpoint = util.Checkpoint(a=a) save_path = checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt")) checkpoint.restore(save_path).assert_consumed()
def testDictDeepCopy(self): root = tracking.Checkpointable() orig_dict = {"a": [1.]} root.a = orig_dict copied = copy.deepcopy(root.a) self.assertAllEqual([1.], copied["a"]) self.assertIsNot(root.a, copied) self.assertIsNot(root.a["a"], copied["a"]) # Dirtiness should be inherited util.list_objects(root.a) orig_dict["b"] = [] with self.assertRaises(ValueError): util.list_objects(root.a) with self.assertRaises(ValueError): util.list_objects(copy.deepcopy(root.a))
def testListDeepCopy(self): root = tracking.Checkpointable() orig_list = [[1.]] root.a = orig_list copied = copy.deepcopy(root.a) self.assertAllEqual([[1.]], copied) self.assertIsNot(root.a, copied) self.assertIsNot(root.a[0], copied[0]) # Dirtiness should be inherited util.list_objects(root.a) orig_list.append(1.) with self.assertRaises(ValueError): util.list_objects(root.a) with self.assertRaises(ValueError): util.list_objects(copy.deepcopy(root.a))
def _make_graph_def(root, signature_functions, object_saver): """Generates and exports call ops for `signature_functions`.""" signatures = {} # List objects from the eager context to make sure Optimizers give us the # right Graph-dependent variables. accessible_objects = util.list_objects(root) exported_graph = ops.Graph() with exported_graph.as_default(): object_map, resource_map = _map_resources(accessible_objects) # Saving an object-based checkpoint again gathers variables. We need to do the # gathering from the eager context so Optimizers save the right set of # variables, but want any operations associated with the save/restore to be in # the exported graph (thus the `to_graph` argument). saver = object_saver.freeze(object_map=object_map, to_graph=exported_graph) with exported_graph.as_default(): signatures = _generate_signatures(signature_functions, resource_map) saver_def = saver.to_proto() graph_def = exported_graph.as_graph_def(add_shapes=True) # Clean reference cycles so repeated export()s don't make work for the garbage # collector. ops.dismantle_graph(exported_graph) return graph_def, signatures, saver_def
def _make_graph_def(root, signature_functions, object_saver): """Generates and exports call ops for `signature_functions`.""" signatures = {} # List objects from the eager context to make sure Optimizers give us the # right Graph-dependent variables. accessible_objects = util.list_objects(root) exported_graph = ops.Graph() with exported_graph.as_default(): object_map, resource_map = _map_resources(accessible_objects) # Saving an object-based checkpoint again gathers variables. We need to do the # gathering from the eager context so Optimizers save the right set of # variables, but want any operations associated with the save/restore to be in # the exported graph (thus the `to_graph` argument). saver = object_saver.freeze(object_map=object_map, to_graph=exported_graph) with exported_graph.as_default(): signatures = _generate_signatures(signature_functions, resource_map) saver_def = saver.to_proto() graph_def = exported_graph.as_graph_def(add_shapes=True) # Clean reference cycles so repeated export()s don't make work for the garbage # collector. ops.dismantle_graph(exported_graph) return graph_def, signatures, saver_def
def test_bidirectional(self): rnn = keras.layers.SimpleRNN samples = 2 dim = 2 timesteps = 2 output_dim = 2 with self.cached_session(): for mode in ['sum', 'concat', 'ave', 'mul']: x = np.random.random((samples, timesteps, dim)) target_dim = 2 * output_dim if mode == 'concat' else output_dim y = np.random.random((samples, target_dim)) # test with Sequential model model = keras.models.Sequential() model.add( keras.layers.Bidirectional(rnn(output_dim), merge_mode=mode, input_shape=(timesteps, dim))) model.compile(optimizer=RMSPropOptimizer(0.01), loss='mse') model.fit(x, y, epochs=1, batch_size=1) # check whether the model variables are present in the # checkpointable list of objects checkpointed_objects = set( checkpointable_util.list_objects(model)) for v in model.variables: self.assertIn(v, checkpointed_objects) # test compute output shape ref_shape = model.layers[-1].output.get_shape() shape = model.layers[-1].compute_output_shape( (None, timesteps, dim)) self.assertListEqual(shape.as_list(), ref_shape.as_list()) # test config model.get_config() model = keras.models.model_from_json(model.to_json()) model.summary()
def test_bidirectional(self): rnn = keras.layers.SimpleRNN samples = 2 dim = 2 timesteps = 2 output_dim = 2 with self.cached_session(): for mode in ['sum', 'concat', 'ave', 'mul']: x = np.random.random((samples, timesteps, dim)) target_dim = 2 * output_dim if mode == 'concat' else output_dim y = np.random.random((samples, target_dim)) # test with Sequential model model = keras.models.Sequential() model.add( keras.layers.Bidirectional( rnn(output_dim), merge_mode=mode, input_shape=(timesteps, dim))) model.compile(optimizer=RMSPropOptimizer(0.01), loss='mse') model.fit(x, y, epochs=1, batch_size=1) # check whether the model variables are present in the # checkpointable list of objects checkpointed_objects = set(checkpointable_util.list_objects(model)) for v in model.variables: self.assertIn(v, checkpointed_objects) # test compute output shape ref_shape = model.layers[-1].output.get_shape() shape = model.layers[-1].compute_output_shape( (None, timesteps, dim)) self.assertListEqual(shape.as_list(), ref_shape.as_list()) # test config model.get_config() model = keras.models.model_from_json(model.to_json()) model.summary()
def _fill_meta_graph_def(meta_graph_def, obj, signature_functions, object_saver): """Generates a MetaGraph which calls `signature_functions`. Args: meta_graph_def: The MetaGraphDef proto to fill. obj: The checkpointable object being exported. signature_functions: A dictionary mapping signature keys to concrete functions containing signatures to add to the MetaGraph. object_saver: A CheckpointableSaver to add to the MetaGraph. Returns: An _AssetInfo, which contains information to help creating the SavedModel. """ signatures = {} # List objects from the eager context to make sure Optimizers give us the # right Graph-dependent variables. accessible_objects = util.list_objects(obj) resource_initializer_functions = _trace_resource_initializers( accessible_objects) exported_graph = ops.Graph() resource_initializer_ops = [] with exported_graph.as_default(): object_map, resource_map, asset_info = _map_resources( accessible_objects) for resource_initializer_function in resource_initializer_functions: asset_dependencies = [] for capture in resource_initializer_function.graph.external_captures: asset_initializer = asset_info.asset_initializers_by_resource.get( capture, None) if asset_initializer is not None: asset_dependencies.append(asset_initializer) with ops.control_dependencies(asset_dependencies): resource_initializer_ops.append( _call_function_with_mapped_captures( resource_initializer_function, [], resource_map)) with ops.control_dependencies(resource_initializer_ops): init_op = control_flow_ops.no_op() # Add the same op to the main_op collection and to the init_op # signature. The collection is for compatibility with older loader APIs; # only one will be executed. meta_graph_def.collection_def[ constants.MAIN_OP_KEY].node_list.value.append(init_op.name) meta_graph_def.signature_def[constants.INIT_OP_SIGNATURE_KEY].CopyFrom( signature_def_utils.op_signature_def( init_op, constants.INIT_OP_SIGNATURE_KEY)) # Saving an object-based checkpoint again gathers variables. We need to do the # gathering from the eager context so Optimizers save the right set of # variables, but want any operations associated with the save/restore to be in # the exported graph (thus the `to_graph` argument). saver = object_saver.freeze(object_map=object_map, to_graph=exported_graph) # We must resolve the concrete function to add to MetaGraph while in eager # mode. concrete_functions = [] for accessible_object in accessible_objects: for function in function_serialization.list_all_polymorphic_functions( accessible_object).values(): concrete_functions.extend( function_serialization.list_all_concrete_functions(function)) with exported_graph.as_default(): signatures = _generate_signatures(signature_functions, resource_map) for concrete_function in concrete_functions: concrete_function.add_to_graph() saver_def = saver.to_proto() meta_graph_def.saver_def.CopyFrom(saver_def) graph_def = exported_graph.as_graph_def(add_shapes=True) # Clean reference cycles so repeated export()s don't make work for the garbage # collector. ops.dismantle_graph(exported_graph) meta_graph_def.graph_def.CopyFrom(graph_def) meta_graph_def.meta_info_def.tags.append(tag_constants.SERVING) meta_graph_def.asset_file_def.extend(asset_info.asset_defs) for signature_key, signature in signatures.items(): meta_graph_def.signature_def[signature_key].CopyFrom(signature) meta_graph.strip_graph_default_valued_attrs(meta_graph_def) return asset_info
def _fill_meta_graph_def(meta_graph_def, obj, signature_functions, object_saver): """Generates a MetaGraph which calls `signature_functions`. Args: meta_graph_def: The MetaGraphDef proto to fill. obj: The checkpointable object being exported. signature_functions: A dictionary mapping signature keys to concrete functions containing signatures to add to the MetaGraph. object_saver: A CheckpointableSaver to add to the MetaGraph. Returns: asset_filename_map, a dictionary mapping from asset base names to user-specified full asset paths, which should be copied to the SavedModel's assets/ directory. """ signatures = {} # List objects from the eager context to make sure Optimizers give us the # right Graph-dependent variables. accessible_objects = util.list_objects(obj) resource_initializer_functions = _trace_resource_initializers( accessible_objects) exported_graph = ops.Graph() resource_initializer_ops = [] with exported_graph.as_default(): object_map, resource_map, asset_info = _map_resources(accessible_objects) for resource_initializer_function in resource_initializer_functions: asset_dependencies = [] for capture in resource_initializer_function.graph.external_captures: asset_initializer = asset_info.asset_initializers_by_resource.get( capture, None) if asset_initializer is not None: asset_dependencies.append(asset_initializer) with ops.control_dependencies(asset_dependencies): resource_initializer_ops.append( _call_function_with_mapped_captures( resource_initializer_function, [], resource_map)) with ops.control_dependencies(resource_initializer_ops): init_op = control_flow_ops.no_op() # Add the same op to the main_op collection and to the init_op # signature. The collection is for compatibility with older loader APIs; # only one will be executed. meta_graph_def.collection_def[constants.MAIN_OP_KEY].node_list.value.append( init_op.name) meta_graph_def.signature_def[constants.INIT_OP_SIGNATURE_KEY].CopyFrom( signature_def_utils.op_signature_def( init_op, constants.INIT_OP_SIGNATURE_KEY)) # Saving an object-based checkpoint again gathers variables. We need to do the # gathering from the eager context so Optimizers save the right set of # variables, but want any operations associated with the save/restore to be in # the exported graph (thus the `to_graph` argument). saver = object_saver.freeze(object_map=object_map, to_graph=exported_graph) with exported_graph.as_default(): signatures = _generate_signatures(signature_functions, resource_map) saver_def = saver.to_proto() meta_graph_def.saver_def.CopyFrom(saver_def) graph_def = exported_graph.as_graph_def(add_shapes=True) # Clean reference cycles so repeated export()s don't make work for the garbage # collector. ops.dismantle_graph(exported_graph) meta_graph_def.graph_def.CopyFrom(graph_def) meta_graph_def.meta_info_def.tags.append(tag_constants.SERVING) meta_graph_def.asset_file_def.extend(asset_info.asset_defs) for signature_key, signature in signatures.items(): meta_graph_def.signature_def[signature_key].CopyFrom(signature) meta_graph.strip_graph_default_valued_attrs(meta_graph_def) return asset_info.asset_filename_map