예제 #1
0
def create_experiment_components(hparams, output_dir, data_dir, model_name):
  """Constructs and returns Estimator and train/eval input functions."""
  tf.logging.info("Creating experiment, storing model files in %s", output_dir)

  num_datashards = devices.data_parallelism().n
  train_input_fn = input_fn_builder.build_input_fn(
      mode=tf.contrib.learn.ModeKeys.TRAIN,
      hparams=hparams,
      data_file_patterns=get_data_filepatterns(data_dir,
                                               tf.contrib.learn.ModeKeys.TRAIN),
      num_datashards=num_datashards)

  eval_input_fn = input_fn_builder.build_input_fn(
      mode=tf.contrib.learn.ModeKeys.EVAL,
      hparams=hparams,
      data_file_patterns=get_data_filepatterns(data_dir,
                                               tf.contrib.learn.ModeKeys.EVAL),
      num_datashards=num_datashards)
  estimator = tf.contrib.learn.Estimator(
      model_fn=model_builder.build_model_fn(model_name, hparams=hparams),
      model_dir=output_dir,
      config=tf.contrib.learn.RunConfig(
          master=FLAGS.master,
          model_dir=output_dir,
          gpu_memory_fraction=FLAGS.worker_gpu_memory_fraction,
          session_config=session_config(),
          keep_checkpoint_max=FLAGS.keep_checkpoint_max))
  # Store the hparams in the estimator as well
  estimator.hparams = hparams
  return estimator, {
      tf.contrib.learn.ModeKeys.TRAIN: train_input_fn,
      tf.contrib.learn.ModeKeys.EVAL: eval_input_fn
  }
예제 #2
0
def create_experiment_components(hparams, output_dir, data_dir, model_name):
    """Constructs and returns Estimator and train/eval input functions."""
    tf.logging.info("Creating experiment, storing model files in %s",
                    output_dir)

    num_datashards = devices.data_parallelism().n
    train_input_fn = input_fn_builder.build_input_fn(
        mode=tf.estimator.ModeKeys.TRAIN,
        hparams=hparams,
        data_file_patterns=get_data_filepatterns(data_dir,
                                                 tf.estimator.ModeKeys.TRAIN),
        num_datashards=num_datashards,
        worker_replicas=FLAGS.worker_replicas,
        worker_id=FLAGS.worker_id)

    eval_input_fn = input_fn_builder.build_input_fn(
        mode=tf.estimator.ModeKeys.EVAL,
        hparams=hparams,
        data_file_patterns=get_data_filepatterns(data_dir,
                                                 tf.estimator.ModeKeys.EVAL),
        num_datashards=num_datashards,
        worker_replicas=FLAGS.worker_replicas,
        worker_id=FLAGS.worker_id)

    autotune = False
    objective = None
    if hasattr(FLAGS, "autotune"):
        autotune = FLAGS.autotune
        objective = FLAGS.objective
    model_fn = model_builder.build_model_fn(
        model_name,
        problem_names=FLAGS.problems.split("-"),
        train_steps=FLAGS.train_steps,
        worker_id=FLAGS.worker_id,
        worker_replicas=FLAGS.worker_replicas,
        eval_run_autoregressive=FLAGS.eval_run_autoregressive,
        decode_hparams=decoding.decode_hparams(FLAGS.decode_hparams),
        autotune=autotune,
        objective=objective)
    estimator = tf.estimator.Estimator(
        model_fn=model_fn,
        model_dir=output_dir,
        params=hparams,
        config=tf.contrib.learn.RunConfig(
            master=FLAGS.master,
            gpu_memory_fraction=FLAGS.worker_gpu_memory_fraction,
            session_config=session_config(),
            keep_checkpoint_max=FLAGS.keep_checkpoint_max,
            keep_checkpoint_every_n_hours=FLAGS.keep_checkpoint_every_n_hours,
            save_checkpoints_secs=FLAGS.save_checkpoints_secs))

    return estimator, {
        tf.estimator.ModeKeys.TRAIN: train_input_fn,
        tf.estimator.ModeKeys.EVAL: eval_input_fn
    }
