def test_checkpoint_comparison(self): saveable_state = SaveableState(5.) trackable_state = TrackableState(10.) # First test that SaveableState and TrackableState are equivalent by # saving a checkpoint with both objects and swapping values. self.assertEqual(5, self.evaluate(saveable_state.read())) self.assertEqual(10, self.evaluate(trackable_state.read())) ckpt_path = os.path.join(self.get_temp_dir(), "ckpt") checkpoint.Checkpoint(a=saveable_state, b=trackable_state).write(ckpt_path) status = checkpoint.Checkpoint(b=saveable_state, a=trackable_state).read(ckpt_path) status.assert_consumed() self.assertEqual(10, self.evaluate(saveable_state.read())) self.assertEqual(5, self.evaluate(trackable_state.read())) # Test that the converted SaveableState is compatible with the checkpoint # saved above. to_convert = SaveableState(0.0) converted_saveable_state = ( saveable_object_util.SaveableCompatibilityConverter(to_convert)) checkpoint.Checkpoint(a=converted_saveable_state).read( ckpt_path).assert_existing_objects_matched().expect_partial() self.assertEqual(5, self.evaluate(to_convert.read())) checkpoint.Checkpoint(b=converted_saveable_state).read( ckpt_path).assert_existing_objects_matched().expect_partial() self.assertEqual(10, self.evaluate(to_convert.read()))
def testRestoreOrInitialize(self): directory = self.get_temp_dir() # Create a checkpoint for initializing. init_prefix = os.path.join(directory, "init") init_v = variables.Variable(2.0) init_ckpt = util.Checkpoint(v=init_v) self.evaluate(init_v.initializer) init_path = init_ckpt.save(init_prefix) # Create the checkpoint manager. ckpt_dir = os.path.join(directory, "ckpt") v = variables.Variable(1.0) checkpoint = util.Checkpoint(v=v) manager = checkpoint_management.CheckpointManager( checkpoint, ckpt_dir, max_to_keep=None, init_fn=lambda: checkpoint.restore(init_path).run_restore_ops()) self.evaluate(v.initializer) # First call should call `init_fn`. self.assertIsNone(manager.restore_or_initialize()) self.assertEqual(2.0, self.evaluate(v)) # Save a checkpoint and second call should restore from the checkpoints. manager.save() self.assertIsNotNone(manager.restore_or_initialize())
def test_checkpointing(self): self.skipTest( "b/216201668: revisit parallel device and checkpointing.") prefix = os.path.join(self.get_temp_dir(), "ckpt") different_values = self.device.pack( [constant_op.constant(-1.), constant_op.constant(3.)]) with self.device: v = variables.Variable(different_values) checkpoint = tracking.Checkpoint(v=v) save_path = checkpoint.save(prefix) with self.device: v.assign(constant_op.constant(0.)) checkpoint.restore(save_path).assert_consumed() with self.device: outputs = self.device.unpack(v) self.assertAllClose([-1., 3.], outputs) with self.device: restore_on_create = tracking.Checkpoint() restore_on_create.restore(save_path) restore_on_create.v = variables.Variable(0.) outputs = self.device.unpack(restore_on_create.v) self.assertAllClose([-1., 3.], outputs) # Changing the number of devices / restoring into a single-device copy is OK single_device = tracking.Checkpoint(v=variables.Variable(0.)) status = single_device.restore(save_path) status.assert_existing_objects_matched() self.assertAllClose(-1., single_device.v) with self.assertRaisesRegex(AssertionError, "parallel_component_1"): # There are parts of the variable that aren't restored into a # single-device copy. status.assert_consumed()
def test_delayed_restore(self): fname = os.path.join(self.get_temp_dir(), 'checkpoint') model = autotrackable.AutoTrackable() variables = [ variables_lib.Variable([0]), variables_lib.Variable([1]), variables_lib.Variable([2]), variables_lib.Variable([3]) ] model.s = sharded_variable.ShardedVariable(variables) cp = util.Checkpoint(model=model) cp.write(fname) model2 = autotrackable.AutoTrackable() cp2 = util.Checkpoint(model=model2) cp2.restore(fname) variables2 = [ variables_lib.Variable([0]), variables_lib.Variable([0]), variables_lib.Variable([0]), variables_lib.Variable([0]) ] model2.s = sharded_variable.ShardedVariable(variables2) self.assertAllEqual(self.evaluate(model2.s.variables[0]), [0]) self.assertAllEqual(self.evaluate(model2.s.variables[1]), [1]) self.assertAllEqual(self.evaluate(model2.s.variables[2]), [2]) self.assertAllEqual(self.evaluate(model2.s.variables[3]), [3])
def testSaveRestoreNumpyState(self): directory = self.get_temp_dir() prefix = os.path.join(directory, "ckpt") save_state = _NumpyState() saver = util.Checkpoint(numpy=save_state) save_state.a = numpy.ones([2, 2]) save_state.b = numpy.ones([2, 2]) save_state.b = numpy.zeros([2, 2]) save_state.c = numpy.int64(3) self.assertAllEqual(numpy.ones([2, 2]), save_state.a) self.assertAllEqual(numpy.zeros([2, 2]), save_state.b) self.assertEqual(3, save_state.c) first_save_path = saver.save(prefix) save_state.a[1, 1] = 2. save_state.c = numpy.int64(4) second_save_path = saver.save(prefix) load_state = _NumpyState() loader = util.Checkpoint(numpy=load_state) loader.restore(first_save_path).initialize_or_restore() self.assertAllEqual(numpy.ones([2, 2]), load_state.a) self.assertAllEqual(numpy.zeros([2, 2]), load_state.b) self.assertEqual(3, load_state.c) load_state.a[0, 0] = 42. self.assertAllEqual([[42., 1.], [1., 1.]], load_state.a) loader.restore(first_save_path).run_restore_ops() self.assertAllEqual(numpy.ones([2, 2]), load_state.a) loader.restore(second_save_path).run_restore_ops() self.assertAllEqual([[1., 1.], [1., 2.]], load_state.a) self.assertAllEqual(numpy.zeros([2, 2]), load_state.b) self.assertEqual(4, load_state.c)
def testAssertConsumedWithUnusedPythonState(self): has_config = base.Trackable() has_config.get_config = lambda: {} saved = util.Checkpoint(obj=has_config) save_path = saved.save(os.path.join(self.get_temp_dir(), "ckpt")) restored = util.Checkpoint(obj=base.Trackable()) restored.restore(save_path).assert_consumed()
def testDeferredSlotRestoration(self): checkpoint_directory = self.get_temp_dir() root = trackable_utils.Checkpoint() root.var = trackable_utils.add_variable( root, name="var", initializer=0.) optimizer = adam.AdamOptimizer(0.1) if context.executing_eagerly(): optimizer.minimize(root.var.read_value) else: train_op = optimizer.minimize(root.var) # Note that `optimizer` has not been added as a dependency of # `root`. Create a one-off grouping so that slot variables for `root.var` # get initialized too. self.evaluate(trackable_utils.gather_initializers( trackable_utils.Checkpoint(root=root, optimizer=optimizer))) self.evaluate(train_op) self.evaluate(state_ops.assign(root.var, 12.)) no_slots_path = root.save(os.path.join(checkpoint_directory, "no_slots")) root.optimizer = optimizer self.evaluate(state_ops.assign(root.var, 13.)) self.evaluate(state_ops.assign(optimizer.get_slot(name="m", var=root.var), 14.)) slots_path = root.save(os.path.join(checkpoint_directory, "with_slots")) new_root = trackable_utils.Checkpoint() # Load the slot-containing checkpoint (deferred), then immediately overwrite # the non-slot variable (also deferred). slot_status = new_root.restore(slots_path) no_slot_status = new_root.restore(no_slots_path) with self.assertRaises(AssertionError): no_slot_status.assert_consumed() new_root.var = trackable_utils.add_variable( new_root, name="var", shape=[]) no_slot_status.assert_consumed() no_slot_status.run_restore_ops() self.assertEqual(12., self.evaluate(new_root.var)) new_root.optimizer = adam.AdamOptimizer(0.1) slot_status.assert_existing_objects_matched() with self.assertRaisesRegex(AssertionError, "beta1_power"): slot_status.assert_consumed() self.assertEqual(12., self.evaluate(new_root.var)) if context.executing_eagerly(): # Slot variables are only created with restoring initializers when # executing eagerly. self.assertEqual(14., self.evaluate( new_root.optimizer.get_slot(name="m", var=new_root.var))) else: self.assertIs(new_root.optimizer.get_slot(name="m", var=new_root.var), None) if context.executing_eagerly(): new_root.optimizer.minimize(new_root.var.read_value) else: train_op = new_root.optimizer.minimize(new_root.var) # The slot variable now exists; restore() didn't create it, but we should # now have a restore op for it. slot_status.run_restore_ops() self.assertEqual(14., self.evaluate( new_root.optimizer.get_slot(name="m", var=new_root.var))) self.evaluate(train_op) slot_status.assert_consumed()
def test_save_restore_different_partitions(self): fname = os.path.join(self.get_temp_dir(), 'checkpoint') variables = [ variables_lib.Variable([0]), variables_lib.Variable([1]), variables_lib.Variable([2]), variables_lib.Variable([3]) ] s = sharded_variable.ShardedVariable(variables, name='s') cp = util.Checkpoint(s=s) cp.write(fname) variables2 = [variables_lib.Variable([0, 0, 0, 0])] s2 = sharded_variable.ShardedVariable(variables2, name='s') # Restore from 4 partitions into 1. cp2 = util.Checkpoint(s=s2) cp2.restore(fname) self.assertAllEqual(self.evaluate(cp2.s.variables[0]), [0, 1, 2, 3]) self.evaluate(cp2.s.variables[0].assign([5, 10, 15, 20])) cp2.write(fname) # Restore 1 partition into 4. cp.restore(fname) self.assertEqual(self.evaluate(cp.s.variables[0]), [5]) self.assertEqual(self.evaluate(cp.s.variables[1]), [10]) self.assertEqual(self.evaluate(cp.s.variables[2]), [15]) self.assertEqual(self.evaluate(cp.s.variables[3]), [20])
def testCheckpoint(self, delayed, restore_shards): if test_util.is_xla_enabled() and not delayed and restore_shards == 4: self.skipTest( "TODO(b/202760274): Would raise an error that is to be " "investigated.") def make_variable(name, shape, dtype, initializer): initial_value = functools.partial(initializer, shape, dtype=dtype) return variables.Variable(name=name, initial_value=initial_value, shape=shape, dtype=dtype) class Model(autotrackable.AutoTrackable): def build(self): self.w = self._add_variable_with_custom_getter( "w", shape=(4, ), initializer=init_ops_v2.Ones(), getter=make_variable) strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( self.cluster_resolver, sharded_variable.FixedShardsPartitioner(2)) ckpt_dir = os.path.join(self.get_temp_dir(), "checkpoint") with strategy.scope(): model1 = Model() model1.build() self.assertIsInstance(model1.w, sharded_variable.ShardedVariable) self.assertLen(model1.w.variables, 2) model1.w.assign([1., 2., 3., 4.]) cp1 = tracking_util.Checkpoint(model=model1) cp1.write(ckpt_dir) strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( self.cluster_resolver, sharded_variable.FixedShardsPartitioner(restore_shards)) with strategy.scope(): model2 = Model() cp2 = tracking_util.Checkpoint(model=model2) if delayed: cp2.restore(ckpt_dir) model2.build() else: model2.build() cp2.restore(ckpt_dir) self.assertIsInstance(model2.w, sharded_variable.ShardedVariable) self.assertLen(model2.w.variables, restore_shards) if restore_shards == 2: self.assertAllEqual(model2.w.variables[0], [1., 2.]) self.assertAllEqual(model2.w.variables[1], [3., 4.]) elif restore_shards == 4: self.assertAllEqual(model2.w.variables[0], [1.]) self.assertAllEqual(model2.w.variables[1], [2.]) self.assertAllEqual(model2.w.variables[2], [3.]) self.assertAllEqual(model2.w.variables[3], [4.])
def test_forward_compatibility(self): class _MultiSpecSaveable(saveable_object.SaveableObject): def __init__(self, obj, name): self.obj = obj specs = [ saveable_object.SaveSpec(obj.a, "", name + "-a"), saveable_object.SaveSpec(obj.b, "", name + "-b") ] super(_MultiSpecSaveable, self).__init__(None, specs, name) def restore(self, restored_tensors, restored_shapes): del restored_shapes # Unused. self.obj.a.assign(restored_tensors[0]) self.obj.b.assign(restored_tensors[1]) class DeprecatedTrackable(base.Trackable): def __init__(self): self.a = variables.Variable(1.0) self.b = variables.Variable(2.0) def _gather_saveables_for_checkpoint(self): return {"foo": lambda name: _MultiSpecSaveable(self, name)} @saveable_compat.legacy_saveable_name("foo") class NewTrackable(base.Trackable): def __init__(self): self.a = variables.Variable(3.0) self.b = variables.Variable(4.0) def _serialize_to_tensors(self): return {"-a": self.a, "-b": self.b} def _restore_from_tensors(self, restored_tensors): self.a.assign(restored_tensors["-a"]) self.b.assign(restored_tensors["-b"]) new = NewTrackable() # Test with the checkpoint conversion flag disabled (normal compatibility). saveable_compat.force_checkpoint_conversion(False) checkpoint_path = os.path.join(self.get_temp_dir(), "ckpt") checkpoint.Checkpoint(new).write(checkpoint_path) dep = DeprecatedTrackable() checkpoint.Checkpoint(dep).read(checkpoint_path).assert_consumed() self.assertEqual(3, self.evaluate(dep.a)) self.assertEqual(4, self.evaluate(dep.b)) # Now test with the checkpoint conversion flag enabled (forward compat). # The deprecated object will try to load from the new checkpoint. saveable_compat.force_checkpoint_conversion() checkpoint_path = os.path.join(self.get_temp_dir(), "ckpt2") checkpoint.Checkpoint(new).write(checkpoint_path) dep = DeprecatedTrackable() checkpoint.Checkpoint(dep).read(checkpoint_path).assert_consumed() self.assertEqual(3, self.evaluate(dep.a)) self.assertEqual(4, self.evaluate(dep.b))
def test_checkpoint_restore_before_variable_creation(self): self.skip_if_oss() class TestModule(module.Module): def __init__(self, initializer, rows): self._initializer = initializer self._rows = rows table = tpu_embedding_v2_utils.TableConfig( vocabulary_size=self._rows, dim=4, initializer=self._initializer, combiner='sum', name='table') feature_config = (tpu_embedding_v2_utils.FeatureConfig( table=table, name='feature'), ) optimizer = tpu_embedding_v2_utils.SGD() self.tpu_embedding = tpu_embedding_v2.TPUEmbedding( feature_config, optimizer) def create_embedding(self): # We aren't training so batch_size here doesn't matter. self.tpu_embedding.build(64) strategy = self._get_strategy() with strategy.scope(): module1 = TestModule(init_ops_v2.Ones(), strategy.num_replicas_in_sync * 2) module1.create_embedding() checkpoint = util.Checkpoint(test_module=module1) checkpoint.save(self._get_tmpdir('restore_before_create', 'save')) # Reinitialize the tpu strategy = self._get_strategy() with strategy.scope(): module2 = TestModule(init_ops_v2.Zeros(), strategy.num_replicas_in_sync * 2) checkpoint = util.Checkpoint(test_module=module2) checkpoint.restore(self._get_tmpdir('restore_before_create', 'save-1')) with strategy.scope(): module2.create_embedding() def get_values(mid): return mid._variables['table']['parameters'].variables[0].numpy() self.assertAllClose(np.ones((strategy.num_replicas_in_sync * 2, 4)), get_values(module2.tpu_embedding)) # Fetch the values from the TPU to check that they are the same. module2.tpu_embedding._retrieve_variables() self.assertAllClose(np.ones((strategy.num_replicas_in_sync * 2, 4)), get_values(module2.tpu_embedding))
def testMultipleGraphsNonSlotVariables(self): with context.graph_mode(): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") optimizer = adam.AdamOptimizer(0.001) # Construct a model in one graph first_graph = ops.Graph() first_session = session_lib.Session(graph=first_graph) with first_graph.as_default(), first_session.as_default(): first_variable = resource_variable_ops.ResourceVariable([1.]) first_root_trackable = trackable_utils.Checkpoint( optimizer=optimizer, variable=first_variable) train_op = optimizer.minimize(first_variable.read_value) self.evaluate(trackable_utils.gather_initializers( first_root_trackable)) self.evaluate(train_op) self.evaluate(first_variable.assign([1.])) self.evaluate(optimizer.get_slot( var=first_variable, name="m").assign([2.])) beta1_power, _ = optimizer._get_beta_accumulators() self.evaluate(beta1_power.assign(3.)) # Save and load in a second graph second_graph = ops.Graph() with second_graph.as_default(), session_lib.Session(graph=second_graph): second_variable = resource_variable_ops.ResourceVariable([1.]) second_root_trackable = trackable_utils.Checkpoint( optimizer=optimizer, variable=second_variable) train_op = optimizer.minimize(second_variable.read_value) second_root_trackable.restore(None).initialize_or_restore() self.evaluate(train_op) self.evaluate(second_variable.assign([4.])) self.evaluate(optimizer.get_slot( var=second_variable, name="m").assign([5.])) beta1_power, _ = optimizer._get_beta_accumulators() self.evaluate(beta1_power.assign(6.)) save_path = second_root_trackable.save(checkpoint_prefix) self.evaluate(second_variable.assign([7.])) self.evaluate(optimizer.get_slot( var=second_variable, name="m").assign([8.])) beta1_power, _ = optimizer._get_beta_accumulators() self.assertAllEqual(6., self.evaluate(beta1_power)) status = second_root_trackable.restore(save_path) status.assert_consumed().run_restore_ops() self.assertAllEqual([4.], self.evaluate(second_variable)) self.assertAllEqual([5.], self.evaluate(optimizer.get_slot( var=second_variable, name="m"))) beta1_power, _ = optimizer._get_beta_accumulators() self.assertAllEqual(6., self.evaluate(beta1_power)) # Check that the first graph is unmolested with first_graph.as_default(), first_session.as_default(): self.assertAllEqual([1.], self.evaluate(first_variable)) self.assertAllEqual([2.], self.evaluate(optimizer.get_slot( var=first_variable, name="m"))) beta1_power, _ = optimizer._get_beta_accumulators() self.assertAllEqual(3., self.evaluate(beta1_power))
def testDocstringExample(self): arrays = _NumpyState() checkpoint = util.Checkpoint(numpy_arrays=arrays) arrays.x = numpy.zeros([3, 4]) save_path = checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt")) arrays.x[1, 1] = 4. checkpoint.restore(save_path) self.assertAllEqual(numpy.zeros([3, 4]), arrays.x) second_checkpoint = util.Checkpoint(numpy_arrays=_NumpyState()) second_checkpoint.restore(save_path) self.assertAllEqual(numpy.zeros([3, 4]), second_checkpoint.numpy_arrays.x)
def test_trackable_save_restore(self): def _templated(): v = variable_scope.get_variable( "v", shape=[1], initializer=init_ops.zeros_initializer(), use_resource=True) v2 = variable_scope.get_variable( "v2", shape=[1], initializer=init_ops.zeros_initializer(), use_resource=True) manual = _ManualScope() return v, v + 1., v2, manual, manual() save_template = template.make_template("s1", _templated) v1_save, _, v2_save, manual_scope, manual_scope_v = save_template() six.assertCountEqual(self, [ id(obj) for obj in [v1_save, v2_save, manual_scope, manual_scope_v, save_template] ], [id(obj) for obj in trackable_utils.list_objects(save_template)]) self.assertDictEqual({"in_manual_scope": manual_scope_v}, manual_scope._trackable_children()) optimizer = adam.AdamOptimizer(0.0) save_root = trackable_utils.Checkpoint(my_template=save_template, optimizer=optimizer) optimizer.minimize(v1_save.read_value) self.evaluate([v.initializer for v in save_template.variables]) self.evaluate([v.initializer for v in optimizer.variables()]) self.evaluate(v1_save.assign([12.])) self.evaluate(v2_save.assign([14.])) checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") save_path = save_root.save(checkpoint_prefix) load_template = template.make_template("s2", _templated) load_optimizer = adam.AdamOptimizer(0.0) load_root = trackable_utils.Checkpoint(my_template=load_template, optimizer=load_optimizer) status = load_root.restore(save_path) var, var_plus_one, var2, _, _ = load_template() load_optimizer.minimize(var.read_value) self.assertEqual(3, len(load_template._trackable_children())) self.assertEqual(set(["v", "v2", "ManualScope"]), load_template._trackable_children().keys()) status.assert_consumed().run_restore_ops() self.assertAllEqual([12.], self.evaluate(var)) self.assertAllEqual([13.], self.evaluate(var_plus_one)) self.assertAllEqual([14.], self.evaluate(var2))
def testAssertConsumedFailsWithUsedPythonState(self): has_config = base.Trackable() attributes = { "foo_attr": functools.partial(base.PythonStringStateSaveable, state_callback=lambda: "", restore_callback=lambda x: None) } has_config._gather_saveables_for_checkpoint = lambda: attributes saved = util.Checkpoint(obj=has_config) save_path = saved.save(os.path.join(self.get_temp_dir(), "ckpt")) restored = util.Checkpoint(obj=base.Trackable()) status = restored.restore(save_path) with self.assertRaisesRegex(AssertionError, "foo_attr"): status.assert_consumed()
def test_spmd_model_checkpointing(self): class LinearModel(module.Module): def __init__(self, w): super(LinearModel, self).__init__() self.w = variables.Variable(w) def __call__(self, x): return math_ops.matmul(x, self.w) def change_weights_op(self, w_new): return self.w.assign(w_new) batch_size = 32 num_feature_in = 16 num_feature_out = 8 w1 = random_ops.random_uniform((num_feature_in, num_feature_out), dtype=dtypes.float32) w2 = random_ops.random_uniform((num_feature_in, num_feature_out), dtype=dtypes.float32) x = random_ops.random_uniform((batch_size, num_feature_in), dtype=dtypes.float32) strategy, num_replicas = get_tpu_strategy(enable_spmd=True) with strategy.scope(): model = LinearModel(w1) checkpoint_dir = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt") checkpoint = util.Checkpoint(model=model) @def_function.function def step_fn(x): x = strategy.experimental_split_to_logical_devices(x, [1, 2]) return model(x) with self.cached_session() as sess: self.evaluate(variables.global_variables_initializer()) checkpoint.save(file_prefix=checkpoint_prefix) self.evaluate(model.change_weights_op(w2)) result = strategy.run(step_fn, args=(x,)) self.assertAllClose( math_ops.matmul(x, w2) * num_replicas, self.evaluate(strategy.reduce("SUM", result, axis=None)), rtol=5e-3, atol=5e-3) status = checkpoint.restore( checkpoint_management.latest_checkpoint(checkpoint_dir)) status.run_restore_ops(sess) # must run restore op in non-eager mode. status.assert_consumed() status.assert_existing_objects_matched() result = strategy.run(step_fn, args=(x,)) self.assertAllClose( math_ops.matmul(x, w1) * num_replicas, self.evaluate(strategy.reduce("SUM", result, axis=None)), rtol=5e-3, atol=5e-3)
def test_table(self): initializer = lookup_ops.TextFileInitializer( self._vocab_path, key_dtype=dtypes.string, key_index=lookup_ops.TextFileIndex.WHOLE_LINE, value_dtype=dtypes.int64, value_index=lookup_ops.TextFileIndex.LINE_NUMBER) root = checkpoint.Checkpoint( table=lookup_ops.HashTable(initializer, default_value=-1)) root.table_user = def_function.function( root.table.lookup, input_signature=[tensor_spec.TensorSpec(None, dtypes.string)]) self.assertEqual( 2, self.evaluate(root.table_user(constant_op.constant("gamma")))) save_dir = os.path.join(self.get_temp_dir(), "saved_model") save.save(root, save_dir) file_io.delete_file(self._vocab_path) self.assertAllClose({"output_0": [2, 0]}, _import_and_infer(save_dir, {"keys": ["gamma", "alpha"]})) second_dir = os.path.join(self.get_temp_dir(), "second_dir") # Asset paths should track the location the SavedModel is loaded from. file_io.rename(save_dir, second_dir) self.assertAllClose({"output_0": [2, 1]}, _import_and_infer(second_dir, {"keys": ["gamma", "beta"]}))
def testNestedLists(self): a = autotrackable.AutoTrackable() a.l = [] b = autotrackable.AutoTrackable() a.l.append([b]) c = autotrackable.AutoTrackable() a.l[0].append(c) a_deps = util.list_objects(a) self.assertIn(b, a_deps) self.assertIn(c, a_deps) a.l[0].append(1) d = autotrackable.AutoTrackable() a.l[0].append(d) a_deps = util.list_objects(a) self.assertIn(d, a_deps) self.assertIn(b, a_deps) self.assertIn(c, a_deps) self.assertNotIn(1, a_deps) e = autotrackable.AutoTrackable() f = autotrackable.AutoTrackable() a.l1 = [[], [e]] a.l1[0].append(f) a_deps = util.list_objects(a) self.assertIn(e, a_deps) self.assertIn(f, a_deps) checkpoint = util.Checkpoint(a=a) checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt")) a.l[0].append(data_structures.NoDependency([])) a.l[0][-1].append(5) checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt")) # Dirtying the inner list means the root object is unsaveable. a.l[0][1] = 2 with self.assertRaisesRegex(ValueError, "A list element was replaced"): checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
def test_metrics_v2(self): api_label = util._CHECKPOINT_V2 prefix = os.path.join(self.get_temp_dir(), 'ckpt') with context.eager_mode(): ckpt = util.Checkpoint(v=variables_lib.Variable(1.)) self.assertEqual(self._get_time_saved(api_label), 0.0) self.assertEqual(self._get_write_histogram_proto(api_label).num, 0.0) for i in range(3): time_saved = self._get_time_saved(api_label) time.sleep(1) ckpt_path = ckpt.write(file_prefix=prefix) filesize = util._get_checkpoint_size(ckpt_path) self.assertEqual(self._get_checkpoint_size(api_label, filesize), i + 1) self.assertGreater(self._get_time_saved(api_label), time_saved) self.assertEqual(self._get_write_histogram_proto(api_label).num, 3.0) self.assertEqual(self._get_read_histogram_proto(api_label).num, 0.0) time_saved = self._get_time_saved(api_label) with context.eager_mode(): ckpt.restore(ckpt_path) self.assertEqual(self._get_read_histogram_proto(api_label).num, 1.0) # Restoring a checkpoint in the same "job" does not increase training time # saved. self.assertEqual(self._get_time_saved(api_label), time_saved)
def test_lookup_table_compatibility(self): table_module = generate_checkpoint.TableModule() ckpt = checkpoint.Checkpoint(table_module) checkpoint_directory = self.get_temp_dir() checkpoint_path = os.path.join(checkpoint_directory, "ckpt") ckpt.write(checkpoint_path) # Ensure that the checkpoint metadata and keys are the same. legacy_metadata = checkpoint.object_metadata(_LEGACY_TABLE_CHECKPOINT_PATH) metadata = checkpoint.object_metadata(checkpoint_path) def _get_table_node(object_metadata): for child in object_metadata.nodes[0].children: if child.local_name == "lookup_table": return object_metadata.nodes[child.node_id] table_proto = _get_table_node(metadata) legacy_table_proto = _get_table_node(legacy_metadata) self.assertAllEqual( [table_proto.attributes[0].name, table_proto.attributes[0].checkpoint_key], [legacy_table_proto.attributes[0].name, legacy_table_proto.attributes[0].checkpoint_key]) legacy_reader = checkpoint_utils.load_checkpoint( _LEGACY_TABLE_CHECKPOINT_PATH) reader = checkpoint_utils.load_checkpoint(checkpoint_path) self.assertEqual( legacy_reader.get_variable_to_shape_map().keys(), reader.get_variable_to_shape_map().keys()) # Ensure that previous checkpoint can be loaded into current table. ckpt.read(_LEGACY_TABLE_CHECKPOINT_PATH).assert_consumed()
def test_training_loop(self): self.skipTest("b/216201668: revisit parallel device and checkpointing") for _ in range(5): layer = _Dense(5) checkpoint = tracking.Checkpoint(layer=layer) manager = checkpoint_management.CheckpointManager( checkpoint, directory=self.get_temp_dir(), max_to_keep=5) manager.restore_or_initialize() for _ in range(10): x = self.device.pack([ constant_op.constant([[-0.5]]), constant_op.constant([[0.5]]) ]) with self.device: with backprop.GradientTape() as tape: y = layer(x) loss = (y - math_ops.range(5.))**2. parameters = layer.trainable_variables unreduced_gradients = tape.gradient(loss, parameters) reduced_gradients = _collective_sum( unreduced_gradients, num_replicas=len(self.device.components)) for grad, param in zip(reduced_gradients, parameters): param.assign_sub(0.01 * grad) manager.save()
def testSaveRestoreMultipleIterator(self): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") dataset = dataset_ops.Dataset.from_tensor_slices( [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]) dataset = dataset.map(math_ops.square).batch(2) iterator_1 = iter(dataset) get_next_1 = iterator_1.get_next iterator_2 = iter(dataset) get_next_2 = iterator_2.get_next dataset_2 = dataset_ops.Dataset.range(10) iterator_3 = iter(dataset_2) get_next_3 = iterator_3.get_next checkpoint = trackable_utils.Checkpoint(iterator_1=iterator_1, iterator_2=iterator_2, iterator_3=iterator_3) self.assertAllEqual([1, 4], get_next_1()) self.assertAllEqual(0, get_next_3()) self.assertAllEqual(1, get_next_3()) self.assertAllEqual(2, get_next_3()) save_path = checkpoint.save(checkpoint_prefix) self.assertAllEqual([1, 4], get_next_2()) self.assertAllEqual([9, 16], get_next_2()) self.assertAllEqual(3, get_next_3()) checkpoint.restore(save_path).run_restore_ops() self.assertAllEqual([9, 16], get_next_1()) self.assertAllEqual([1, 4], get_next_2()) self.assertAllEqual(3, get_next_3())
def test_signature_attribute_reserved(self): root = checkpoint.Checkpoint(signatures=variables.Variable(1.)) save_dir = os.path.join(self.get_temp_dir(), "saved_model") with self.assertRaisesRegex(ValueError, "del obj.signatures"): save.save(root, save_dir) del root.signatures save.save(root, save_dir)
def save(self, path, compression=None, shard_func=None, checkpoint_args=None): """Implements the save function and checkpoint functionality.""" if context.executing_eagerly() and checkpoint_args: save_dataset = _SaveDataset(self, path, shard_func, compression) save_iterator = iter(save_dataset) if "checkpoint" in checkpoint_args: raise ValueError( "'Invalid `checkpoint_args`. `checkpoint_args` are not allowed " "to include 'checkpoint'." ) checkpoint = checkpoint_lib.Checkpoint(iterator=save_iterator) checkpoint_args["checkpoint"] = checkpoint manager = checkpoint_management.CheckpointManager(**checkpoint_args) checkpoint.restore(manager.latest_checkpoint) for _ in enumerate(save_iterator): if "step_counter" in checkpoint_args: checkpoint_args["step_counter"].assign_add(delta=1) manager.save(check_interval=True) else: dataset, shard_func, use_shard_func, path = set_save_dataset_attributes( self, shard_func, path) ged_ops.save_dataset( dataset._variant_tensor, # pylint: disable=protected-access path=path, shard_func_other_args=shard_func.captured_inputs, compression=compression, shard_func=shard_func, use_shard_func=use_shard_func)
def testCheckpointing(self, distribution, synchronization, aggregation, mode): if (isinstance( distribution, collective_all_reduce_strategy.CollectiveAllReduceStrategy) and mode == "graph"): self.skipTest( "MWMS combinations tests do not work well in graph mode.") with distribution.scope(): v = variables_lib.Variable(constant_op.constant([1., 2., 3., 4]), synchronization=synchronization, aggregation=aggregation) self.evaluate(v.initializer) before_save = self.evaluate(v.read_value()) # Save random weights into checkpoint. checkpoint = trackable_utils.Checkpoint(v=v) prefix = os.path.join(self.get_temp_dir(), "ckpt") with self.test_session(): save_path = checkpoint.save(prefix) # Assign inverted value. self.evaluate(v.assign(constant_op.constant([4., 3., 2., 1.]))) after_assign = self.evaluate(v.read_value()) self.assertNotAllClose(before_save, after_assign) # Restore from the checkpoint. with self.test_session(): checkpoint.restore(save_path).assert_consumed().run_restore_ops() after_restore = self.evaluate(v) self.assertAllClose(before_save, after_restore)
def testStatefulExternalPolicy(self): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") dataset = dataset_ops.Dataset.range(4) def fn(x): return x * x dataset = dataset.map( lambda x: script_ops.eager_py_func(fn, [x], dtypes.int64)) options = options_lib.Options() options.experimental_external_state_policy = ( options_lib.ExternalStatePolicy.WARN) dataset = dataset.with_options(options) iterator = iter(dataset) get_next = iterator.get_next checkpoint = trackable_utils.Checkpoint(iterator=iterator) self.assertEqual(0, get_next().numpy()) self.assertEqual(1, get_next().numpy()) save_path = checkpoint.save(checkpoint_prefix) self.assertEqual(4, get_next().numpy()) self.assertEqual(9, get_next().numpy()) checkpoint.restore(save_path).run_restore_ops() self.assertEqual(4, get_next().numpy()) self.assertEqual(9, get_next().numpy()) with self.assertRaises(errors.OutOfRangeError): get_next()
def test_registered_saver_is_called_before_save_after_load(self): if not context.executing_eagerly(): self.skipTest("This test must run under eager mode.") class RestoreClass(autotrackable.AutoTrackable): pass def save_fn(trackables, file_prefix): del trackables # Unused. # Check that directory is empty files = gfile.ListDirectory(os.path.dirname(file_prefix.numpy())) self.assertEmpty(files) def restore_fn(trackables, merged_prefix): del merged_prefix # Unused. root = next(trackables.values()) self.assertEqual(root.v.numpy(), 123) registration.register_checkpoint_saver( name="OptionalRestore", predicate=lambda x: isinstance(x, RestoreClass), save_fn=save_fn, restore_fn=restore_fn) root = RestoreClass() root.v = variables.Variable(123.0) ckpt_path = os.path.join(self.get_temp_dir(), "ckpt") util.Checkpoint(root).write(ckpt_path)
def test_non_strict_predicate(self): class NonStrictPredicateClass(autotrackable.AutoTrackable): pass registration.register_checkpoint_saver( name="NonStrictPredicate", predicate=lambda x: isinstance(x, NonStrictPredicateClass), save_fn=lambda **kwargs: [], restore_fn=lambda **kwargs: None, strict_predicate_restore=False) root = NonStrictPredicateClass() ckpt_path = os.path.join(self.get_temp_dir(), "ckpt") util.Checkpoint(root).write(ckpt_path) root2 = autotrackable.AutoTrackable() # This should run without throwing an error. util.Checkpoint(root2).read(ckpt_path)
def test_strict_predicate(self): class StrictPredicateClass(autotrackable.AutoTrackable): pass registration.register_checkpoint_saver( name="StrictPredicate", predicate=lambda x: isinstance(x, StrictPredicateClass), save_fn=lambda **kwargs: [], restore_fn=lambda **kwargs: None, strict_predicate_restore=True) root = StrictPredicateClass() ckpt_path = os.path.join(self.get_temp_dir(), "ckpt") util.Checkpoint(root).write(ckpt_path) root2 = autotrackable.AutoTrackable() with self.assertRaisesRegex(ValueError, "saver cannot be used"): util.Checkpoint(root2).read(ckpt_path)
def testDistStratRestore(self, strat1, strat2, jit_replica_fn): """Tests checkpointing and restoring (to possibly different #replicas).""" if strat2 is None: strat2 = strat1 strat1_name = type(strat1).__name__ strat2_name = type(strat2).__name__ if "Default" in strat1_name or "Default" in strat2_name: self.skipTest( "We don't guarantee consistency between strategy and no-strategy.") if ("TPU" in strat1_name or "TPU" in strat2_name) and not jit_replica_fn: self.skipTest( "TPUStrategy requires the replica function (the function passed to " "strategy.run) to be decorated with tf.function") coord1 = None if "ParameterServer" in strat1_name: coord1 = coordinator_lib.ClusterCoordinator(strat1) coord2 = None if "ParameterServer" in strat2_name: coord2 = coordinator_lib.ClusterCoordinator(strat2) fname = os.path.join(self.get_temp_dir(), "checkpoint") def uniform(strat, coord, g): def f(): return g.uniform_full_int([3], dtype=dtypes.int32) replica_fn = def_function.function(f) if jit_replica_fn else f result = run_on_strategy(replica_fn, strat, coord) return strat.experimental_local_results(result) with strat1.scope(): g1 = rng.Generator.from_seed(1) with strat2.scope(): g2 = rng.Generator.from_seed(10) cp1 = tracking_util.Checkpoint(g=g1) cp2 = tracking_util.Checkpoint(g=g2) def write_restore_compare(): cp1.write(fname) r1 = uniform(strat1, coord1, g1) cp2.restore(fname) r2 = uniform(strat2, coord2, g2) # Tests that overlapping replicas are properly restored. n1 = get_num_local_replicas(strat1) n2 = get_num_local_replicas(strat2) n = min(n1, n2) self.assertAllEqual(r1[:n], r2[:n]) # Run multiple times so that cp1.write is called in various RNG states for _ in range(2): write_restore_compare()