def testEagerMultiLearnerCheckpointCompatibility(self): self.assertTrue(tf.executing_eagerly()) cfg = model_registry.GetParams('test.LinearModelParams', 'Train') mdl = cfg.Instantiate() with py_utils.GradientTape(persistent=True): mdl.ConstructFPropBPropGraph() eager_v1_logdir = os.path.join(self.get_temp_dir(), 'eager_v1') eager_v2_logdir = os.path.join(self.get_temp_dir(), 'eager_v2') checkpointer.EagerCheckpointerV1(eager_v1_logdir, mdl).Save(gsteps=0) checkpointer.EagerCheckpointerV2(eager_v2_logdir, mdl).Save(gsteps=0) eager_v1_keys = _GetCheckpointKeys( os.path.join(eager_v1_logdir, 'ckpt_V1', 'ckpt-00000000')) eager_v2_keys = _GetCheckpointKeys( os.path.join(eager_v2_logdir, 'ckpt_V2', 'ckpt-0')) # Expecting two more variables in V2 checkpoints: # _CHECKPOINTABLE_OBJECT_GRAPH # save_counter self.assertEqual(len(eager_v1_keys) + 2, len(eager_v2_keys)) # pylint:disable=g-generic-assert py_utils.SetEagerMode(False) self.assertFalse(tf.executing_eagerly()) graph_logdir = os.path.join(self.get_temp_dir(), 'graph') os.mkdir(graph_logdir) with self.session(graph=tf.Graph()) as sess: mdl = cfg.Instantiate() for lrn in mdl.GetTask().learners: lrn.optimizer.params.clear_variable_scope = False mdl.ConstructFPropBPropGraph() sess.run(tf.global_variables_initializer()) checkpointer.Checkpointer(graph_logdir, mdl).Save(sess) graph_keys = _GetCheckpointKeys(os.path.join(graph_logdir, 'ckpt')) self.assertEqual(eager_v1_keys, graph_keys)
def TrainFunc(): with py_utils.GradientTape(persistent=True): model.ConstructFPropBPropGraph() return task.eval_metrics, task.per_example_tensors
def ModelFunc(): with self._summary_writer.as_default(): with py_utils.GradientTape(persistent=True): model.ConstructFPropBPropGraph() return task.eval_metrics
def _Update(): with py_utils.GradientTape(persistent=True): mdl.ConstructFPropBPropGraph()