def create_experiment_components(params,
                                 hparams,
                                 run_config,
                                 problem_instance,
                                 train_preprocess_file_path=None,
                                 dev_preprocess_file_path=None):
    """Constructs and returns Estimator and train/eval input functions."""
    tf.logging.info("Creating experiment, storing model files in %s",
                    run_config.model_dir)

    add_problem_hparams(hparams, params.problem_name, params.model_dir,
                        problem_instance)

    # hparams batch_size is used as minibatch size instead of tokens in batch
    batch_size = (hparams.use_fixed_batch_size and hparams.batch_size) or None
    num_datashards = 1
    train_input_fn = input_fn_builder.build_input_fn(
        mode=tf.estimator.ModeKeys.TRAIN,
        hparams=hparams,
        data_dir=params.data_dir,
        num_datashards=num_datashards,
        worker_replicas=FLAGS.worker_replicas,
        worker_id=FLAGS.worker_id,
        batch_size=batch_size,
        dataset_split=train_preprocess_file_path)

    eval_input_fn = input_fn_builder.build_input_fn(
        mode=tf.estimator.ModeKeys.EVAL,
        hparams=hparams,
        data_dir=params.data_dir,
        num_datashards=num_datashards,
        worker_replicas=FLAGS.worker_replicas,
        worker_id=FLAGS.worker_id,
        dataset_split=dev_preprocess_file_path)

    model_fn = model_builder.build_model_fn(
        params.model_name,
        problem_names=[params.problem_name],
        train_steps=params.train_steps,
        worker_id=FLAGS.worker_id,
        worker_replicas=FLAGS.worker_replicas,
        eval_run_autoregressive=FLAGS.eval_run_autoregressive,
        decode_hparams=decoding.decode_hparams(params.decode_hparams))

    estimator = tf.estimator.Estimator(model_fn=model_fn,
                                       model_dir=run_config.model_dir,
                                       params=hparams,
                                       config=run_config)

    return estimator, {
        tf.estimator.ModeKeys.TRAIN: train_input_fn,
        tf.estimator.ModeKeys.EVAL: eval_input_fn
    }
예제 #4
0
def create_experiment_components(data_dir, model_name, hparams, run_config):
    """Constructs and returns Estimator and train/eval input functions."""
    tf.logging.info("Creating experiment, storing model files in %s",
                    run_config.model_dir)

    add_problem_hparams(hparams, FLAGS.problems)

    # hparams batch_size is used as minibatch size instead of tokens in batch
    batch_size = (hparams.use_fixed_batch_size and hparams.batch_size) or None
    num_datashards = devices.data_parallelism().n
    train_input_fn = input_fn_builder.build_input_fn(
        mode=tf.estimator.ModeKeys.TRAIN,
        hparams=hparams,
        data_dir=data_dir,
        num_datashards=num_datashards,
        worker_replicas=FLAGS.worker_replicas,
        worker_id=FLAGS.worker_id,
        batch_size=batch_size)  # return feature_map, feature_map["targets"]

    eval_input_fn = input_fn_builder.build_input_fn(
        mode=tf.estimator.ModeKeys.EVAL,
        hparams=hparams,
        data_dir=data_dir,
        num_datashards=num_datashards,
        worker_replicas=FLAGS.worker_replicas,
        worker_id=FLAGS.worker_id,
        dataset_split="test"
        if FLAGS.eval_use_test_set else None)  # evaluate on test dataset
    # input_fn return feature_map

    model_fn = model_builder.build_model_fn(
        model_name,
        problem_names=FLAGS.problems.split("-"),
        train_steps=FLAGS.train_steps,
        worker_id=FLAGS.worker_id,
        worker_replicas=FLAGS.worker_replicas,
        eval_run_autoregressive=FLAGS.eval_run_autoregressive,
        decode_hparams=decoding.decode_hparams(FLAGS.decode_hparams))

    estimator = tf.estimator.Estimator(model_fn=model_fn,
                                       model_dir=run_config.model_dir,
                                       params=hparams,
                                       config=run_config)

    return estimator, {
        tf.estimator.ModeKeys.TRAIN: train_input_fn,
        tf.estimator.ModeKeys.EVAL: eval_input_fn
    }
