Ejemplo n.º 1
0
 def _buildGraphAndSaver(logdir,
                         keep_latest_n=5,
                         keep_every_n_hours=None,
                         save_async=False):
   tf.random.set_seed(123)
   g = tf.Graph()
   with g.as_default():
     p = mnist.LeNet5().Task()
     p.input = mnist.LeNet5().Train()
     with cluster_factory.ForTestingWorker(mode='sync', job='controller'):
       _ = p.Instantiate()
     gsv = py_utils.GetOrCreateGlobalStepVar()
     inc = gsv.assign_add(1)
     variables = tf.all_variables()
     sanity_checks = [([gsv], saver.InRange(0, 10))]
     for var in variables:
       sanity_checks.append(([var], saver.IsFinite()))
     sav = saver.Saver(
         logdir,
         variables,
         sanity_checks,
         keep_latest_n=keep_latest_n,
         keep_every_n_hours=keep_every_n_hours,
         async_save=save_async)
   return g, sav, inc
Ejemplo n.º 2
0
 def testSingleCheckpoint(self):
     logdir = tempfile.mkdtemp()
     g = tf.Graph()
     with g.as_default():
         _ = py_utils.GetOrCreateGlobalStepVar()
         sav = saver.Saver(logdir, tf.all_variables(), [], keep_latest_n=1)
     with self.session(graph=g) as sess:
         sess.run(tf.global_variables_initializer())
         _ = sav.Save(sess)
Ejemplo n.º 3
0
    def testFProp(self):
        with self.session(use_gpu=False):
            tf.set_random_seed(93820985)
            p = self._testParams()
            mdl = p.Instantiate()
            mdl.FPropDefaultTheta()
            tf.global_variables_initializer().run()
            test_utils.CompareToGoldenSingleFloat(self, 4.472597,
                                                  mdl.loss.eval())

            actual_var_names = [_.name for _ in tf.all_variables()]
            print('all vars \n', '\n'.join(actual_var_names))
            expected_var_names = [
                'global_step:0', 'test_mdl/enc/conv_L0/w/var:0',
                'test_mdl/enc/conv_L0/beta/var:0',
                'test_mdl/enc/conv_L0/gamma/var:0',
                'test_mdl/enc/conv_L0/moving_mean/var:0',
                'test_mdl/enc/conv_L0/moving_variance/var:0',
                'test_mdl/enc/conv_L1/w/var:0',
                'test_mdl/enc/conv_L1/beta/var:0',
                'test_mdl/enc/conv_L1/gamma/var:0',
                'test_mdl/enc/conv_L1/moving_mean/var:0',
                'test_mdl/enc/conv_L1/moving_variance/var:0',
                'test_mdl/enc/f_conv_lstm_0/wm/var:0',
                'test_mdl/enc/f_conv_lstm_0/b/var:0',
                'test_mdl/enc/b_conv_lstm_0/wm/var:0',
                'test_mdl/enc/b_conv_lstm_0/b/var:0',
                'test_mdl/enc/conv_lstm_cnn_0/w/var:0',
                'test_mdl/enc/conv_lstm_cnn_0/beta/var:0',
                'test_mdl/enc/conv_lstm_cnn_0/gamma/var:0',
                'test_mdl/enc/conv_lstm_cnn_0/moving_mean/var:0',
                'test_mdl/enc/conv_lstm_cnn_0/moving_variance/var:0',
                'test_mdl/enc/fwd_rnn_L0/wm/var:0',
                'test_mdl/enc/fwd_rnn_L0/b/var:0',
                'test_mdl/enc/bak_rnn_L0/wm/var:0',
                'test_mdl/enc/bak_rnn_L0/b/var:0',
                'test_mdl/enc/proj_L0/w/var:0',
                'test_mdl/enc/proj_L0/beta/var:0',
                'test_mdl/enc/proj_L0/gamma/var:0',
                'test_mdl/enc/proj_L0/moving_mean/var:0',
                'test_mdl/enc/proj_L0/moving_variance/var:0',
                'test_mdl/enc/fwd_rnn_L1/wm/var:0',
                'test_mdl/enc/fwd_rnn_L1/b/var:0',
                'test_mdl/enc/bak_rnn_L1/wm/var:0',
                'test_mdl/enc/bak_rnn_L1/b/var:0',
                'test_mdl/enc/proj_L1/w/var:0',
                'test_mdl/enc/proj_L1/beta/var:0',
                'test_mdl/enc/proj_L1/gamma/var:0',
                'test_mdl/enc/proj_L1/moving_mean/var:0',
                'test_mdl/enc/proj_L1/moving_variance/var:0',
                'test_mdl/enc/fwd_rnn_L2/wm/var:0',
                'test_mdl/enc/fwd_rnn_L2/b/var:0',
                'test_mdl/enc/bak_rnn_L2/wm/var:0',
                'test_mdl/enc/bak_rnn_L2/b/var:0',
                'test_mdl/dec/emb/var_0/var:0',
                'test_mdl/dec/rnn_cell/wm/var:0',
                'test_mdl/dec/rnn_cell/b/var:0',
                'test_mdl/dec/atten/source_var/var:0',
                'test_mdl/dec/atten/query_var/var:0',
                'test_mdl/dec/atten/hidden_var/var:0',
                'test_mdl/dec/softmax/weight_0/var:0',
                'test_mdl/dec/softmax/bias_0/var:0'
            ]
            self.assertEqual(sorted(expected_var_names),
                             sorted(actual_var_names))
