Exemple #1
0
    def train_eval_and_decode(self):
        """Does eval and decode after training every eval_freq_in_steps."""
        eval_steps = self._hparams.eval_freq_in_steps
        packed_dataset = "_packed" in self._hparams.problem.name
        for i in range(0, self._train_spec.max_steps, eval_steps):
            if packed_dataset and i > 0:
                problem = registry.problem(self._hparams.problem.name +
                                           "_packed")
                p_hparams = problem.get_hparams(self._hparams)
                self._hparams.problem = problem
                self._hparams.problem_hparams = p_hparams
            self._estimator.train(self._train_spec.input_fn,
                                  steps=eval_steps,
                                  hooks=self._train_spec.hooks)
            self._set_eval_dir_name("eval")
            self._estimator.evaluate(self._eval_spec.input_fn,
                                     steps=self._eval_spec.steps,
                                     hooks=self._eval_spec.hooks,
                                     name="eval")
            if packed_dataset:
                problem = registry.problem(
                    self._hparams.problem.name.replace("_packed", ""))
                p_hparams = problem.get_hparams(self._hparams)
                self._hparams.problem = problem
                self._hparams.problem_hparams = p_hparams
            self.decode(dataset_split=tf.estimator.ModeKeys.EVAL)
            d_hparams = self._decode_hparams

        d_hparams = self._decode_hparams
Exemple #2
0
def generate_data():
  # Generate data if requested.
  data_dir = os.path.expanduser(FLAGS.data_dir)
  tmp_dir = os.path.expanduser(FLAGS.tmp_dir)
  tf.gfile.MakeDirs(data_dir)
  tf.gfile.MakeDirs(tmp_dir)

  problem_name = FLAGS.problem
  tf.logging.info("Generating data for %s" % problem_name)
  registry.problem(problem_name).generate_data(data_dir, tmp_dir)
Exemple #3
0
def generate_data_for_registered_problem(problem_name):
    """Generate data for a registered problem."""
    tf.logging.info("Generating data for %s.", problem_name)
    if FLAGS.num_shards:
        raise ValueError(
            "--num_shards should not be set for registered Problem.")
    problem = registry.problem(problem_name)
    task_id = None if FLAGS.task_id < 0 else FLAGS.task_id
    data_dir = os.path.expanduser(FLAGS.data_dir)
    tmp_dir = os.path.expanduser(FLAGS.tmp_dir)
    if task_id is None and problem.multiprocess_generate:
        if FLAGS.task_id_start != -1:
            assert FLAGS.task_id_end != -1
            task_id_start = FLAGS.task_id_start
            task_id_end = FLAGS.task_id_end
        else:
            task_id_start = 0
            task_id_end = problem.num_generate_tasks
        pool = multiprocessing.Pool(processes=FLAGS.num_concurrent_processes)
        problem.prepare_to_generate(data_dir, tmp_dir)
        args = [(problem_name, data_dir, tmp_dir, task_id)
                for task_id in range(task_id_start, task_id_end)]
        pool.map(generate_data_in_process, args)
    else:
        problem.generate_data(data_dir, tmp_dir, task_id)
Exemple #4
0
def score_file(filename):
  """Score each line in a file and return the scores."""
  # Prepare model.
  hparams = create_hparams()
  encoders = registry.problem(FLAGS.problem).feature_encoders(FLAGS.data_dir)
  has_inputs = "inputs" in encoders

  # Prepare features for feeding into the model.
  if has_inputs:
    inputs_ph = tf.placeholder(dtype=tf.int32)  # Just length dimension.
    batch_inputs = tf.reshape(inputs_ph, [1, -1, 1, 1])  # Make it 4D.
  targets_ph = tf.placeholder(dtype=tf.int32)  # Just length dimension.
  batch_targets = tf.reshape(targets_ph, [1, -1, 1, 1])  # Make it 4D.
  if has_inputs:
    features = {"inputs": batch_inputs, "targets": batch_targets}
  else:
    features = {"targets": batch_targets}

  # Prepare the model and the graph when model runs on features.
  model = registry.model(FLAGS.model)(hparams, tf.estimator.ModeKeys.EVAL)
  _, losses = model(features)
  saver = tf.train.Saver()

  with tf.Session() as sess:
    # Load weights from checkpoint.
    if FLAGS.checkpoint_path is None:
      ckpts = tf.train.get_checkpoint_state(FLAGS.output_dir)
      ckpt = ckpts.model_checkpoint_path
    else:
      ckpt = FLAGS.checkpoint_path
    saver.restore(sess, ckpt)
    # Run on each line.
    with tf.gfile.Open(filename) as f:
      lines = f.readlines()
    results = []
    for line in lines:
      tab_split = line.split("\t")
      if len(tab_split) > 2:
        raise ValueError("Each line must have at most one tab separator.")
      if len(tab_split) == 1:
        targets = tab_split[0].strip()
      else:
        targets = tab_split[1].strip()
        inputs = tab_split[0].strip()
      # Run encoders and append EOS symbol.
      targets_numpy = encoders["targets"].encode(
          targets) + [text_encoder.EOS_ID]
      if has_inputs:
        inputs_numpy = encoders["inputs"].encode(inputs) + [text_encoder.EOS_ID]
      # Prepare the feed.
      if has_inputs:
        feed = {inputs_ph: inputs_numpy, targets_ph: targets_numpy}
      else:
        feed = {targets_ph: targets_numpy}
      # Get the score.
      np_loss = sess.run(losses["training"], feed)
      results.append(np_loss)
  return results
Exemple #5
0
def add_problem_hparams(hparams, problem_name_or_instance):
    """Add problem hparams for the problems."""
    if isinstance(problem_name_or_instance, problem_lib.Problem):
        problem = problem_name_or_instance
    else:
        problem = registry.problem(problem_name_or_instance)
    p_hparams = problem.get_hparams(hparams)
    hparams.problem = problem
    hparams.problem_hparams = p_hparams
def main(_):
    problem = registry.problem(FLAGS.problem)

    # We make the assumption that the problem is a subclass of Text2TextProblem.
    assert isinstance(problem, text_problems.Text2TextProblem)

    data_dir = os.path.expanduser(FLAGS.data_dir)
    tmp_dir = os.path.expanduser(FLAGS.tmp_dir)

    tf.gfile.MakeDirs(data_dir)
    tf.gfile.MakeDirs(tmp_dir)

    tf.logging.info("Saving vocabulary to data_dir: %s" % data_dir)

    problem.get_or_create_vocab(data_dir, tmp_dir)

    tf.logging.info("Saved vocabulary file: " +
                    os.path.join(data_dir, problem.vocab_filename))
Exemple #7
0
def generate_data_in_process(arg):
    problem_name, data_dir, tmp_dir, task_id = arg
    problem = registry.problem(problem_name)
    problem.generate_data(data_dir, tmp_dir, task_id)
def problem(name):
    return registry.problem(name)