예제 #5
0
def create_experiment_components(data_dir, model_name, hparams, run_config):
  """Constructs and returns Estimator and train/eval input functions."""
  tf.logging.info("Creating experiment, storing model files in %s",
                  run_config.model_dir)

  add_problem_hparams(hparams, FLAGS.problems)

  # hparams batch_size is used as minibatch size instead of tokens in batch
  batch_size = (hparams.use_fixed_batch_size and hparams.batch_size) or None
  num_datashards = devices.data_parallelism(hparams).n
  train_input_fn = input_fn_builder.build_input_fn(
      mode=tf.estimator.ModeKeys.TRAIN,
      hparams=hparams,
      data_dir=data_dir,
      num_datashards=num_datashards,
      worker_replicas=FLAGS.worker_replicas,
      worker_id=FLAGS.worker_id,
      batch_size=batch_size)

  eval_input_fn = input_fn_builder.build_input_fn(
      mode=tf.estimator.ModeKeys.EVAL,
      hparams=hparams,
      data_dir=data_dir,
      num_datashards=num_datashards,
      worker_replicas=FLAGS.worker_replicas,
      worker_id=FLAGS.worker_id,
      dataset_split="test" if FLAGS.eval_use_test_set else None)

  model_fn = model_builder.build_model_fn(
      model_name,
      problem_names=FLAGS.problems.split("-"),
      train_steps=FLAGS.train_steps,
      worker_id=FLAGS.worker_id,
      worker_replicas=FLAGS.worker_replicas,
      eval_run_autoregressive=FLAGS.eval_run_autoregressive,
      decode_hparams=decoding.decode_hparams(FLAGS.decode_hparams))

  estimator = tf.estimator.Estimator(
      model_fn=model_fn,
      model_dir=run_config.model_dir,
      params=hparams,
      config=run_config)

  return estimator, {
      tf.estimator.ModeKeys.TRAIN: train_input_fn,
      tf.estimator.ModeKeys.EVAL: eval_input_fn
  }
예제 #6
0
def create_experiment_components(data_dir, model_name, hparams, run_config):
    """Constructs and returns Estimator and train/eval input functions."""
    tf.logging.info("Creating experiment, storing model files in %s",
                    run_config.model_dir)

    hparams = add_problem_hparams(hparams, FLAGS.problems)

    num_datashards = devices.data_parallelism().n
    train_input_fn = input_fn_builder.build_input_fn(
        mode=tf.estimator.ModeKeys.TRAIN,
        hparams=hparams,
        data_file_patterns=get_data_filepatterns(data_dir,
                                                 tf.estimator.ModeKeys.TRAIN),
        num_datashards=num_datashards,
        worker_replicas=FLAGS.worker_replicas,
        worker_id=FLAGS.worker_id)

    eval_input_fn = input_fn_builder.build_input_fn(
        mode=tf.estimator.ModeKeys.EVAL,
        hparams=hparams,
        data_file_patterns=get_data_filepatterns(data_dir,
                                                 tf.estimator.ModeKeys.EVAL),
        num_datashards=num_datashards,
        worker_replicas=FLAGS.worker_replicas,
        worker_id=FLAGS.worker_id)

    model_fn = model_builder.build_model_fn(
        model_name,
        problem_names=FLAGS.problems.split("-"),
        train_steps=FLAGS.train_steps,
        worker_id=FLAGS.worker_id,
        worker_replicas=FLAGS.worker_replicas,
        eval_run_autoregressive=FLAGS.eval_run_autoregressive,
        decode_hparams=decoding.decode_hparams(FLAGS.decode_hparams))

    estimator = tf.estimator.Estimator(model_fn=model_fn,
                                       model_dir=run_config.model_dir,
                                       params=hparams,
                                       config=run_config)

    return estimator, {
        tf.estimator.ModeKeys.TRAIN: train_input_fn,
        tf.estimator.ModeKeys.EVAL: eval_input_fn
    }
