def main(_): params = registry.get_params(FLAGS.params)(FLAGS.param_overrides) if FLAGS.tfds_train_examples > 0: if not params.train_pattern.startswith("tfds:"): raise ValueError("expect tfds type dataset.") params.train_pattern += "-take_%d" % FLAGS.tfds_train_examples estimator = estimator_utils.create_estimator( FLAGS.master, FLAGS.model_dir, FLAGS.use_tpu, FLAGS.iterations_per_loop, FLAGS.num_shards, params, train_init_checkpoint=FLAGS.train_init_checkpoint, train_warmup_steps=FLAGS.train_warmup_steps, save_checkpoints_steps=FLAGS.save_checkpoints_steps, keep_checkpoint_max=FLAGS.keep_checkpoint_max) # Split training into sesions, walkaround yaqs/5313417304080384 # Tensorflow estimator doesn't respect save_checkpoints_steps when running in # distributed environment if FLAGS.train_steps_overrides: train_steps_list = [ int(s) for s in FLAGS.train_steps_overrides.split(",") if int(s) > 0 ] else: train_steps_list = [params.train_steps] for train_steps in train_steps_list: estimator.train( input_fn=infeed.get_input_fn( params.parser, params.train_pattern, tf.estimator.ModeKeys.TRAIN, parallelism=FLAGS.train_infeed_parallelism), max_steps=train_steps)
def run_train(self, params, param_overrides, max_steps=2): if "batch_size=" not in param_overrides: param_overrides = "batch_size=2," + param_overrides model_params = registry.get_params(params)(param_overrides) model_dir = self.create_tempdir().full_path estimator = estimator_utils.create_estimator("", model_dir, True, 1000, 1, model_params) input_fn = infeed.get_input_fn(model_params.parser, model_params.train_pattern, tf.estimator.ModeKeys.PREDICT, parallelism=8) estimator.train(input_fn=input_fn, max_steps=max_steps) estimator.train(input_fn=input_fn, max_steps=max_steps) eval_input_fn = infeed.get_input_fn(model_params.parser, model_params.dev_pattern, tf.estimator.ModeKeys.EVAL) estimator.evaluate(input_fn=eval_input_fn, steps=1, name="eval")
def test_supervised_parser(self, config): def parser_fn(mode): return parsers.supervised_strings_parser(_SUBWORDS, "sentencepiece", 30, 10, mode) data = infeed.get_input_fn(parser_fn, config, tf.estimator.ModeKeys.TRAIN)({ "batch_size": 4 }) d = next(iter(data)) self.assertEqual(d["inputs"].shape, [4, 30]) self.assertEqual(d["targets"].shape, [4, 10])
def run_eval(self, params, param_overrides, max_steps=2): if "batch_size=" not in param_overrides: param_overrides = "batch_size=2," + param_overrides model_params = registry.get_params(params)(param_overrides) model_dir = self.create_tempdir().full_path estimator = estimator_utils.create_estimator("", model_dir, True, 1000, 1, model_params) input_fn = infeed.get_input_fn(model_params.parser, model_params.test_pattern, tf.estimator.ModeKeys.PREDICT, parallelism=8) predictions = estimator.predict(input_fn=input_fn) predictions = itertools.islice(predictions, max_steps) model_params.eval(predictions, model_dir, 0, "", True)
def main(_): param_overrides = FLAGS.param_overrides or "" param_overrides = param_overrides.replace("use_bfloat16=true", "use_bfloat16=false") params = registry.get_params(FLAGS.params)(FLAGS.param_overrides) estimator = estimator_utils.create_estimator(FLAGS.master, FLAGS.model_dir, FLAGS.use_tpu, FLAGS.iterations_per_loop, FLAGS.num_shards, params) for _ in contrib_training.checkpoints_iterator(FLAGS.model_dir, min_interval_secs=60): global_step = estimator.get_variable_value("global_step") tf.logging.info("Evaluating at global step %d", global_step) input_fn = infeed.get_input_fn(params.parser, params.dev_pattern, tf.estimator.ModeKeys.PREDICT) predictions = estimator.predict(input_fn=input_fn) if params.eval_max_predictions > 0: eval_max_predictions = params.eval_max_predictions predictions = itertools.islice(predictions, eval_max_predictions) else: eval_max_predictions = None params.eval(predictions, FLAGS.model_dir, global_step, "eval_decode_dev", FLAGS.enable_logging) # In eval, topology is 1x1, total batch size is single core batch size. if eval_max_predictions: eval_steps = max( 1, eval_max_predictions // params.batch_size // FLAGS.num_shards) else: eval_steps = None if FLAGS.use_tpu: raise ValueError( "The parameter eval_max_predictions has to be defined on TPU." ) # Token-based metrics (e.g. perplexity, accuracy) calculated on the dev set. estimator.evaluate(input_fn=infeed.get_input_fn( params.parser, params.train_pattern, tf.estimator.ModeKeys.EVAL), steps=eval_steps, name="train") # Token-based metrics calculated on the same set used to train. estimator.evaluate(input_fn=infeed.get_input_fn( params.parser, params.dev_pattern, tf.estimator.ModeKeys.EVAL), steps=eval_steps, name="dev") if global_step >= params.train_steps: break # Run a final eval on entire dev and test sets. input_fn = infeed.get_input_fn(params.parser, params.test_pattern, tf.estimator.ModeKeys.PREDICT) predictions = estimator.predict(input_fn=input_fn) params.eval(predictions, FLAGS.model_dir, global_step, "eval_decode_final_test", FLAGS.enable_logging) input_fn = infeed.get_input_fn(params.parser, params.dev_pattern, tf.estimator.ModeKeys.PREDICT) predictions = estimator.predict(input_fn=input_fn) params.eval(predictions, FLAGS.model_dir, global_step, "eval_decode_final_dev", FLAGS.enable_logging)
def main(_): params = registry.get_params(FLAGS.params)(FLAGS.param_overrides) if FLAGS.tfds_train_examples > 0: if not params.train_pattern.startswith("tfds:"): raise ValueError("expect tfds type dataset.") params.train_pattern += "-take_%d" % FLAGS.tfds_train_examples logging.warning("Flag 1: Creating Estimator") estimator = estimator_utils.create_estimator( FLAGS.master, FLAGS.model_dir, FLAGS.use_tpu, FLAGS.iterations_per_loop, FLAGS.num_shards, params, train_init_checkpoint=FLAGS.train_init_checkpoint, train_warmup_steps=FLAGS.train_warmup_steps, save_checkpoints_steps=FLAGS.save_checkpoints_steps, keep_checkpoint_max=FLAGS.keep_checkpoint_max) # Split training into sesions, walkaround yaqs/5313417304080384 # Tensorflow estimator doesn't respect save_checkpoints_steps when running in # distributed environment if FLAGS.train_steps_overrides: train_steps_list = [ int(s) for s in FLAGS.train_steps_overrides.split(",") if int(s) > 0 ] else: train_steps_list = [params.train_steps] logging.warning("Flag 2: Training the Estimator") if FLAGS.eval_during_training: # EVALUATION DURING TRAINING HOOK - exhaust the NLL input_fn = infeed.get_input_fn(params.parser, params.dev_pattern, tf.estimator.ModeKeys.EVAL) evaluator = tf.estimator.experimental.InMemoryEvaluatorHook( estimator, input_fn, steps=100, hooks=None, name="evaluate_dev", every_n_iter=1000) early_stopping = tf.estimator.experimental.stop_if_no_decrease_hook( estimator, metric_name='loss', eval_dir=estimator.eval_dir() + "_evaluate_dev", max_steps_without_decrease=3000, min_steps=100, run_every_secs=None, run_every_steps=1000) # Train the estimator with evaluation hooks for train_steps in train_steps_list: estimator.train(input_fn=infeed.get_input_fn( params.parser, params.train_pattern, tf.estimator.ModeKeys.TRAIN, parallelism=FLAGS.train_infeed_parallelism), max_steps=train_steps, hooks=[evaluator, early_stopping]) else: # Train the estimator with no evaluation hooks for train_steps in train_steps_list: estimator.train(input_fn=infeed.get_input_fn( params.parser, params.train_pattern, tf.estimator.ModeKeys.TRAIN, parallelism=FLAGS.train_infeed_parallelism), max_steps=train_steps)
def main(_): if not FLAGS.wait and not tf.compat.v1.train.checkpoint_exists( FLAGS.model_dir): raise ValueError(("Checkpoints %s doesn't exist " % FLAGS.model_dir, "and evaluation doesn't wait.")) while True: if tf.compat.v1.train.checkpoint_exists(FLAGS.model_dir): # If checkpoint provided instead of dir, convert eval dir to parent dir. if tf.io.gfile.isdir(FLAGS.model_dir): eval_dir = FLAGS.model_dir if FLAGS.best: checkpoint_id = _get_best_checkpoint_id(FLAGS.model_dir) logging.info("Use best checkpoint id: %d", checkpoint_id) checkpoint_path = os.path.join( FLAGS.model_dir, "model.ckpt-%d" % checkpoint_id) else: checkpoint_path = None else: eval_dir = os.path.dirname(FLAGS.model_dir) checkpoint_path = FLAGS.model_dir if FLAGS.best: raise ValueError("When evaluating the best checkpoint, " "a model dir should be provided " "instead of a specified checkpoint.") params = registry.get_params(FLAGS.params)(FLAGS.param_overrides) if FLAGS.evaluate_test: pattern = params.test_pattern logging.warning( "Evaluating on test set. " "This should be only used for final number report.") else: pattern = params.dev_pattern input_fn = infeed.get_input_fn(params.parser, pattern, tf.estimator.ModeKeys.PREDICT) estimator = estimator_utils.create_estimator( FLAGS.master, eval_dir, FLAGS.use_tpu, FLAGS.iterations_per_loop, FLAGS.num_shards, params) if checkpoint_path: global_step = int(checkpoint_path.split("-")[-1]) else: global_step = estimator.get_variable_value("global_step") predictions = estimator.predict(input_fn=input_fn, checkpoint_path=checkpoint_path) if not FLAGS.full: predictions = itertools.islice(predictions, params.eval_max_predictions) eval_tag = FLAGS.eval_tag if FLAGS.best: eval_tag += ".best" if FLAGS.evaluate_test: eval_tag += ".test" else: eval_tag += ".dev" if FLAGS.full: eval_tag += ".full" params.eval(predictions, eval_dir, global_step, eval_tag, FLAGS.enable_logging) break time.sleep(10)
def main(_): if not FLAGS.wait and not tf.train.checkpoint_exists(FLAGS.model_dir): raise ValueError(("Checkpoints %s doesn't exist " % FLAGS.model_dir, "and evaluation doesn't wait.")) while True: if tf.train.checkpoint_exists(FLAGS.model_dir): logging.warning("Flag 1: Loading model checkpoint from {}".format( FLAGS.model_dir)) # If checkpoint provided instead of dir, convert eval dir to parent dir. if tf.io.gfile.isdir(FLAGS.model_dir): eval_dir = FLAGS.model_dir if FLAGS.best: checkpoint_id = _get_best_checkpoint_id(FLAGS.model_dir) logging.info("Use best checkpoint id: %d", checkpoint_id) checkpoint_path = os.path.join( FLAGS.model_dir, "model.ckpt-%d" % checkpoint_id) else: checkpoint_path = None else: eval_dir = os.path.dirname(FLAGS.model_dir) checkpoint_path = FLAGS.model_dir if FLAGS.best: raise ValueError("When evaluating the best checkpoint, " "a model dir should be provided " "instead of a specified checkpoint.") params = registry.get_params(FLAGS.params)(FLAGS.param_overrides) logging.warning("Flag 2: These are the params: {}".format(params)) if FLAGS.evaluate_test: pattern = params.test_pattern logging.warning( "Evaluating on test set. " "This should be only used for final number report.") else: pattern = params.dev_pattern logging.warning("Flag 3: Getting the input function...") input_fn = infeed.get_input_fn(params.parser, pattern, tf.estimator.ModeKeys.PREDICT) logging.warning("Flag 4: Creating the estimator...") estimator = estimator_utils.create_estimator( FLAGS.master, eval_dir, FLAGS.use_tpu, FLAGS.iterations_per_loop, FLAGS.num_shards, params) if checkpoint_path: global_step = int(checkpoint_path.split("-")[-1]) else: global_step = estimator.get_variable_value("global_step") logging.warning( "Flag 5: Here we define the predictions function to be passed to the estimator..." ) predictions = estimator.predict(input_fn=input_fn, checkpoint_path=checkpoint_path) if not FLAGS.full: predictions = itertools.islice(predictions, params.eval_max_predictions) eval_tag = FLAGS.eval_tag if FLAGS.best: eval_tag += ".best" if FLAGS.evaluate_test: eval_tag += ".test" else: eval_tag += ".dev" if FLAGS.full: eval_tag += ".full" logging.warning( "Flag 6: Entering predictions -> saving INPUT/TARGET/PREDS to {}" .format(FLAGS.model_dir)) params.eval(predictions, eval_dir, global_step, eval_tag, FLAGS.enable_logging) logging.warning( "Flag 7: Evaluations completed -> saving text metrics to {}". format(FLAGS.model_dir)) break time.sleep(10)