예제 #1
0
def decode(estimator, hparams, decode_hp):
    """Decode from estimator. Interactive, from file, or from dataset."""
    if FLAGS.decode_interactive:
        if estimator.config.use_tpu:
            raise ValueError("TPU can only decode from dataset.")
        decoding.decode_interactively(estimator,
                                      hparams,
                                      decode_hp,
                                      checkpoint_path=FLAGS.checkpoint_path)
    elif FLAGS.decode_from_file:
        decoding.decode_from_file(estimator,
                                  FLAGS.decode_from_file,
                                  hparams,
                                  decode_hp,
                                  FLAGS.decode_to_file,
                                  checkpoint_path=FLAGS.checkpoint_path)
        if FLAGS.checkpoint_path and FLAGS.keep_timestamp:
            ckpt_time = os.path.getmtime(FLAGS.checkpoint_path + ".index")
            os.utime(FLAGS.decode_to_file, (ckpt_time, ckpt_time))
    else:
        decoding.decode_from_dataset(
            estimator,
            FLAGS.problem,
            hparams,
            decode_hp,
            decode_to_file=FLAGS.decode_to_file,
            dataset_split="test" if FLAGS.eval_use_test_set else None)
예제 #2
0
def decode(estimator):
  if FLAGS.decode_interactive:
    decoding.decode_interactively(estimator)
  elif FLAGS.decode_from_file is not None:
    decoding.decode_from_file(estimator, FLAGS.decode_from_file)
  elif FLAGS.decode_from_dataset:
    decoding.decode_from_dataset(estimator)
예제 #3
0
 def decode(self, dataset_split=None):
     """Decodes from dataset."""
     decoding.decode_from_dataset(self._estimator,
                                  self._hparams.problem.name,
                                  self._hparams,
                                  self._decode_hparams,
                                  dataset_split=dataset_split)
예제 #4
0
파일: t2t-decoder.py 프로젝트: ling60/coies
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)
    usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
    trainer_utils.log_registry()
    trainer_utils.validate_flags()
    assert FLAGS.schedule == "train_and_evaluate"
    data_dir = os.path.expanduser(FLAGS.data_dir)
    output_dir = os.path.expanduser(FLAGS.output_dir)

    hparams = trainer_utils.create_hparams(FLAGS.hparams_set,
                                           data_dir,
                                           passed_hparams=FLAGS.hparams)
    hparams = trainer_utils.add_problem_hparams(hparams, FLAGS.problems)
    estimator, _ = trainer_utils.create_experiment_components(
        data_dir=data_dir,
        model_name=FLAGS.model,
        hparams=hparams,
        run_config=trainer_utils.create_run_config(output_dir))

    decode_hp = decoding.decode_hparams(FLAGS.decode_hparams)
    decode_hp.add_hparam("shards", FLAGS.decode_shards)
    if FLAGS.decode_interactive:
        decoding.decode_interactively(estimator, decode_hp)
    elif FLAGS.decode_from_file:
        decoding.decode_from_file(estimator, FLAGS.decode_from_file, decode_hp,
                                  FLAGS.decode_to_file)
    else:
        decoding.decode_from_dataset(estimator, FLAGS.problems.split("-"),
                                     decode_hp, FLAGS.decode_to_file)
예제 #5
0
 def decode(self, dataset_split=None, decode_from_file=False):
     """Decodes from dataset or file."""
     if decode_from_file:
         decoding.decode_from_file(self._estimator,
                                   self._decode_hparams.decode_from_file,
                                   self._hparams, self._decode_hparams,
                                   self._decode_hparams.decode_to_file)
     else:
         decoding.decode_from_dataset(self._estimator,
                                      self._hparams.problem.name,
                                      self._hparams,
                                      self._decode_hparams,
                                      dataset_split=dataset_split)
예제 #6
0
def decode(estimator, hparams, decode_hp):
    if FLAGS.decode_interactive:
        decoding.decode_interactively(estimator, hparams, decode_hp)
    elif FLAGS.decode_from_file:
        decoding.decode_from_file(estimator, FLAGS.decode_from_file, hparams,
                                  decode_hp, FLAGS.decode_to_file)
    else:
        decoding.decode_from_dataset(
            estimator,
            FLAGS.problems.split("-"),
            hparams,
            decode_hp,
            decode_to_file=FLAGS.decode_to_file,
            dataset_split="test" if FLAGS.eval_use_test_set else None)
예제 #7
0
def decode(estimator, hparams, decode_hp):
  if FLAGS.decode_interactive:
    decoding.decode_interactively(estimator, hparams, decode_hp)
  elif FLAGS.decode_from_file:
    decoding.decode_from_file(estimator, FLAGS.decode_from_file, hparams,
                              decode_hp, FLAGS.decode_to_file)
  else:
    decoding.decode_from_dataset(
        estimator,
        FLAGS.problems.split("-"),
        hparams,
        decode_hp,
        decode_to_file=FLAGS.decode_to_file,
        dataset_split="test" if FLAGS.eval_use_test_set else None)
예제 #8
0
def decode(estimator, hparams, decode_hp):
  if FLAGS.decode_interactive:
    decoding.decode_interactively(estimator, hparams, decode_hp)
  elif FLAGS.decode_from_file:
    decoding.decode_from_file(estimator, FLAGS.decode_from_file, hparams,
                              decode_hp, FLAGS.decode_to_file,
                              checkpoint_path=FLAGS.checkpoint_path)
    if FLAGS.checkpoint_path and FLAGS.keep_timestamp:
      ckpt_time = os.path.getmtime(FLAGS.checkpoint_path + ".index")
      os.utime(FLAGS.decode_to_file, (ckpt_time, ckpt_time))
  else:
    decoding.decode_from_dataset(
        estimator,
        FLAGS.problems.split("-"),
        hparams,
        decode_hp,
        decode_to_file=FLAGS.decode_to_file,
        dataset_split="test" if FLAGS.eval_use_test_set else None)