예제 #7
0
def decode_from_dataset(estimator,
                        problem_names,
                        return_beams=False,
                        beam_size=1,
                        max_predictions=-1,
                        decode_to_file=None,
                        save_images=False,
                        identity_output=False):
    tf.logging.info("Performing local inference from dataset for %s.",
                    str(problem_names))
    hparams = estimator.hparams

    for problem_idx, problem_name in enumerate(problem_names):
        # Build the inference input function
        infer_problems_data = data_reader.get_data_filepatterns(
            problem_name, hparams.data_dir, tf.contrib.learn.ModeKeys.INFER)

        infer_input_fn = input_fn_builder.build_input_fn(
            mode=tf.contrib.learn.ModeKeys.INFER,
            hparams=hparams,
            data_file_patterns=infer_problems_data,
            num_datashards=devices.data_parallelism().n,
            fixed_problem=problem_idx)

        # Get the predictions as an iterable
        predictions = estimator.predict(input_fn=infer_input_fn,
                                        as_iterable=True)

        # Prepare output file writers if decode_to_file passed
        if decode_to_file:
            output_filepath = decode_to_file + ".outputs." + problem_name
            target_filepath = decode_to_file + ".targets." + problem_name

            output_file = tf.gfile.Open(output_filepath, "w")
            target_file = tf.gfile.Open(target_filepath, "w")

        problem_hparams = hparams.problems[problem_idx]
        inputs_vocab = problem_hparams.vocabulary.get("inputs", None)
        targets_vocab = problem_hparams.vocabulary["targets"]
        for num_predictions, prediction in enumerate(predictions):
            inputs = prediction["inputs"]
            targets = prediction["targets"]
            outputs = prediction["outputs"]

            # Log predictions
            decoded_outputs = []
            if return_beams:
                output_beams = np.split(outputs, beam_size, axis=0)
                for i, beam in enumerate(output_beams):
                    tf.logging.info("BEAM %d:" % i)
                    decoded = _decode_from_dataset_log_results(
                        inputs, targets, beam, problem_name, num_predictions,
                        inputs_vocab, targets_vocab, save_images,
                        estimator.model_dir, identity_output)
                    decoded_outputs.append(decoded)
            else:
                decoded = _decode_from_dataset_log_results(
                    inputs, targets, outputs, problem_name, num_predictions,
                    inputs_vocab, targets_vocab, save_images,
                    estimator.model_dir, identity_output)
                decoded_outputs.append(decoded)

            # Write out predictions if decode_to_file passed
            if decode_to_file:
                for decoded_output, decoded_target in decoded_outputs:
                    output_file.write(str(decoded_output) + "\n")
                    target_file.write(str(decoded_target) + "\n")

            if max_predictions >= 0 and num_predictions >= max_predictions:
                break

        if decode_to_file:
            output_file.close()
            target_file.close()

        tf.logging.info("Completed inference on %d samples." % num_predictions)  # pylint: disable=undefined-loop-variable
