def test_register_params(self): called = [False] def my_fn(): called[0] = True registry.register("test_params")(my_fn) registry.get_params("test_params")() self.assertTrue(called[0])
def main(_): params = registry.get_params(FLAGS.params)(FLAGS.param_overrides) if FLAGS.tfds_train_examples > 0: if not params.train_pattern.startswith("tfds:"): raise ValueError("expect tfds type dataset.") params.train_pattern += "-take_%d" % FLAGS.tfds_train_examples estimator = estimator_utils.create_estimator( FLAGS.master, FLAGS.model_dir, FLAGS.use_tpu, FLAGS.iterations_per_loop, FLAGS.num_shards, params, train_init_checkpoint=FLAGS.train_init_checkpoint, train_warmup_steps=FLAGS.train_warmup_steps, save_checkpoints_steps=FLAGS.save_checkpoints_steps, keep_checkpoint_max=FLAGS.keep_checkpoint_max) # Split training into sesions, walkaround yaqs/5313417304080384 # Tensorflow estimator doesn't respect save_checkpoints_steps when running in # distributed environment if FLAGS.train_steps_overrides: train_steps_list = [ int(s) for s in FLAGS.train_steps_overrides.split(",") if int(s) > 0 ] else: train_steps_list = [params.train_steps] for train_steps in train_steps_list: estimator.train( input_fn=infeed.get_input_fn( params.parser, params.train_pattern, tf.estimator.ModeKeys.TRAIN, parallelism=FLAGS.train_infeed_parallelism), max_steps=train_steps)
def run_eval(self, params, param_overrides, max_steps=2): if "batch_size=" not in param_overrides: param_overrides = "batch_size=2," + param_overrides model_params = registry.get_params(params)(param_overrides) model_dir = self.create_tempdir().full_path estimator = estimator_utils.create_estimator("", model_dir, True, 1000, 1, model_params) input_fn = infeed.get_input_fn(model_params.parser, model_params.test_pattern, tf.estimator.ModeKeys.PREDICT, parallelism=8) predictions = estimator.predict(input_fn=input_fn) predictions = itertools.islice(predictions, max_steps) model_params.eval(predictions, model_dir, 0, "", True)
def run_train(self, params, param_overrides, max_steps=2): if "batch_size=" not in param_overrides: param_overrides = "batch_size=2," + param_overrides model_params = registry.get_params(params)(param_overrides) model_dir = self.create_tempdir().full_path estimator = estimator_utils.create_estimator("", model_dir, True, 1000, 1, model_params) input_fn = infeed.get_input_fn(model_params.parser, model_params.train_pattern, tf.estimator.ModeKeys.PREDICT, parallelism=8) estimator.train(input_fn=input_fn, max_steps=max_steps) estimator.train(input_fn=input_fn, max_steps=max_steps) eval_input_fn = infeed.get_input_fn(model_params.parser, model_params.dev_pattern, tf.estimator.ModeKeys.EVAL) estimator.evaluate(input_fn=eval_input_fn, steps=1, name="eval")
def run(self): checkpoint_path = self.model_dir checkpoint_path = tf.train.latest_checkpoint(checkpoint_path) params = registry.get_params(self.params_transformer)( self.param_overrides) parser, shapes = params.parser(mode=tf.estimator.ModeKeys.PREDICT) estimator = estimator_utils.create_estimator(self.master, self.model_dir, self.use_tpu, self.iterations_per_loop, self.num_shards, params) encoder = public_parsing_ops.create_text_encoder( self.encoder_type, self.vocab_filename) def input_function(params): input_text1 = "hello this is a first text" target1 = "first text" input_text2 = "Eighteen sailors were injured after an explosion and fire on board a ship at the US Naval Base in San Diego, US Navy officials said.The sailors on the USS Bonhomme Richard had 'minor injuries' from the fire and were taken to a hospital, Lt. Cmdr. Patricia Kreuzberger told CNN." target2 = "18 sailors injured after an explosion and fire on a naval ship in San Diego" read_dictionary_data = np.load(self.test_dict_dataset_path, allow_pickle='TRUE').item() # dataset = tf.data.Dataset.from_tensor_slices({"inputs":[input_text1, input_text2],"targets":[target1, target2]}).map(parser) dataset = tf.data.Dataset.from_tensor_slices( read_dictionary_data).map(parser) dataset = dataset.unbatch() dataset = dataset.padded_batch(params["batch_size"], padded_shapes=shapes, drop_remainder=True) return dataset predictions = estimator.predict(input_fn=input_function, checkpoint_path=checkpoint_path) for i in predictions: print( "=======================================================================================================================================" ) print("inputs: " + text_eval.ids2str(encoder, i['inputs'], None)) print("targets: " + text_eval.ids2str(encoder, i['targets'], None)) print("outputs: " + text_eval.ids2str(encoder, i['outputs'], None))
def main(_): param_overrides = FLAGS.param_overrides or "" param_overrides = param_overrides.replace("use_bfloat16=true", "use_bfloat16=false") params = registry.get_params(FLAGS.params)(FLAGS.param_overrides) estimator = estimator_utils.create_estimator(FLAGS.master, FLAGS.model_dir, FLAGS.use_tpu, FLAGS.iterations_per_loop, FLAGS.num_shards, params) for _ in contrib_training.checkpoints_iterator(FLAGS.model_dir, min_interval_secs=60): global_step = estimator.get_variable_value("global_step") tf.logging.info("Evaluating at global step %d", global_step) input_fn = infeed.get_input_fn(params.parser, params.dev_pattern, tf.estimator.ModeKeys.PREDICT) predictions = estimator.predict(input_fn=input_fn) if params.eval_max_predictions > 0: eval_max_predictions = params.eval_max_predictions predictions = itertools.islice(predictions, eval_max_predictions) else: eval_max_predictions = None params.eval(predictions, FLAGS.model_dir, global_step, "eval_decode_dev", FLAGS.enable_logging) # In eval, topology is 1x1, total batch size is single core batch size. if eval_max_predictions: eval_steps = max( 1, eval_max_predictions // params.batch_size // FLAGS.num_shards) else: eval_steps = None if FLAGS.use_tpu: raise ValueError( "The parameter eval_max_predictions has to be defined on TPU." ) # Token-based metrics (e.g. perplexity, accuracy) calculated on the dev set. estimator.evaluate(input_fn=infeed.get_input_fn( params.parser, params.train_pattern, tf.estimator.ModeKeys.EVAL), steps=eval_steps, name="train") # Token-based metrics calculated on the same set used to train. estimator.evaluate(input_fn=infeed.get_input_fn( params.parser, params.dev_pattern, tf.estimator.ModeKeys.EVAL), steps=eval_steps, name="dev") if global_step >= params.train_steps: break # Run a final eval on entire dev and test sets. input_fn = infeed.get_input_fn(params.parser, params.test_pattern, tf.estimator.ModeKeys.PREDICT) predictions = estimator.predict(input_fn=input_fn) params.eval(predictions, FLAGS.model_dir, global_step, "eval_decode_final_test", FLAGS.enable_logging) input_fn = infeed.get_input_fn(params.parser, params.dev_pattern, tf.estimator.ModeKeys.PREDICT) predictions = estimator.predict(input_fn=input_fn) params.eval(predictions, FLAGS.model_dir, global_step, "eval_decode_final_dev", FLAGS.enable_logging)
def main(_): params = registry.get_params(FLAGS.params)(FLAGS.param_overrides) if FLAGS.tfds_train_examples > 0: if not params.train_pattern.startswith("tfds:"): raise ValueError("expect tfds type dataset.") params.train_pattern += "-take_%d" % FLAGS.tfds_train_examples logging.warning("Flag 1: Creating Estimator") estimator = estimator_utils.create_estimator( FLAGS.master, FLAGS.model_dir, FLAGS.use_tpu, FLAGS.iterations_per_loop, FLAGS.num_shards, params, train_init_checkpoint=FLAGS.train_init_checkpoint, train_warmup_steps=FLAGS.train_warmup_steps, save_checkpoints_steps=FLAGS.save_checkpoints_steps, keep_checkpoint_max=FLAGS.keep_checkpoint_max) # Split training into sesions, walkaround yaqs/5313417304080384 # Tensorflow estimator doesn't respect save_checkpoints_steps when running in # distributed environment if FLAGS.train_steps_overrides: train_steps_list = [ int(s) for s in FLAGS.train_steps_overrides.split(",") if int(s) > 0 ] else: train_steps_list = [params.train_steps] logging.warning("Flag 2: Training the Estimator") if FLAGS.eval_during_training: # EVALUATION DURING TRAINING HOOK - exhaust the NLL input_fn = infeed.get_input_fn(params.parser, params.dev_pattern, tf.estimator.ModeKeys.EVAL) evaluator = tf.estimator.experimental.InMemoryEvaluatorHook( estimator, input_fn, steps=100, hooks=None, name="evaluate_dev", every_n_iter=1000) early_stopping = tf.estimator.experimental.stop_if_no_decrease_hook( estimator, metric_name='loss', eval_dir=estimator.eval_dir() + "_evaluate_dev", max_steps_without_decrease=3000, min_steps=100, run_every_secs=None, run_every_steps=1000) # Train the estimator with evaluation hooks for train_steps in train_steps_list: estimator.train(input_fn=infeed.get_input_fn( params.parser, params.train_pattern, tf.estimator.ModeKeys.TRAIN, parallelism=FLAGS.train_infeed_parallelism), max_steps=train_steps, hooks=[evaluator, early_stopping]) else: # Train the estimator with no evaluation hooks for train_steps in train_steps_list: estimator.train(input_fn=infeed.get_input_fn( params.parser, params.train_pattern, tf.estimator.ModeKeys.TRAIN, parallelism=FLAGS.train_infeed_parallelism), max_steps=train_steps)
def main(_): if not FLAGS.wait and not tf.compat.v1.train.checkpoint_exists( FLAGS.model_dir): raise ValueError(("Checkpoints %s doesn't exist " % FLAGS.model_dir, "and evaluation doesn't wait.")) while True: if tf.compat.v1.train.checkpoint_exists(FLAGS.model_dir): # If checkpoint provided instead of dir, convert eval dir to parent dir. if tf.io.gfile.isdir(FLAGS.model_dir): eval_dir = FLAGS.model_dir if FLAGS.best: checkpoint_id = _get_best_checkpoint_id(FLAGS.model_dir) logging.info("Use best checkpoint id: %d", checkpoint_id) checkpoint_path = os.path.join( FLAGS.model_dir, "model.ckpt-%d" % checkpoint_id) else: checkpoint_path = None else: eval_dir = os.path.dirname(FLAGS.model_dir) checkpoint_path = FLAGS.model_dir if FLAGS.best: raise ValueError("When evaluating the best checkpoint, " "a model dir should be provided " "instead of a specified checkpoint.") params = registry.get_params(FLAGS.params)(FLAGS.param_overrides) if FLAGS.evaluate_test: pattern = params.test_pattern logging.warning( "Evaluating on test set. " "This should be only used for final number report.") else: pattern = params.dev_pattern input_fn = infeed.get_input_fn(params.parser, pattern, tf.estimator.ModeKeys.PREDICT) estimator = estimator_utils.create_estimator( FLAGS.master, eval_dir, FLAGS.use_tpu, FLAGS.iterations_per_loop, FLAGS.num_shards, params) if checkpoint_path: global_step = int(checkpoint_path.split("-")[-1]) else: global_step = estimator.get_variable_value("global_step") predictions = estimator.predict(input_fn=input_fn, checkpoint_path=checkpoint_path) if not FLAGS.full: predictions = itertools.islice(predictions, params.eval_max_predictions) eval_tag = FLAGS.eval_tag if FLAGS.best: eval_tag += ".best" if FLAGS.evaluate_test: eval_tag += ".test" else: eval_tag += ".dev" if FLAGS.full: eval_tag += ".full" params.eval(predictions, eval_dir, global_step, eval_tag, FLAGS.enable_logging) break time.sleep(10)
def main(_): if not FLAGS.wait and not tf.train.checkpoint_exists(FLAGS.model_dir): raise ValueError(("Checkpoints %s doesn't exist " % FLAGS.model_dir, "and evaluation doesn't wait.")) while True: if tf.train.checkpoint_exists(FLAGS.model_dir): logging.warning("Flag 1: Loading model checkpoint from {}".format( FLAGS.model_dir)) # If checkpoint provided instead of dir, convert eval dir to parent dir. if tf.io.gfile.isdir(FLAGS.model_dir): eval_dir = FLAGS.model_dir if FLAGS.best: checkpoint_id = _get_best_checkpoint_id(FLAGS.model_dir) logging.info("Use best checkpoint id: %d", checkpoint_id) checkpoint_path = os.path.join( FLAGS.model_dir, "model.ckpt-%d" % checkpoint_id) else: checkpoint_path = None else: eval_dir = os.path.dirname(FLAGS.model_dir) checkpoint_path = FLAGS.model_dir if FLAGS.best: raise ValueError("When evaluating the best checkpoint, " "a model dir should be provided " "instead of a specified checkpoint.") params = registry.get_params(FLAGS.params)(FLAGS.param_overrides) logging.warning("Flag 2: These are the params: {}".format(params)) if FLAGS.evaluate_test: pattern = params.test_pattern logging.warning( "Evaluating on test set. " "This should be only used for final number report.") else: pattern = params.dev_pattern logging.warning("Flag 3: Getting the input function...") input_fn = infeed.get_input_fn(params.parser, pattern, tf.estimator.ModeKeys.PREDICT) logging.warning("Flag 4: Creating the estimator...") estimator = estimator_utils.create_estimator( FLAGS.master, eval_dir, FLAGS.use_tpu, FLAGS.iterations_per_loop, FLAGS.num_shards, params) if checkpoint_path: global_step = int(checkpoint_path.split("-")[-1]) else: global_step = estimator.get_variable_value("global_step") logging.warning( "Flag 5: Here we define the predictions function to be passed to the estimator..." ) predictions = estimator.predict(input_fn=input_fn, checkpoint_path=checkpoint_path) if not FLAGS.full: predictions = itertools.islice(predictions, params.eval_max_predictions) eval_tag = FLAGS.eval_tag if FLAGS.best: eval_tag += ".best" if FLAGS.evaluate_test: eval_tag += ".test" else: eval_tag += ".dev" if FLAGS.full: eval_tag += ".full" logging.warning( "Flag 6: Entering predictions -> saving INPUT/TARGET/PREDS to {}" .format(FLAGS.model_dir)) params.eval(predictions, eval_dir, global_step, eval_tag, FLAGS.enable_logging) logging.warning( "Flag 7: Evaluations completed -> saving text metrics to {}". format(FLAGS.model_dir)) break time.sleep(10)