def _buildGraphAndSaver(logdir, keep_latest_n=5, keep_every_n_hours=None, save_async=False): tf.random.set_seed(123) g = tf.Graph() with g.as_default(): p = mnist.LeNet5().Task() p.input = mnist.LeNet5().Train() with cluster_factory.ForTestingWorker(mode='sync', job='controller'): _ = p.Instantiate() gsv = py_utils.GetOrCreateGlobalStepVar() inc = gsv.assign_add(1) variables = tf.all_variables() sanity_checks = [([gsv], saver.InRange(0, 10))] for var in variables: sanity_checks.append(([var], saver.IsFinite())) sav = saver.Saver( logdir, variables, sanity_checks, keep_latest_n=keep_latest_n, keep_every_n_hours=keep_every_n_hours, async_save=save_async) return g, sav, inc
def testBasic(self): logdir = tempfile.mkdtemp() # Create a dummy file that looks like a checkpoint that shouldn't # be touched. with tf.io.gfile.GFile(logdir + '/ckpt-foo', 'w') as f: f.write('contents') g = tf.Graph() with g.as_default(): p = mnist.LeNet5().Task() p.input = mnist.LeNet5().Train() with cluster_factory.ForTestingWorker(mode='sync', job='controller'): _ = p.Instantiate() gsv = py_utils.GetOrCreateGlobalStepVar() inc = gsv.assign_add(1) variables = tf.all_variables() sanity_checks = [([gsv], saver.InRange(0, 10))] for var in variables: sanity_checks.append(([var], saver.IsFinite())) sav = saver.Saver(logdir, variables, sanity_checks, keep_latest_n=5, keep_every_n_hours=1e-9) with self.session(graph=g) as sess: # Creates a few checkpoints. sess.run(tf.global_variables_initializer()) for _ in range(10): sess.run(inc) _ = sav.Save(sess) # Restore to the latest. sess.run(tf.global_variables_initializer()) _ = sav.Restore(sess) # Restore to a specific checkpoint. sess.run(tf.global_variables_initializer()) _ = sav.Restore(sess, 6) # Increments global_step out of range, Save() fails. for _ in range(5): sess.run(inc) with self.assertRaises(tf.errors.AbortedError): _ = sav.Save(sess) filenames = tf.io.gfile.glob('{}/*'.format(logdir)) filenames = [x[len(logdir) + 1:] for x in filenames] print('\n'.join(filenames)) self.assertIn('checkpoint', filenames) meta_files = [] for f in filenames: if f.endswith('.meta'): meta_files.append(f) # A .meta for each checkpoint. self.assertEqual(len(meta_files), 6) # 1 for checkpoint. 3 files per checkpoint. 5 good checkpoints, 1 bad. # 1 extra file contains the error message, and 1 dummy file self.assertEqual(len(filenames), 1 + (5 + 1) * 3 + 1 + 1)
def __init__(self, logdir, train_params, variables_to_restore_dict=None, async_save=False): """Create a tf.train.Saver or a custom_saver.Saver. Args: logdir: The directory path to save checkpoints to. train_params: Training parameters. variables_to_restore_dict: A dictionary mapping names to Saveables. Typically, used in evaluation for substituting exponential moving average weights. If this is set, then tf.train.Saver is used. async_save: Save asynchronously. Only works with custom saver. """ self._logdir = logdir self._save_path = os.path.join(self._logdir, 'ckpt') self._use_custom_saver = (FLAGS.use_custom_saver and not variables_to_restore_dict) if async_save and not self._use_custom_saver: tf.logging.warning( 'Asynchronous saving only works with custom saver.') self._keep_latest_n = train_params.save_max_to_keep self._keep_every_n_hours = train_params.save_keep_checkpoint_every_n_hours self._max_steps = train_params.max_steps self._tpu_steps_per_loop = train_params.tpu_steps_per_loop if not self._use_custom_saver: tf.logging.info('Instantiating tf.train.Saver') self._saver = tf.train.Saver( variables_to_restore_dict, sharded=True, max_to_keep=self._keep_latest_n, keep_checkpoint_every_n_hours=self._keep_every_n_hours, pad_step_number=True, # %08d write_version=tf.train.SaverDef.V2) self._var_list = self._saver._var_list # pylint: disable=protected-access else: tf.logging.info('Instantiating custom Saver') gsv = py_utils.GetOrCreateGlobalStepVar() self._var_list = tf.all_variables() if self._max_steps and self._tpu_steps_per_loop: sanity_checks = [ ([gsv], custom_saver.InRange( 0, self._max_steps + self._tpu_steps_per_loop)) ] else: sanity_checks = [] if train_params.checkpoint_finite_check: for var in self._var_list: sanity_checks.append(([var], custom_saver.IsFinite())) self._saver = custom_saver.Saver( logdir, variables=self._var_list, sanity_checks=sanity_checks, keep_latest_n=self._keep_latest_n, keep_every_n_hours=self._keep_every_n_hours, async_save=async_save)