Beispiel #1
0
def main(_):
  params = registry.get_params(FLAGS.params)(FLAGS.param_overrides)
  if FLAGS.tfds_train_examples > 0:
    if not params.train_pattern.startswith("tfds:"):
      raise ValueError("expect tfds type dataset.")
    params.train_pattern += "-take_%d" % FLAGS.tfds_train_examples
  estimator = estimator_utils.create_estimator(
      FLAGS.master,
      FLAGS.model_dir,
      FLAGS.use_tpu,
      FLAGS.iterations_per_loop,
      FLAGS.num_shards,
      params,
      train_init_checkpoint=FLAGS.train_init_checkpoint,
      train_warmup_steps=FLAGS.train_warmup_steps,
      save_checkpoints_steps=FLAGS.save_checkpoints_steps,
      keep_checkpoint_max=FLAGS.keep_checkpoint_max)

  # Split training into sesions, walkaround yaqs/5313417304080384
  # Tensorflow estimator doesn't respect save_checkpoints_steps when running in
  # distributed environment
  if FLAGS.train_steps_overrides:
    train_steps_list = [
        int(s) for s in FLAGS.train_steps_overrides.split(",") if int(s) > 0
    ]
  else:
    train_steps_list = [params.train_steps]
  for train_steps in train_steps_list:
    estimator.train(
        input_fn=infeed.get_input_fn(
            params.parser,
            params.train_pattern,
            tf.estimator.ModeKeys.TRAIN,
            parallelism=FLAGS.train_infeed_parallelism),
        max_steps=train_steps)
Beispiel #2
0
 def run_train(self, params, param_overrides, max_steps=2):
     if "batch_size=" not in param_overrides:
         param_overrides = "batch_size=2," + param_overrides
     model_params = registry.get_params(params)(param_overrides)
     model_dir = self.create_tempdir().full_path
     estimator = estimator_utils.create_estimator("", model_dir, True, 1000,
                                                  1, model_params)
     input_fn = infeed.get_input_fn(model_params.parser,
                                    model_params.train_pattern,
                                    tf.estimator.ModeKeys.PREDICT,
                                    parallelism=8)
     estimator.train(input_fn=input_fn, max_steps=max_steps)
     estimator.train(input_fn=input_fn, max_steps=max_steps)
     eval_input_fn = infeed.get_input_fn(model_params.parser,
                                         model_params.dev_pattern,
                                         tf.estimator.ModeKeys.EVAL)
     estimator.evaluate(input_fn=eval_input_fn, steps=1, name="eval")
Beispiel #3
0
    def test_supervised_parser(self, config):

        def parser_fn(mode):
            return parsers.supervised_strings_parser(_SUBWORDS, "sentencepiece", 30,
                                                     10, mode)

        data = infeed.get_input_fn(parser_fn, config, tf.estimator.ModeKeys.TRAIN)({
            "batch_size": 4
        })
        d = next(iter(data))
        self.assertEqual(d["inputs"].shape, [4, 30])
        self.assertEqual(d["targets"].shape, [4, 10])
Beispiel #4
0
 def run_eval(self, params, param_overrides, max_steps=2):
     if "batch_size=" not in param_overrides:
         param_overrides = "batch_size=2," + param_overrides
     model_params = registry.get_params(params)(param_overrides)
     model_dir = self.create_tempdir().full_path
     estimator = estimator_utils.create_estimator("", model_dir, True, 1000,
                                                  1, model_params)
     input_fn = infeed.get_input_fn(model_params.parser,
                                    model_params.test_pattern,
                                    tf.estimator.ModeKeys.PREDICT,
                                    parallelism=8)
     predictions = estimator.predict(input_fn=input_fn)
     predictions = itertools.islice(predictions, max_steps)
     model_params.eval(predictions, model_dir, 0, "", True)