예제 #9
0
def decode(estimator, hparams, decode_hp):
  if FLAGS.decode_interactive:
    decoding.decode_interactively(estimator, hparams, decode_hp)
  elif FLAGS.decode_from_file:
    decoding.decode_from_file(estimator, FLAGS.decode_from_file, hparams,
                              decode_hp, FLAGS.decode_to_file,
                              checkpoint_path=FLAGS.checkpoint_path)
    if FLAGS.checkpoint_path and FLAGS.keep_timestamp:
      ckpt_time = os.path.getmtime(FLAGS.checkpoint_path + ".index")
      os.utime(FLAGS.decode_to_file, (ckpt_time, ckpt_time))
  else:
    decoding.decode_from_dataset(
        estimator,
        FLAGS.problems.split("-"),
        hparams,
        decode_hp,
        decode_to_file=FLAGS.decode_to_file,
        dataset_split="test" if FLAGS.eval_use_test_set else None)
예제 #10
0
 def decode(self,
            dataset_split=None,
            decode_from_file=False,
            checkpoint_path=None):
   """Decodes from dataset or file."""
   if decode_from_file:
     decoding.decode_from_file(self._estimator,
                               self._decode_hparams.decode_from_file,
                               self._hparams,
                               self._decode_hparams,
                               self._decode_hparams.decode_to_file)
   else:
     decoding.decode_from_dataset(
         self._estimator,
         self._hparams.problem.name,
         self._hparams,
         self._decode_hparams,
         dataset_split=dataset_split,
         checkpoint_path=checkpoint_path)
예제 #11
0
def decode(estimator, hparams, decode_hp):
    """Decode from estimator. Interactive, from file, or from dataset."""
    if FLAGS.decode_interactive:
        if estimator.config.use_tpu:
            raise ValueError("TPU can only decode from dataset.")
        decoding.decode_interactively(estimator,
                                      hparams,
                                      decode_hp,
                                      checkpoint_path=FLAGS.checkpoint_path)
    elif FLAGS.decode_from_file:
        if estimator.config.use_tpu:
            raise ValueError("TPU can only decode from dataset.")
        decoding.decode_from_file(estimator,
                                  FLAGS.decode_from_file,
                                  hparams,
                                  decode_hp,
                                  FLAGS.decode_to_file,
                                  checkpoint_path=FLAGS.checkpoint_path)
        if FLAGS.checkpoint_path and FLAGS.keep_timestamp:
            ckpt_time = os.path.getmtime(FLAGS.checkpoint_path + ".index")
            os.utime(FLAGS.decode_to_file, (ckpt_time, ckpt_time))
    else:

        # Fathom
        predictions = decoding.decode_from_dataset(
            estimator,
            FLAGS.problem,
            hparams,
            decode_hp,
            decode_to_file=FLAGS.decode_to_file,
            dataset_split=dataset_to_t2t_mode(FLAGS.dataset_split),
            return_generator=FLAGS.fathom_output_predictions,
            # save logs/summaries to a directory with the same name as decode_output_file
            # in situations where we are calling decode without write permissions
            # to the model directory
            output_dir=os.path.splitext(FLAGS.decode_output_file)[0])

        # Fathom
        if FLAGS.fathom_output_predictions:
            print('Assuming only one problem...')
            assert '-' not in FLAGS.problems
            # if we already have built problem instance in hparams, no need to create
            # it second time (as it's downloading files from gcs)
            if hasattr(hparams, 'problem'):
                problem = hparams.problem
            else:
                problem = registry.problem(FLAGS.problems)
            problem.output_predictions(predictions=predictions,
                                       num_examples=FLAGS.num_examples)
예제 #12
0
def decode(estimator, hparams, decode_hp):
  """Decode from estimator. Interactive, from file, or from dataset."""
  if FLAGS.decode_interactive:
    if estimator.config.use_tpu:
      raise ValueError("TPU can only decode from dataset.")
    decoding.decode_interactively(estimator, hparams, decode_hp,
                                  checkpoint_path=FLAGS.checkpoint_path)
  elif FLAGS.decode_from_file:
    if estimator.config.use_tpu:
      raise ValueError("TPU can only decode from dataset.")
    decoding.decode_from_file(estimator, FLAGS.decode_from_file, hparams,
                              decode_hp, FLAGS.decode_to_file,
                              checkpoint_path=FLAGS.checkpoint_path)
    if FLAGS.checkpoint_path and FLAGS.keep_timestamp:
      ckpt_time = os.path.getmtime(FLAGS.checkpoint_path + ".index")
      os.utime(FLAGS.decode_to_file, (ckpt_time, ckpt_time))
  else:
    decoding.decode_from_dataset(
        estimator,
        FLAGS.problem,
        hparams,
        decode_hp,
        decode_to_file=FLAGS.decode_to_file,
        dataset_split="test" if FLAGS.eval_use_test_set else None)
예제 #13
0
 def decode(self):
     """Decodes from dataset."""
     decoding.decode_from_dataset(self._estimator,
                                  self._hparams.problem.name, self._hparams,
                                  self._decode_hparams)
예제 #14
0
 def decode(self):
   """Decodes from dataset."""
   decoding.decode_from_dataset(self._estimator, self._hparams.problem.name,
                                self._hparams, self._decode_hparams)