Ejemplo n.º 4
0
    def testBasic(self):
        logdir = tempfile.mkdtemp()
        # Create a dummy file that looks like a checkpoint that shouldn't
        # be touched.
        with tf.io.gfile.GFile(logdir + '/ckpt-foo', 'w') as f:
            f.write('contents')

        g = tf.Graph()
        with g.as_default():
            p = mnist.LeNet5().Task()
            p.input = mnist.LeNet5().Train()
            with cluster_factory.ForTestingWorker(mode='sync',
                                                  job='controller'):
                _ = p.Instantiate()
            gsv = py_utils.GetOrCreateGlobalStepVar()
            inc = gsv.assign_add(1)
            variables = tf.all_variables()
            sanity_checks = [([gsv], saver.InRange(0, 10))]
            for var in variables:
                sanity_checks.append(([var], saver.IsFinite()))
            sav = saver.Saver(logdir,
                              variables,
                              sanity_checks,
                              keep_latest_n=5,
                              keep_every_n_hours=1e-9)

        with self.session(graph=g) as sess:
            # Creates a few checkpoints.
            sess.run(tf.global_variables_initializer())
            for _ in range(10):
                sess.run(inc)
                _ = sav.Save(sess)

            # Restore to the latest.
            sess.run(tf.global_variables_initializer())
            _ = sav.Restore(sess)

            # Restore to a specific checkpoint.
            sess.run(tf.global_variables_initializer())
            _ = sav.Restore(sess, 6)

            # Increments global_step out of range, Save() fails.
            for _ in range(5):
                sess.run(inc)
            with self.assertRaises(tf.errors.AbortedError):
                _ = sav.Save(sess)

        filenames = tf.io.gfile.glob('{}/*'.format(logdir))
        filenames = [x[len(logdir) + 1:] for x in filenames]
        print('\n'.join(filenames))
        self.assertIn('checkpoint', filenames)

        meta_files = []
        for f in filenames:
            if f.endswith('.meta'):
                meta_files.append(f)
        # A .meta for each checkpoint.
        self.assertEqual(len(meta_files), 6)

        # 1 for checkpoint. 3 files per checkpoint. 5 good checkpoints, 1 bad.
        # 1 extra file contains the error message, and 1 dummy file
        self.assertEqual(len(filenames), 1 + (5 + 1) * 3 + 1 + 1)
Ejemplo n.º 5
0
    def __init__(self,
                 logdir,
                 train_params,
                 variables_to_restore_dict=None,
                 async_save=False):
        """Create a tf.train.Saver or a custom_saver.Saver.

    Args:
      logdir: The directory path to save checkpoints to.
      train_params: Training parameters.
      variables_to_restore_dict: A dictionary mapping names to Saveables.
        Typically, used in evaluation for substituting exponential moving
        average weights.  If this is set, then tf.train.Saver is used.
      async_save: Save asynchronously. Only works with custom saver.
    """
        self._logdir = logdir
        self._save_path = os.path.join(self._logdir, 'ckpt')
        self._use_custom_saver = (FLAGS.use_custom_saver
                                  and not variables_to_restore_dict)
        if async_save and not self._use_custom_saver:
            tf.logging.warning(
                'Asynchronous saving only works with custom saver.')

        self._keep_latest_n = train_params.save_max_to_keep
        self._keep_every_n_hours = train_params.save_keep_checkpoint_every_n_hours
        self._max_steps = train_params.max_steps
        self._tpu_steps_per_loop = train_params.tpu_steps_per_loop

        if not self._use_custom_saver:
            tf.logging.info('Instantiating tf.train.Saver')
            self._saver = tf.train.Saver(
                variables_to_restore_dict,
                sharded=True,
                max_to_keep=self._keep_latest_n,
                keep_checkpoint_every_n_hours=self._keep_every_n_hours,
                pad_step_number=True,  # %08d
                write_version=tf.train.SaverDef.V2)
            self._var_list = self._saver._var_list  # pylint: disable=protected-access
        else:
            tf.logging.info('Instantiating custom Saver')
            gsv = py_utils.GetOrCreateGlobalStepVar()
            self._var_list = tf.all_variables()

            if self._max_steps and self._tpu_steps_per_loop:
                sanity_checks = [
                    ([gsv],
                     custom_saver.InRange(
                         0, self._max_steps + self._tpu_steps_per_loop))
                ]
            else:
                sanity_checks = []

            if train_params.checkpoint_finite_check:
                for var in self._var_list:
                    sanity_checks.append(([var], custom_saver.IsFinite()))

            self._saver = custom_saver.Saver(
                logdir,
                variables=self._var_list,
                sanity_checks=sanity_checks,
                keep_latest_n=self._keep_latest_n,
                keep_every_n_hours=self._keep_every_n_hours,
                async_save=async_save)