예제 #8
0
def decode_from_dataset(estimator):
    hparams = estimator.hparams
    for i, problem in enumerate(FLAGS.problems.split("-")):
        inputs_vocab = hparams.problems[i].vocabulary.get("inputs", None)
        targets_vocab = hparams.problems[i].vocabulary["targets"]
        tf.logging.info("Performing local inference.")
        infer_problems_data = data_reader.get_data_filepatterns(
            FLAGS.problems, hparams.data_dir, tf.contrib.learn.ModeKeys.INFER)

        infer_input_fn = input_fn_builder.build_input_fn(
            mode=tf.contrib.learn.ModeKeys.INFER,
            hparams=hparams,
            data_file_patterns=infer_problems_data,
            num_datashards=devices.data_parallelism().n,
            fixed_problem=i)

        def log_fn(inputs,
                   targets,
                   outputs,
                   problem,
                   j,
                   inputs_vocab=inputs_vocab,
                   targets_vocab=targets_vocab):
            """Log inference results."""
            if "image" in problem and FLAGS.decode_save_images:
                save_path = os.path.join(estimator.model_dir,
                                         "%s_prediction_%d.jpg" % (problem, j))
                show_and_save_image(inputs / 255., save_path)
            elif inputs_vocab:
                decoded_inputs = inputs_vocab.decode(
                    _save_until_eos(inputs.flatten()))
                tf.logging.info("Inference results INPUT: %s" % decoded_inputs)

            if FLAGS.identity_output:
                decoded_outputs = " ".join(map(str, outputs.flatten()))
                decoded_targets = " ".join(map(str, targets.flatten()))
            else:
                decoded_outputs = targets_vocab.decode(
                    _save_until_eos(outputs.flatten()))
                decoded_targets = targets_vocab.decode(
                    _save_until_eos(targets.flatten()))

            tf.logging.info("Inference results OUTPUT: %s" % decoded_outputs)
            tf.logging.info("Inference results TARGET: %s" % decoded_targets)
            if FLAGS.decode_to_file:
                output_filepath = FLAGS.decode_to_file + ".outputs." + problem
                output_file = tf.gfile.Open(output_filepath, "a")
                output_file.write(decoded_outputs + "\n")
                target_filepath = FLAGS.decode_to_file + ".targets." + problem
                target_file = tf.gfile.Open(target_filepath, "a")
                target_file.write(decoded_targets + "\n")

        result_iter = estimator.predict(input_fn=infer_input_fn,
                                        as_iterable=True)
        count = 0
        for result in result_iter:
            # predictions from the test input. We use it to log inputs and decodes.
            inputs = result["inputs"]
            targets = result["targets"]
            outputs = result["outputs"]
            if FLAGS.decode_return_beams:
                output_beams = np.split(outputs,
                                        FLAGS.decode_beam_size,
                                        axis=0)
                for k, beam in enumerate(output_beams):
                    tf.logging.info("BEAM %d:" % k)
                    log_fn(inputs, targets, beam, problem, count)
            else:
                log_fn(inputs, targets, outputs, problem, count)

            count += 1
            if FLAGS.decode_num_samples != -1 and count >= FLAGS.decode_num_samples:
                break
        tf.logging.info("Completed inference on %d samples." % count)
예제 #9
0
def decode_from_dataset(estimator,
                        problem_names,
                        decode_hp,
                        decode_to_file=None,
                        dataset_split=None):
  tf.logging.info("Performing local inference from dataset for %s.",
                  str(problem_names))
  hparams = estimator.params
  # We assume that worker_id corresponds to shard number.
  shard = decode_hp.shard_id if decode_hp.shards > 1 else None

  for problem_idx, problem_name in enumerate(problem_names):
    # Build the inference input function
    infer_input_fn = input_fn_builder.build_input_fn(
        mode=tf.estimator.ModeKeys.PREDICT,
        hparams=hparams,
        data_dir=hparams.data_dir,
        num_datashards=devices.data_parallelism().n,
        fixed_problem=problem_idx,
        batch_size=decode_hp.batch_size,
        dataset_split=dataset_split,
        shard=shard)

    # Get the predictions as an iterable
    predictions = estimator.predict(infer_input_fn)

    # Prepare output file writers if decode_to_file passed
    if decode_to_file:
      if decode_hp.shards > 1:
        decode_filename = decode_to_file + ("%.2d" % decode_hp.shard_id)
      else:
        decode_filename = decode_to_file
      output_filepath = _decode_filename(decode_filename, problem_name,
                                         decode_hp)
      parts = output_filepath.split(".")
      parts[-1] = "targets"
      target_filepath = ".".join(parts)

      output_file = tf.gfile.Open(output_filepath, "w")
      target_file = tf.gfile.Open(target_filepath, "w")

    problem_hparams = hparams.problems[problem_idx]
    # Inputs vocabulary is set to targets if there are no inputs in the problem,
    # e.g., for language models where the inputs are just a prefix of targets.
    has_input = "inputs" in problem_hparams.vocabulary
    inputs_vocab_key = "inputs" if has_input else "targets"
    inputs_vocab = problem_hparams.vocabulary[inputs_vocab_key]
    targets_vocab = problem_hparams.vocabulary["targets"]
    for num_predictions, prediction in enumerate(predictions):
      num_predictions += 1
      inputs = prediction["inputs"]
      targets = prediction["targets"]
      outputs = prediction["outputs"]

      # Log predictions
      decoded_outputs = []
      if decode_hp.return_beams:
        output_beams = np.split(outputs, decode_hp.beam_size, axis=0)
        for i, beam in enumerate(output_beams):
          tf.logging.info("BEAM %d:" % i)
          decoded = log_decode_results(
              inputs,
              beam,
              problem_name,
              num_predictions,
              inputs_vocab,
              targets_vocab,
              save_images=decode_hp.save_images,
              model_dir=estimator.model_dir,
              identity_output=decode_hp.identity_output,
              targets=targets)
          decoded_outputs.append(decoded)
      else:
        decoded = log_decode_results(
            inputs,
            outputs,
            problem_name,
            num_predictions,
            inputs_vocab,
            targets_vocab,
            save_images=decode_hp.save_images,
            model_dir=estimator.model_dir,
            identity_output=decode_hp.identity_output,
            targets=targets)
        decoded_outputs.append(decoded)

      # Write out predictions if decode_to_file passed
      if decode_to_file:
        for decoded_output, decoded_target in decoded_outputs:
          output_file.write(str(decoded_output) + decode_hp.delimiter)
          target_file.write(str(decoded_target) + decode_hp.delimiter)

      if (decode_hp.num_samples >= 0 and
          num_predictions >= decode_hp.num_samples):
        break

    if decode_to_file:
      output_file.close()
      target_file.close()

    tf.logging.info("Completed inference on %d samples." % num_predictions)  # pylint: disable=undefined-loop-variable
