def test_to_proto(self): v1 = resource_variable_ops.ResourceVariable(2.) saver = functional_saver.MultiDeviceSaver( saveable_object_util.saveable_objects_for_op(v1, "x")) prefix = os.path.join(self.get_temp_dir(), "ckpt") proto_accumulator = [] wrapped = wrap_function.wrap_function( lambda: proto_accumulator.append(saver.to_proto()), signature=()) self.assertEqual(1, len(proto_accumulator)) proto = proto_accumulator[0] save = wrapped.prune( feeds=wrapped.graph.get_tensor_by_name(proto.filename_tensor_name), fetches=wrapped.graph.get_tensor_by_name(proto.save_tensor_name)) restore = wrapped.prune( feeds=wrapped.graph.get_tensor_by_name(proto.filename_tensor_name), fetches=wrapped.graph.get_operation_by_name(proto.restore_op_name)) save_path = save(constant_op.constant(prefix)) v1.assign(1.) restore(constant_op.constant(save_path)) self.assertEqual(2., self.evaluate(v1)) v2 = resource_variable_ops.ResourceVariable(3.) second_saver = functional_saver.MultiDeviceSaver( saveable_object_util.saveable_objects_for_op(v2, "x")) second_saver.restore(save_path) self.assertEqual(2., self.evaluate(v2))
def test_resource_variable_use_localhost(self): v1 = resource_variable_ops.ResourceVariable(2.) self.evaluate(v1.initializer) saver = functional_saver.MultiDeviceSaver( saveable_object_util.saveable_objects_for_op(v1, "x")) prefix = os.path.join(self.get_temp_dir(), "ckpt") self.evaluate(saver.save(constant_op.constant(prefix), self.local_options)) self.assertEqual(2, len(gfile.Glob(prefix + "*"))) self.evaluate(v1.assign(1.)) self.evaluate(saver.restore(prefix, self.local_options)) self.assertEqual(2., self.evaluate(v1)) v2 = resource_variable_ops.ResourceVariable(3.) self.evaluate(v2.initializer) second_saver = functional_saver.MultiDeviceSaver( saveable_object_util.saveable_objects_for_op(v2, "x")) self.evaluate(second_saver.restore(prefix, self.local_options)) self.assertEqual(2., self.evaluate(v2)) # In graph mode, verify that the save and restore ops were set to run on # localhost. if not context.executing_eagerly(): for op in ops.get_default_graph().get_operations(): if op.type in ("SaveV2", "RestoreV2"): self.assertEqual(LOCALHOST, op.device)
def test_checkpoint_multi_device_using_localhost(self): with ops.device("cpu:0"): v0 = resource_variable_ops.ResourceVariable(0.) with ops.device("cpu:1"): v1 = resource_variable_ops.ResourceVariable(1.) with ops.device("cpu:2"): v2 = resource_variable_ops.ResourceVariable(2.) self.evaluate([v0.initializer, v1.initializer, v2.initializer]) saver = functional_saver.MultiDeviceSaver( list(saveable_object_util.saveable_objects_for_op(v0, "v0")) + list(saveable_object_util.saveable_objects_for_op(v1, "v1")) + list(saveable_object_util.saveable_objects_for_op(v2, "v2"))) prefix = os.path.join(self.get_temp_dir(), "ckpt") self.evaluate(saver.save(constant_op.constant(prefix), self.local_options)) self.assertEqual(2, len(gfile.Glob(prefix + "*"))) self.evaluate(v0.assign(-1.)) self.evaluate(v1.assign(-1.)) self.evaluate(v2.assign(-1.)) self.evaluate( saver.restore(constant_op.constant(prefix), self.local_options)) self.assertEqual(0., self.evaluate(v0)) self.assertEqual(1., self.evaluate(v1)) self.assertEqual(2., self.evaluate(v2)) # In graph mode, verify that the save and restore ops were set to run on # localhost. if not context.executing_eagerly(): for op in ops.get_default_graph().get_operations(): if op.type in ("SaveV2", "RestoreV2", "MergeV2Checkpoints"): self.assertEqual(LOCALHOST, op.device)
def test_checkpoint_is_sharded_by_task(self): servers = [server_lib.Server.create_local_server() for _ in range(3)] cluster_spec = server_lib.ClusterSpec({ "worker": [s.target[len("grpc://"):] for s in servers]}) remote.connect_to_cluster(cluster_spec) with ops.device("/job:worker/task:0/cpu:0"): v0 = resource_variable_ops.ResourceVariable(0.) with ops.device("/job:worker/task:1/cpu:0"): v1 = resource_variable_ops.ResourceVariable(1.) with ops.device("/job:worker/task:2/cpu:0"): v2 = resource_variable_ops.ResourceVariable(2.) self.evaluate([v0.initializer, v1.initializer, v2.initializer]) saver = functional_saver.MultiDeviceSaver( list(saveable_object_util.saveable_objects_for_op(v0, "v0")) + list(saveable_object_util.saveable_objects_for_op(v1, "v1")) + list(saveable_object_util.saveable_objects_for_op(v2, "v2"))) prefix = os.path.join(self.get_temp_dir(), "ckpt") self.evaluate(saver.save(constant_op.constant(prefix))) self.assertEqual(4, len(gfile.Glob(prefix + "*"))) self.evaluate(v0.assign(-1.)) self.evaluate(v1.assign(-1.)) self.evaluate(v2.assign(-1.)) self.evaluate(saver.restore(constant_op.constant(prefix))) self.assertEqual(0., self.evaluate(v0)) self.assertEqual(1., self.evaluate(v1)) self.assertEqual(2., self.evaluate(v2))
def test_resource_variable(self): v1 = resource_variable_ops.ResourceVariable(2.) self.evaluate(v1.initializer) saver = functional_saver.MultiDeviceSaver( saveable_object_util.saveable_objects_for_op(v1, "x")) prefix = os.path.join(self.get_temp_dir(), "ckpt") self.evaluate(saver.save(constant_op.constant(prefix))) self.assertEqual(2, len(gfile.Glob(prefix + "*"))) self.evaluate(v1.assign(1.)) self.evaluate(saver.restore(prefix)) self.assertEqual(2., self.evaluate(v1)) v2 = resource_variable_ops.ResourceVariable(3.) self.evaluate(v2.initializer) second_saver = functional_saver.MultiDeviceSaver( saveable_object_util.saveable_objects_for_op(v2, "x")) self.evaluate(second_saver.restore(prefix)) self.assertEqual(2., self.evaluate(v2))
def test_callbacks_run(self): # Use dict because an int would be shadowed inside callback. called = { "save": 0, "restore": 0, } class DummyHook(saveable_hook.SaveableHook): def before_save(self): called["save"] += 1 def after_restore(self): called["restore"] += 1 saveable = DummyHook(name="dummy") saver = functional_saver.MultiDeviceSaver([saveable]) prefix = os.path.join(self.get_temp_dir(), "ckpt") self.evaluate(saver.save(constant_op.constant(prefix))) self.assertEqual({"save": 1, "restore": 0}, called) self.evaluate(saver.restore(prefix)) self.assertEqual({"save": 1, "restore": 1}, called)