def testEagerMultiLearnerCheckpointCompatibility(self):
    self.assertTrue(tf.executing_eagerly())
    cfg = model_registry.GetParams('test.LinearModelParams', 'Train')
    mdl = cfg.Instantiate()
    with py_utils.GradientTape(persistent=True):
      mdl.ConstructFPropBPropGraph()

    eager_v1_logdir = os.path.join(self.get_temp_dir(), 'eager_v1')
    eager_v2_logdir = os.path.join(self.get_temp_dir(), 'eager_v2')
    checkpointer.EagerCheckpointerV1(eager_v1_logdir, mdl).Save(gsteps=0)
    checkpointer.EagerCheckpointerV2(eager_v2_logdir, mdl).Save(gsteps=0)
    eager_v1_keys = _GetCheckpointKeys(
        os.path.join(eager_v1_logdir, 'ckpt_V1', 'ckpt-00000000'))
    eager_v2_keys = _GetCheckpointKeys(
        os.path.join(eager_v2_logdir, 'ckpt_V2', 'ckpt-0'))
    # Expecting two more variables in V2 checkpoints:
    # _CHECKPOINTABLE_OBJECT_GRAPH
    # save_counter
    self.assertEqual(len(eager_v1_keys) + 2, len(eager_v2_keys))  # pylint:disable=g-generic-assert

    py_utils.SetEagerMode(False)
    self.assertFalse(tf.executing_eagerly())
    graph_logdir = os.path.join(self.get_temp_dir(), 'graph')
    os.mkdir(graph_logdir)
    with self.session(graph=tf.Graph()) as sess:
      mdl = cfg.Instantiate()
      for lrn in mdl.GetTask().learners:
        lrn.optimizer.params.clear_variable_scope = False
      mdl.ConstructFPropBPropGraph()
      sess.run(tf.global_variables_initializer())
      checkpointer.Checkpointer(graph_logdir, mdl).Save(sess)
    graph_keys = _GetCheckpointKeys(os.path.join(graph_logdir, 'ckpt'))
    self.assertEqual(eager_v1_keys, graph_keys)
from lingvo.core import py_utils
from lingvo.core import test_utils
from lingvo.core import var_tmp_wrappers


class VarTmpWrappersTest(test_utils.TestCase):
    def testVarWrapperTrackAssign(self):
        with tf.Graph().as_default():
            var = tf.get_variable('v0', shape=[8, 16], dtype=tf.float32)
            wrapper = var_tmp_wrappers.VarWrapperTrackAssign(var)
            ones = tf.ones_like(wrapper)
            a = wrapper.assign(ones)
            b = wrapper.assign_add(ones)
            c = wrapper.assign_sub(ones)
            self.assertSameElements(wrapper.previous_assigns(), [a, b, c])

    def testStackedVarWrapperWithManualSharding(self):
        with tf.Graph().as_default():
            var = tf.get_variable('v2', shape=[8, 16], dtype=tf.float32)
            wrapper = var_tmp_wrappers.StackedVarWrapperWithManualSharding(var)
            ones = tf.ones_like(wrapper)
            wrapper.assign(ones)
            wrapper.assign_add(ones)
            wrapper.assign_sub(ones)
            self.assertEqual(ones.shape, [16])


if __name__ == '__main__':
    py_utils.SetEagerMode(False)
    tf.test.main()
示例#3
0
        self.MaybeConfigRunDistributed()
        self.MaybeConfigCloudTpu()
        self.MaybeLaunchTensorFlow()

        if FLAGS.job.startswith('evaler_once_'):
            # E.g., trainer --model=foo.bar.Model --logdir=...
            # --run_locally=cpu --mode=sync --job=evaler_once_test@65200
            self.RunEvalerOnce()
            return

        self.StartRunners(
            self.CreateRunners(FLAGS.job.split(','), FLAGS.logdir))


def main(unused_argv):
    RunnerManager(FLAGS.model).Start()


if __name__ == '__main__':
    tf.flags.mark_flag_as_required('model')
    FLAGS(sys.argv, known_only=True)
    if FLAGS.disable_tf2:
        tf.disable_v2_behavior()
    py_utils.SetEagerMode(FLAGS.use_eager)
    tf.config.run_functions_eagerly(FLAGS.run_functions_eagerly)
    if FLAGS.enable_tf_data_debug_mode:
        tf.data.experimental.enable_debug_mode()
    model_imports.ImportParams(FLAGS.model)
    FLAGS.unparse_flags()
    tf.app.run(main)
示例#4
0
      return vars2_intermediate, vars2_1, grads2_1, grads2_2

    vars2_intermediate, vars2_1, grads2_1, grads2_2 = _Apply2(proj_layer, opt)
    # Unlike Graph mode, grads2_1['w'][0]/grads2_2['w'][0] returned from
    # `tf.function` are variables after updates. As a result we cannot compare
    # them with e.g. `vars1`.

    self.assertAllClose(vars1, vars2)

    self.assertAllClose(grads1_1, grads2_1)
    self.assertAllClose(grads1_2, grads2_2)

    self.assertAllClose(vars1, vars2_intermediate)

    lr = lr()
    self.assertAllClose(
        vars1[0] - 0.5 * lr * (grads1_1['w'][1] + grads1_2['w'][1]), vars1_1[0])
    self.assertAllClose(
        vars2[0] - 0.5 * lr * (grads2_1['w'][1] + grads2_2['w'][1]), vars2_1[0])

    self.assertAllClose(vars2, vars2_intermediate)
    self.assertAllClose(vars1_1, vars2_1)
    # TODO(jiaweix): Add checks for the event files from tf.summary
    # once we migrate summary_utils to TF2


if __name__ == '__main__':
  py_utils.SetEagerMode(True)
  tf.test.main()
示例#5
0
def main(*args, **kwargs):
    FLAGS(sys.argv, known_only=True)
    py_utils.SetEagerMode(FLAGS.enable_eager_execution)
    FLAGS.unparse_flags()
    tf.test.main(*args, **kwargs)
      return np.array(float(s), dtype=np.float32)

    bucket_fn = lambda x: 1

    # A record processor written in TF graph.
    def _process(source_id, record):
      num, = tf.py_func(str_to_num, [record], [tf.float32])
      num = tf.stack([num, tf.square(num)])
      return py_utils.NestedMap(
          source_id=source_id, record=record, num=num), bucket_fn(num)

    # pylint: disable=protected-access
    len_before = len(generic_input._GENERIC_CACHE_V2)
    _ = generic_input.GenericInput(
        file_pattern='tfrecord:' + tmp,
        file_random_seed=0,
        file_buffer_size=32,
        file_parallelism=4,
        bucket_batch_limit=[8],
        bucket_upper_bound=[1],
        processor=_process,
        generic_input_v2_key=mock_op_key)

    # pylint: disable=protected-access
    len_after = len(generic_input._GENERIC_CACHE_V2)
    self.assertEqual(len_after, len_before + 1)

if __name__ == '__main__':
  py_utils.SetEagerMode()
  tf.test.main()