def main(_):
    param_overrides = FLAGS.param_overrides or ""
    param_overrides = param_overrides.replace("use_bfloat16=true",
                                              "use_bfloat16=false")

    params = registry.get_params(FLAGS.params)(FLAGS.param_overrides)
    estimator = estimator_utils.create_estimator(FLAGS.master, FLAGS.model_dir,
                                                 FLAGS.use_tpu,
                                                 FLAGS.iterations_per_loop,
                                                 FLAGS.num_shards, params)

    for _ in contrib_training.checkpoints_iterator(FLAGS.model_dir,
                                                   min_interval_secs=60):
        global_step = estimator.get_variable_value("global_step")
        tf.logging.info("Evaluating at global step %d", global_step)

        input_fn = infeed.get_input_fn(params.parser, params.dev_pattern,
                                       tf.estimator.ModeKeys.PREDICT)
        predictions = estimator.predict(input_fn=input_fn)
        if params.eval_max_predictions > 0:
            eval_max_predictions = params.eval_max_predictions
            predictions = itertools.islice(predictions, eval_max_predictions)
        else:
            eval_max_predictions = None
        params.eval(predictions, FLAGS.model_dir, global_step,
                    "eval_decode_dev", FLAGS.enable_logging)

        # In eval, topology is 1x1, total batch size is single core batch size.
        if eval_max_predictions:
            eval_steps = max(
                1,
                eval_max_predictions // params.batch_size // FLAGS.num_shards)
        else:
            eval_steps = None
            if FLAGS.use_tpu:
                raise ValueError(
                    "The parameter eval_max_predictions has to be defined on TPU."
                )

        # Token-based metrics (e.g. perplexity, accuracy) calculated on the dev set.
        estimator.evaluate(input_fn=infeed.get_input_fn(
            params.parser, params.train_pattern, tf.estimator.ModeKeys.EVAL),
                           steps=eval_steps,
                           name="train")

        # Token-based metrics calculated on the same set used to train.
        estimator.evaluate(input_fn=infeed.get_input_fn(
            params.parser, params.dev_pattern, tf.estimator.ModeKeys.EVAL),
                           steps=eval_steps,
                           name="dev")

        if global_step >= params.train_steps:
            break

    # Run a final eval on entire dev and test sets.
    input_fn = infeed.get_input_fn(params.parser, params.test_pattern,
                                   tf.estimator.ModeKeys.PREDICT)
    predictions = estimator.predict(input_fn=input_fn)
    params.eval(predictions, FLAGS.model_dir, global_step,
                "eval_decode_final_test", FLAGS.enable_logging)
    input_fn = infeed.get_input_fn(params.parser, params.dev_pattern,
                                   tf.estimator.ModeKeys.PREDICT)
    predictions = estimator.predict(input_fn=input_fn)
    params.eval(predictions, FLAGS.model_dir, global_step,
                "eval_decode_final_dev", FLAGS.enable_logging)
Beispiel #6
0
def main(_):
    params = registry.get_params(FLAGS.params)(FLAGS.param_overrides)
    if FLAGS.tfds_train_examples > 0:
        if not params.train_pattern.startswith("tfds:"):
            raise ValueError("expect tfds type dataset.")
        params.train_pattern += "-take_%d" % FLAGS.tfds_train_examples

    logging.warning("Flag 1: Creating Estimator")
    estimator = estimator_utils.create_estimator(
        FLAGS.master,
        FLAGS.model_dir,
        FLAGS.use_tpu,
        FLAGS.iterations_per_loop,
        FLAGS.num_shards,
        params,
        train_init_checkpoint=FLAGS.train_init_checkpoint,
        train_warmup_steps=FLAGS.train_warmup_steps,
        save_checkpoints_steps=FLAGS.save_checkpoints_steps,
        keep_checkpoint_max=FLAGS.keep_checkpoint_max)

    # Split training into sesions, walkaround yaqs/5313417304080384
    # Tensorflow estimator doesn't respect save_checkpoints_steps when running in
    # distributed environment
    if FLAGS.train_steps_overrides:
        train_steps_list = [
            int(s) for s in FLAGS.train_steps_overrides.split(",")
            if int(s) > 0
        ]
    else:
        train_steps_list = [params.train_steps]

    logging.warning("Flag 2: Training the Estimator")
    if FLAGS.eval_during_training:
        # EVALUATION DURING TRAINING HOOK - exhaust the NLL
        input_fn = infeed.get_input_fn(params.parser, params.dev_pattern,
                                       tf.estimator.ModeKeys.EVAL)
        evaluator = tf.estimator.experimental.InMemoryEvaluatorHook(
            estimator,
            input_fn,
            steps=100,
            hooks=None,
            name="evaluate_dev",
            every_n_iter=1000)
        early_stopping = tf.estimator.experimental.stop_if_no_decrease_hook(
            estimator,
            metric_name='loss',
            eval_dir=estimator.eval_dir() + "_evaluate_dev",
            max_steps_without_decrease=3000,
            min_steps=100,
            run_every_secs=None,
            run_every_steps=1000)

        # Train the estimator with evaluation hooks
        for train_steps in train_steps_list:
            estimator.train(input_fn=infeed.get_input_fn(
                params.parser,
                params.train_pattern,
                tf.estimator.ModeKeys.TRAIN,
                parallelism=FLAGS.train_infeed_parallelism),
                            max_steps=train_steps,
                            hooks=[evaluator, early_stopping])

    else:
        # Train the estimator with no evaluation hooks
        for train_steps in train_steps_list:
            estimator.train(input_fn=infeed.get_input_fn(
                params.parser,
                params.train_pattern,
                tf.estimator.ModeKeys.TRAIN,
                parallelism=FLAGS.train_infeed_parallelism),
                            max_steps=train_steps)
Beispiel #7
0
def main(_):
    if not FLAGS.wait and not tf.compat.v1.train.checkpoint_exists(
            FLAGS.model_dir):
        raise ValueError(("Checkpoints %s doesn't exist " % FLAGS.model_dir,
                          "and evaluation doesn't wait."))

    while True:
        if tf.compat.v1.train.checkpoint_exists(FLAGS.model_dir):

            # If checkpoint provided instead of dir, convert eval dir to parent dir.
            if tf.io.gfile.isdir(FLAGS.model_dir):
                eval_dir = FLAGS.model_dir
                if FLAGS.best:
                    checkpoint_id = _get_best_checkpoint_id(FLAGS.model_dir)
                    logging.info("Use best checkpoint id: %d", checkpoint_id)
                    checkpoint_path = os.path.join(
                        FLAGS.model_dir, "model.ckpt-%d" % checkpoint_id)
                else:
                    checkpoint_path = None
            else:
                eval_dir = os.path.dirname(FLAGS.model_dir)
                checkpoint_path = FLAGS.model_dir
                if FLAGS.best:
                    raise ValueError("When evaluating the best checkpoint, "
                                     "a model dir should be provided "
                                     "instead of a specified checkpoint.")

            params = registry.get_params(FLAGS.params)(FLAGS.param_overrides)
            if FLAGS.evaluate_test:
                pattern = params.test_pattern
                logging.warning(
                    "Evaluating on test set. "
                    "This should be only used for final number report.")
            else:
                pattern = params.dev_pattern
            input_fn = infeed.get_input_fn(params.parser, pattern,
                                           tf.estimator.ModeKeys.PREDICT)
            estimator = estimator_utils.create_estimator(
                FLAGS.master, eval_dir, FLAGS.use_tpu,
                FLAGS.iterations_per_loop, FLAGS.num_shards, params)
            if checkpoint_path:
                global_step = int(checkpoint_path.split("-")[-1])
            else:
                global_step = estimator.get_variable_value("global_step")

            predictions = estimator.predict(input_fn=input_fn,
                                            checkpoint_path=checkpoint_path)
            if not FLAGS.full:
                predictions = itertools.islice(predictions,
                                               params.eval_max_predictions)

            eval_tag = FLAGS.eval_tag
            if FLAGS.best:
                eval_tag += ".best"
            if FLAGS.evaluate_test:
                eval_tag += ".test"
            else:
                eval_tag += ".dev"
            if FLAGS.full:
                eval_tag += ".full"

            params.eval(predictions, eval_dir, global_step, eval_tag,
                        FLAGS.enable_logging)

            break
        time.sleep(10)
def main(_):
    if not FLAGS.wait and not tf.train.checkpoint_exists(FLAGS.model_dir):
        raise ValueError(("Checkpoints %s doesn't exist " % FLAGS.model_dir,
                          "and evaluation doesn't wait."))

    while True:
        if tf.train.checkpoint_exists(FLAGS.model_dir):
            logging.warning("Flag 1: Loading model checkpoint from {}".format(
                FLAGS.model_dir))
            # If checkpoint provided instead of dir, convert eval dir to parent dir.
            if tf.io.gfile.isdir(FLAGS.model_dir):
                eval_dir = FLAGS.model_dir
                if FLAGS.best:
                    checkpoint_id = _get_best_checkpoint_id(FLAGS.model_dir)
                    logging.info("Use best checkpoint id: %d", checkpoint_id)
                    checkpoint_path = os.path.join(
                        FLAGS.model_dir, "model.ckpt-%d" % checkpoint_id)
                else:
                    checkpoint_path = None
            else:
                eval_dir = os.path.dirname(FLAGS.model_dir)
                checkpoint_path = FLAGS.model_dir
                if FLAGS.best:
                    raise ValueError("When evaluating the best checkpoint, "
                                     "a model dir should be provided "
                                     "instead of a specified checkpoint.")

            params = registry.get_params(FLAGS.params)(FLAGS.param_overrides)
            logging.warning("Flag 2: These are the params: {}".format(params))
            if FLAGS.evaluate_test:
                pattern = params.test_pattern
                logging.warning(
                    "Evaluating on test set. "
                    "This should be only used for final number report.")
            else:
                pattern = params.dev_pattern
            logging.warning("Flag 3: Getting the input function...")
            input_fn = infeed.get_input_fn(params.parser, pattern,
                                           tf.estimator.ModeKeys.PREDICT)
            logging.warning("Flag 4: Creating the estimator...")
            estimator = estimator_utils.create_estimator(
                FLAGS.master, eval_dir, FLAGS.use_tpu,
                FLAGS.iterations_per_loop, FLAGS.num_shards, params)
            if checkpoint_path:
                global_step = int(checkpoint_path.split("-")[-1])
            else:
                global_step = estimator.get_variable_value("global_step")
            logging.warning(
                "Flag 5: Here we define the predictions function to be passed to the estimator..."
            )
            predictions = estimator.predict(input_fn=input_fn,
                                            checkpoint_path=checkpoint_path)
            if not FLAGS.full:
                predictions = itertools.islice(predictions,
                                               params.eval_max_predictions)

            eval_tag = FLAGS.eval_tag
            if FLAGS.best:
                eval_tag += ".best"
            if FLAGS.evaluate_test:
                eval_tag += ".test"
            else:
                eval_tag += ".dev"
            if FLAGS.full:
                eval_tag += ".full"
            logging.warning(
                "Flag 6: Entering predictions -> saving INPUT/TARGET/PREDS to {}"
                .format(FLAGS.model_dir))
            params.eval(predictions, eval_dir, global_step, eval_tag,
                        FLAGS.enable_logging)
            logging.warning(
                "Flag 7: Evaluations completed -> saving text metrics to {}".
                format(FLAGS.model_dir))
            break
        time.sleep(10)