예제 #10
0
def decode_from_dataset(estimator,
                        problem_names,
                        decode_hp,
                        decode_to_file=None,
                        dataset_split=None):
  tf.logging.info("Performing local inference from dataset for %s.",
                  str(problem_names))
  hparams = estimator.params
  # We assume that worker_id corresponds to shard number.
  shard = decode_hp.shard_id if decode_hp.shards > 1 else None

  for problem_idx, problem_name in enumerate(problem_names):
    # Build the inference input function
    infer_input_fn = input_fn_builder.build_input_fn(
        mode=tf.estimator.ModeKeys.PREDICT,
        hparams=hparams,
        data_dir=hparams.data_dir,
        num_datashards=devices.data_parallelism(hparams).n,
        fixed_problem=problem_idx,
        batch_size=decode_hp.batch_size,
        dataset_split=dataset_split,
        shard=shard)

    # Get the predictions as an iterable
    predictions = estimator.predict(infer_input_fn)

    # Prepare output file writers if decode_to_file passed
    if decode_to_file:
      if decode_hp.shards > 1:
        decode_filename = decode_to_file + ("%.2d" % decode_hp.shard_id)
      else:
        decode_filename = decode_to_file
      output_filepath = _decode_filename(decode_filename, problem_name,
                                         decode_hp)
      parts = output_filepath.split(".")
      parts[-1] = "targets"
      target_filepath = ".".join(parts)

      output_file = tf.gfile.Open(output_filepath, "w")
      target_file = tf.gfile.Open(target_filepath, "w")

    problem_hparams = hparams.problems[problem_idx]
    # Inputs vocabulary is set to targets if there are no inputs in the problem,
    # e.g., for language models where the inputs are just a prefix of targets.
    has_input = "inputs" in problem_hparams.vocabulary
    inputs_vocab_key = "inputs" if has_input else "targets"
    inputs_vocab = problem_hparams.vocabulary[inputs_vocab_key]
    targets_vocab = problem_hparams.vocabulary["targets"]
    for num_predictions, prediction in enumerate(predictions):
      num_predictions += 1
      inputs = prediction["inputs"]
      targets = prediction["targets"]
      outputs = prediction["outputs"]

      # Log predictions
      decoded_outputs = []
      if decode_hp.return_beams:
        output_beams = np.split(outputs, decode_hp.beam_size, axis=0)
        for i, beam in enumerate(output_beams):
          tf.logging.info("BEAM %d:" % i)
          decoded = log_decode_results(
              inputs,
              beam,
              problem_name,
              num_predictions,
              inputs_vocab,
              targets_vocab,
              save_images=decode_hp.save_images,
              model_dir=estimator.model_dir,
              identity_output=decode_hp.identity_output,
              targets=targets)
          decoded_outputs.append(decoded)
      else:
        decoded = log_decode_results(
            inputs,
            outputs,
            problem_name,
            num_predictions,
            inputs_vocab,
            targets_vocab,
            save_images=decode_hp.save_images,
            model_dir=estimator.model_dir,
            identity_output=decode_hp.identity_output,
            targets=targets)
        decoded_outputs.append(decoded)

      # Write out predictions if decode_to_file passed
      if decode_to_file:
        for decoded_output, decoded_target in decoded_outputs:
          output_file.write(str(decoded_output) + decode_hp.delimiter)
          target_file.write(str(decoded_target) + decode_hp.delimiter)

      if (decode_hp.num_samples >= 0 and
          num_predictions >= decode_hp.num_samples):
        break

    if decode_to_file:
      output_file.close()
      target_file.close()

    tf.logging.info("Completed inference on %d samples." % num_predictions)  # pylint: disable=undefined-loop-variable
