Пример #1
0
    def __init__(self):

        tf.logging.set_verbosity(tf.logging.INFO)
        usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
        trainer_utils.log_registry()
        #trainer_utils.validate_flags()
        self.data_dir = os.path.expanduser(FLAGS.data_dir)
        self.output_dir = os.path.expanduser(FLAGS.output_dir)

        self.hparams = trainer_utils.create_hparams(
            FLAGS.hparams_set, self.data_dir, passed_hparams=FLAGS.hparams)
        trainer_utils.add_problem_hparams(self.hparams, FLAGS.problems)
        self.estimator, _ = trainer_utils.create_experiment_components(
            data_dir=self.data_dir,
            model_name=FLAGS.model,
            hparams=self.hparams,
            run_config=trainer_utils.create_run_config(self.output_dir))

        self.decode_hp = decoding.decode_hparams(FLAGS.decode_hparams)
        self.decode_hp.add_hparam("shards", FLAGS.decode_shards)
        self.decode_hp.add_hparam("shard_id", FLAGS.worker_id)
        output_sentence = decoding.decode_from_file(
            self.estimator,
            FLAGS.decode_from_file,
            self.decode_hp,
            FLAGS.decode_to_file,
            input_sentence=FLAGS.input_sentence)
Пример #2
0
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)
    usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
    trainer_utils.log_registry()
    trainer_utils.validate_flags()
    assert FLAGS.schedule == "train_and_evaluate"
    data_dir = os.path.expanduser(FLAGS.data_dir)
    output_dir = os.path.expanduser(FLAGS.output_dir)

    hparams = trainer_utils.create_hparams(FLAGS.hparams_set,
                                           data_dir,
                                           passed_hparams=FLAGS.hparams)
    hparams = trainer_utils.add_problem_hparams(hparams, FLAGS.problems)
    estimator, _ = trainer_utils.create_experiment_components(
        data_dir=data_dir,
        model_name=FLAGS.model,
        hparams=hparams,
        run_config=trainer_utils.create_run_config(output_dir))

    decode_hp = decoding.decode_hparams(FLAGS.decode_hparams)
    decode_hp.add_hparam("shards", FLAGS.decode_shards)
    if FLAGS.decode_interactive:
        decoding.decode_interactively(estimator, decode_hp)
    elif FLAGS.decode_from_file:
        decoding.decode_from_file(estimator, FLAGS.decode_from_file, decode_hp,
                                  FLAGS.decode_to_file)
    else:
        decoding.decode_from_dataset(estimator, FLAGS.problems.split("-"),
                                     decode_hp, FLAGS.decode_to_file)
Пример #3
0
 def testSingleStep(self):
   model_name = "transformer"
   data_dir = TrainerUtilsTest.data_dir
   hparams = trainer_utils.create_hparams("transformer_test", data_dir)
   trainer_utils.add_problem_hparams(hparams, FLAGS.problems)
   exp = trainer_utils.create_experiment(
       data_dir=data_dir,
       model_name=model_name,
       train_steps=1,
       eval_steps=1,
       hparams=hparams,
       run_config=trainer_utils.create_run_config(
           output_dir=tf.test.get_temp_dir()))
   exp.test()
Пример #4
0
 def testSingleStep(self):
     model_name = "transformer"
     data_dir = TrainerUtilsTest.data_dir
     hparams = trainer_utils.create_hparams("transformer_test", data_dir)
     hparams = trainer_utils.add_problem_hparams(hparams, FLAGS.problems)
     exp = trainer_utils.create_experiment(
         data_dir=data_dir,
         model_name=model_name,
         train_steps=1,
         eval_steps=1,
         hparams=hparams,
         run_config=trainer_utils.create_run_config(
             output_dir=tf.test.get_temp_dir()))
     exp.test()
Пример #5
0
  def __prepare_decode_model(self):
    """Prepare utilities for decoding."""
    hparams = trainer_utils.create_hparams(
        self.params.hparams_set,
        self.params.data_dir,
        passed_hparams=self.params.hparams)
    estimator, _ = g2p_trainer_utils.create_experiment_components(
        params=self.params,
        hparams=hparams,
        run_config=trainer_utils.create_run_config(self.params.model_dir),
        problem_instance=self.problem)

    decode_hp = decoding.decode_hparams(self.params.decode_hparams)
    decode_hp.add_hparam("shards", 1)
    return estimator, decode_hp
Пример #6
0
    def __init__(self, str_tokens, eval_tokens=None, batch_size=1000):
        """
        Args:
            batch_size: used for encoding
            str_tokens: the original token inputs, as the format of ['t1', 't2'...]. The items within should be strings
            eval_tokens: if not None, then should be the same length as tokens, for similarity comparisons.
        """
        assert type(str_tokens) is list
        assert len(str_tokens) > 0
        assert type(str_tokens[0]) is str
        self.str_tokens = str_tokens
        if eval_tokens is not None:
            assert (len(eval_tokens) == len(str_tokens)
                    and type(eval_tokens[0]) is str)
        self.eval_tokens = eval_tokens
        tf.logging.set_verbosity(tf.logging.INFO)
        tf.logging.info('tf logging set to INFO by: %s' %
                        self.__class__.__name__)

        usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
        trainer_utils.log_registry()
        trainer_utils.validate_flags()
        assert FLAGS.schedule == "train_and_evaluate"
        data_dir = os.path.expanduser(FLAGS.data_dir)
        out_dir = os.path.expanduser(FLAGS.output_dir)

        hparams = trainer_utils.create_hparams(FLAGS.hparams_set,
                                               data_dir,
                                               passed_hparams=FLAGS.hparams)

        trainer_utils.add_problem_hparams(hparams, FLAGS.problems)
        # print(hparams)
        hparams.eval_use_test_set = True

        self.estimator, _ = trainer_utils.create_experiment_components(
            data_dir=data_dir,
            model_name=FLAGS.model,
            hparams=hparams,
            run_config=trainer_utils.create_run_config(out_dir))

        decode_hp = decoding.decode_hparams(FLAGS.decode_hparams)
        decode_hp.add_hparam("shards", FLAGS.decode_shards)
        decode_hp.batch_size = batch_size
        self.decode_hp = decode_hp
        self.arr_results = None
        self._encoding_len = 1