def __init__(self): tf.logging.set_verbosity(tf.logging.INFO) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) trainer_utils.log_registry() #trainer_utils.validate_flags() self.data_dir = os.path.expanduser(FLAGS.data_dir) self.output_dir = os.path.expanduser(FLAGS.output_dir) self.hparams = trainer_utils.create_hparams( FLAGS.hparams_set, self.data_dir, passed_hparams=FLAGS.hparams) trainer_utils.add_problem_hparams(self.hparams, FLAGS.problems) self.estimator, _ = trainer_utils.create_experiment_components( data_dir=self.data_dir, model_name=FLAGS.model, hparams=self.hparams, run_config=trainer_utils.create_run_config(self.output_dir)) self.decode_hp = decoding.decode_hparams(FLAGS.decode_hparams) self.decode_hp.add_hparam("shards", FLAGS.decode_shards) self.decode_hp.add_hparam("shard_id", FLAGS.worker_id) output_sentence = decoding.decode_from_file( self.estimator, FLAGS.decode_from_file, self.decode_hp, FLAGS.decode_to_file, input_sentence=FLAGS.input_sentence)
def main(_): tf.logging.set_verbosity(tf.logging.INFO) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) trainer_utils.log_registry() trainer_utils.validate_flags() assert FLAGS.schedule == "train_and_evaluate" data_dir = os.path.expanduser(FLAGS.data_dir) output_dir = os.path.expanduser(FLAGS.output_dir) hparams = trainer_utils.create_hparams(FLAGS.hparams_set, data_dir, passed_hparams=FLAGS.hparams) hparams = trainer_utils.add_problem_hparams(hparams, FLAGS.problems) estimator, _ = trainer_utils.create_experiment_components( data_dir=data_dir, model_name=FLAGS.model, hparams=hparams, run_config=trainer_utils.create_run_config(output_dir)) decode_hp = decoding.decode_hparams(FLAGS.decode_hparams) decode_hp.add_hparam("shards", FLAGS.decode_shards) if FLAGS.decode_interactive: decoding.decode_interactively(estimator, decode_hp) elif FLAGS.decode_from_file: decoding.decode_from_file(estimator, FLAGS.decode_from_file, decode_hp, FLAGS.decode_to_file) else: decoding.decode_from_dataset(estimator, FLAGS.problems.split("-"), decode_hp, FLAGS.decode_to_file)
def testSingleStep(self): model_name = "transformer" data_dir = TrainerUtilsTest.data_dir hparams = trainer_utils.create_hparams("transformer_test", data_dir) trainer_utils.add_problem_hparams(hparams, FLAGS.problems) exp = trainer_utils.create_experiment( data_dir=data_dir, model_name=model_name, train_steps=1, eval_steps=1, hparams=hparams, run_config=trainer_utils.create_run_config( output_dir=tf.test.get_temp_dir())) exp.test()
def testSingleStep(self): model_name = "transformer" data_dir = TrainerUtilsTest.data_dir hparams = trainer_utils.create_hparams("transformer_test", data_dir) hparams = trainer_utils.add_problem_hparams(hparams, FLAGS.problems) exp = trainer_utils.create_experiment( data_dir=data_dir, model_name=model_name, train_steps=1, eval_steps=1, hparams=hparams, run_config=trainer_utils.create_run_config( output_dir=tf.test.get_temp_dir())) exp.test()
def __prepare_decode_model(self): """Prepare utilities for decoding.""" hparams = trainer_utils.create_hparams( self.params.hparams_set, self.params.data_dir, passed_hparams=self.params.hparams) estimator, _ = g2p_trainer_utils.create_experiment_components( params=self.params, hparams=hparams, run_config=trainer_utils.create_run_config(self.params.model_dir), problem_instance=self.problem) decode_hp = decoding.decode_hparams(self.params.decode_hparams) decode_hp.add_hparam("shards", 1) return estimator, decode_hp
def __init__(self, str_tokens, eval_tokens=None, batch_size=1000): """ Args: batch_size: used for encoding str_tokens: the original token inputs, as the format of ['t1', 't2'...]. The items within should be strings eval_tokens: if not None, then should be the same length as tokens, for similarity comparisons. """ assert type(str_tokens) is list assert len(str_tokens) > 0 assert type(str_tokens[0]) is str self.str_tokens = str_tokens if eval_tokens is not None: assert (len(eval_tokens) == len(str_tokens) and type(eval_tokens[0]) is str) self.eval_tokens = eval_tokens tf.logging.set_verbosity(tf.logging.INFO) tf.logging.info('tf logging set to INFO by: %s' % self.__class__.__name__) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) trainer_utils.log_registry() trainer_utils.validate_flags() assert FLAGS.schedule == "train_and_evaluate" data_dir = os.path.expanduser(FLAGS.data_dir) out_dir = os.path.expanduser(FLAGS.output_dir) hparams = trainer_utils.create_hparams(FLAGS.hparams_set, data_dir, passed_hparams=FLAGS.hparams) trainer_utils.add_problem_hparams(hparams, FLAGS.problems) # print(hparams) hparams.eval_use_test_set = True self.estimator, _ = trainer_utils.create_experiment_components( data_dir=data_dir, model_name=FLAGS.model, hparams=hparams, run_config=trainer_utils.create_run_config(out_dir)) decode_hp = decoding.decode_hparams(FLAGS.decode_hparams) decode_hp.add_hparam("shards", FLAGS.decode_shards) decode_hp.batch_size = batch_size self.decode_hp = decode_hp self.arr_results = None self._encoding_len = 1