예제 #11
0
def decode_from_dataset(estimator,
                        problem_names,
                        decode_hp,
                        decode_to_file=None):
    tf.logging.info("Performing local inference from dataset for %s.",
                    str(problem_names))
    hparams = estimator.params

    for problem_idx, problem_name in enumerate(problem_names):
        # Build the inference input function
        infer_problems_data = data_reader.get_data_filepatterns(
            problem_name, hparams.data_dir, tf.estimator.ModeKeys.PREDICT)

        infer_input_fn = input_fn_builder.build_input_fn(
            mode=tf.estimator.ModeKeys.PREDICT,
            hparams=hparams,
            data_file_patterns=infer_problems_data,
            num_datashards=devices.data_parallelism().n,
            fixed_problem=problem_idx)

        # Get the predictions as an iterable
        predictions = estimator.predict(infer_input_fn)

        # Prepare output file writers if decode_to_file passed
        if decode_to_file:
            output_filepath = _decode_filename(decode_to_file, problem_name,
                                               decode_hp)
            parts = output_filepath.split(".")
            parts[-1] = "targets"
            target_filepath = ".".join(parts)

            output_file = tf.gfile.Open(output_filepath, "w")
            target_file = tf.gfile.Open(target_filepath, "w")

        problem_hparams = hparams.problems[problem_idx]
        inputs_vocab = problem_hparams.vocabulary.get("inputs", None)
        targets_vocab = problem_hparams.vocabulary["targets"]
        for num_predictions, prediction in enumerate(predictions):
            inputs = prediction["inputs"]
            targets = prediction["targets"]
            outputs = prediction["outputs"]

            # Log predictions
            decoded_outputs = []
            if decode_hp.return_beams:
                output_beams = np.split(outputs, decode_hp.beam_size, axis=0)
                for i, beam in enumerate(output_beams):
                    tf.logging.info("BEAM %d:" % i)
                    decoded = log_decode_results(
                        inputs,
                        beam,
                        problem_name,
                        num_predictions,
                        inputs_vocab,
                        targets_vocab,
                        save_images=decode_hp.save_images,
                        model_dir=estimator.model_dir,
                        identity_output=decode_hp.identity_output,
                        targets=targets)
                    decoded_outputs.append(decoded)
            else:
                decoded = log_decode_results(
                    inputs,
                    outputs,
                    problem_name,
                    num_predictions,
                    inputs_vocab,
                    targets_vocab,
                    save_images=decode_hp.save_images,
                    model_dir=estimator.model_dir,
                    identity_output=decode_hp.identity_output,
                    targets=targets)
                decoded_outputs.append(decoded)

            # Write out predictions if decode_to_file passed
            if decode_to_file:
                for decoded_output, decoded_target in decoded_outputs:
                    output_file.write(str(decoded_output) + "\n")
                    target_file.write(str(decoded_target) + "\n")

            if (decode_hp.num_samples >= 0
                    and num_predictions >= decode_hp.num_samples):
                break

        if decode_to_file:
            output_file.close()
            target_file.close()

        tf.logging.info("Completed inference on %d samples." % num_predictions)  # pylint: disable=undefined-loop-variable