def test_to_proto(self): v1 = resource_variable_ops.ResourceVariable(2.) saver = functional_saver.MultiDeviceSaver( saveable_object_util.saveable_objects_for_op(v1, "x")) prefix = os.path.join(self.get_temp_dir(), "ckpt") proto_accumulator = [] wrapped = wrap_function.wrap_function( lambda: proto_accumulator.append(saver.to_proto()), signature=()) self.assertEqual(1, len(proto_accumulator)) proto = proto_accumulator[0] save = wrapped.prune( feeds=wrapped.graph.get_tensor_by_name(proto.filename_tensor_name), fetches=wrapped.graph.get_tensor_by_name(proto.save_tensor_name)) restore = wrapped.prune( feeds=wrapped.graph.get_tensor_by_name(proto.filename_tensor_name), fetches=wrapped.graph.get_operation_by_name(proto.restore_op_name)) save_path = save(constant_op.constant(prefix)) v1.assign(1.) restore(constant_op.constant(save_path)) self.assertEqual(2., self.evaluate(v1)) v2 = resource_variable_ops.ResourceVariable(3.) second_saver = functional_saver.MultiDeviceSaver( saveable_object_util.saveable_objects_for_op(v2, "x")) second_saver.restore(save_path) self.assertEqual(2., self.evaluate(v2))
def test_callbacks_run(self): # Use dict because an int would be shadowed inside callback. called = { "save": 0, "restore": 0, } class DummyHook(saveable_hook.SaveableHook): def before_save(self): called["save"] += 1 def after_restore(self): called["restore"] += 1 saveable = DummyHook(name="dummy") saver = functional_saver.MultiDeviceSaver([saveable]) prefix = os.path.join(self.get_temp_dir(), "ckpt") self.evaluate(saver.save(constant_op.constant(prefix))) self.assertEqual({"save": 1, "restore": 0}, called) self.evaluate(saver.restore(prefix)) self.assertEqual({"save": 1, "restore": 1}, called)
def test_checkpoint_multi_device_using_localhost(self): with ops.device("cpu:0"): v0 = resource_variable_ops.ResourceVariable(0.) with ops.device("cpu:1"): v1 = resource_variable_ops.ResourceVariable(1.) with ops.device("cpu:2"): v2 = resource_variable_ops.ResourceVariable(2.) self.evaluate([v0.initializer, v1.initializer, v2.initializer]) saver = functional_saver.MultiDeviceSaver( list(saveable_object_util.saveable_objects_for_op(v0, "v0")) + list(saveable_object_util.saveable_objects_for_op(v1, "v1")) + list(saveable_object_util.saveable_objects_for_op(v2, "v2"))) prefix = os.path.join(self.get_temp_dir(), "ckpt") self.evaluate(saver.save(constant_op.constant(prefix), self.local_options)) self.assertEqual(2, len(gfile.Glob(prefix + "*"))) self.evaluate(v0.assign(-1.)) self.evaluate(v1.assign(-1.)) self.evaluate(v2.assign(-1.)) self.evaluate( saver.restore(constant_op.constant(prefix), self.local_options)) self.assertEqual(0., self.evaluate(v0)) self.assertEqual(1., self.evaluate(v1)) self.assertEqual(2., self.evaluate(v2)) # In graph mode, verify that the save and restore ops were set to run on # localhost. if not context.executing_eagerly(): for op in ops.get_default_graph().get_operations(): if op.type in ("SaveV2", "RestoreV2", "MergeV2Checkpoints"): self.assertEqual(LOCALHOST, op.device)
def test_checkpoint_is_sharded_by_task(self): servers = [server_lib.Server.create_local_server() for _ in range(3)] cluster_spec = server_lib.ClusterSpec({ "worker": [s.target[len("grpc://"):] for s in servers]}) remote.connect_to_cluster(cluster_spec) with ops.device("/job:worker/task:0/cpu:0"): v0 = resource_variable_ops.ResourceVariable(0.) with ops.device("/job:worker/task:1/cpu:0"): v1 = resource_variable_ops.ResourceVariable(1.) with ops.device("/job:worker/task:2/cpu:0"): v2 = resource_variable_ops.ResourceVariable(2.) self.evaluate([v0.initializer, v1.initializer, v2.initializer]) saver = functional_saver.MultiDeviceSaver( list(saveable_object_util.saveable_objects_for_op(v0, "v0")) + list(saveable_object_util.saveable_objects_for_op(v1, "v1")) + list(saveable_object_util.saveable_objects_for_op(v2, "v2"))) prefix = os.path.join(self.get_temp_dir(), "ckpt") self.evaluate(saver.save(constant_op.constant(prefix))) self.assertEqual(4, len(gfile.Glob(prefix + "*"))) self.evaluate(v0.assign(-1.)) self.evaluate(v1.assign(-1.)) self.evaluate(v2.assign(-1.)) self.evaluate(saver.restore(constant_op.constant(prefix))) self.assertEqual(0., self.evaluate(v0)) self.assertEqual(1., self.evaluate(v1)) self.assertEqual(2., self.evaluate(v2))
def test_checkpoint_is_sharded_by_device(self): with ops.device("cpu:0"): v0 = resource_variable_ops.ResourceVariable(0.) with ops.device("cpu:1"): v1 = resource_variable_ops.ResourceVariable(1.) with ops.device("cpu:2"): v2 = resource_variable_ops.ResourceVariable(2.) self.evaluate([v0.initializer, v1.initializer, v2.initializer]) saver = functional_saver.MultiDeviceSaver( list(saveable_object_util.saveable_objects_for_op(v0, "v0")) + list(saveable_object_util.saveable_objects_for_op(v1, "v1")) + list(saveable_object_util.saveable_objects_for_op(v2, "v2"))) prefix = os.path.join(self.get_temp_dir(), "ckpt") self.evaluate(saver.save(constant_op.constant(prefix))) self.assertEqual(4, len(gfile.Glob(prefix + "*"))) self.evaluate(v0.assign(-1.)) self.evaluate(v1.assign(-1.)) self.evaluate(v2.assign(-1.)) self.evaluate(saver.restore(constant_op.constant(prefix))) self.assertEqual(0., self.evaluate(v0)) self.assertEqual(1., self.evaluate(v1)) self.assertEqual(2., self.evaluate(v2))
def _fill_meta_graph_def(meta_graph_def, saveable_view, signature_functions, namespace_whitelist): """Generates a MetaGraph which calls `signature_functions`. Args: meta_graph_def: The MetaGraphDef proto to fill. saveable_view: The _SaveableView being exported. signature_functions: A dictionary mapping signature keys to concrete functions containing signatures to add to the MetaGraph. namespace_whitelist: List of strings containing whitelisted op namespaces. Returns: A tuple of (_AssetInfo, Graph) containing the captured assets and exported Graph generated from tracing the saveable_view. """ # List objects from the eager context to make sure Optimizers give us the # right Graph-dependent variables. accessible_objects = saveable_view.nodes resource_initializer_functions = _trace_resource_initializers( accessible_objects) exported_graph = ops.Graph() resource_initializer_ops = [] with exported_graph.as_default(): object_map, resource_map, asset_info = saveable_view.map_resources() for resource_initializer_function in resource_initializer_functions: asset_dependencies = [] for capture in resource_initializer_function.graph.external_captures: asset_initializer = asset_info.asset_initializers_by_resource.get( capture, None) if asset_initializer is not None: asset_dependencies.append(asset_initializer) with ops.control_dependencies(asset_dependencies): resource_initializer_ops.append( _call_function_with_mapped_captures( resource_initializer_function, [], resource_map)) resource_initializer_ops.extend( asset_info.asset_initializers_by_resource.values()) with ops.control_dependencies(resource_initializer_ops): init_op = control_flow_ops.no_op() # Add the same op to the main_op collection and to the init_op # signature. The collection is for compatibility with older loader APIs; # only one will be executed. meta_graph_def.collection_def[constants.MAIN_OP_KEY].node_list.value.append( init_op.name) meta_graph_def.signature_def[constants.INIT_OP_SIGNATURE_KEY].CopyFrom( signature_def_utils.op_signature_def( init_op, constants.INIT_OP_SIGNATURE_KEY)) # Saving an object-based checkpoint again gathers variables. We need to do the # gathering from the eager context so Optimizers save the right set of # variables, but want any operations associated with the save/restore to be in # the exported graph (thus the `to_graph` argument). saver = functional_saver.MultiDeviceSaver( saveable_view.checkpoint_view.frozen_saveable_objects( object_map=object_map, to_graph=exported_graph)) with exported_graph.as_default(): signatures = _generate_signatures(signature_functions, resource_map) for concrete_function in saveable_view.concrete_functions: concrete_function.add_to_graph() saver_def = saver.to_proto() meta_graph_def.saver_def.CopyFrom(saver_def) graph_def = exported_graph.as_graph_def(add_shapes=True) _verify_ops(graph_def, namespace_whitelist) meta_graph_def.graph_def.CopyFrom(graph_def) meta_graph_def.meta_info_def.tags.append(tag_constants.SERVING) meta_graph_def.meta_info_def.tensorflow_version = versions.__version__ meta_graph_def.meta_info_def.tensorflow_git_version = ( versions.__git_version__) # We currently always strip default attributes. meta_graph_def.meta_info_def.stripped_default_attrs = True meta_graph_def.meta_info_def.stripped_op_list.MergeFrom( meta_graph.stripped_op_list_for_graph(meta_graph_def.graph_def)) meta_graph_def.asset_file_def.extend(asset_info.asset_defs) for signature_key, signature in signatures.items(): meta_graph_def.signature_def[signature_key].CopyFrom(signature) meta_graph.strip_graph_default_valued_attrs(meta_graph_def) return asset_info, exported_graph