Example #1
0
def decode_from_text_file(estimator,
                          problem_name,
                          filename,
                          hparams,
                          decode_hp,
                          decode_to_file=None,
                          checkpoint_path=None):
    """Compute predictions on entries in filename and write them out."""
    if not decode_hp.batch_size:
        decode_hp.batch_size = 32
        tf.logging.info("decode_hp.batch_size not specified; default=%d" %
                        decode_hp.batch_size)

    # Inputs vocabulary is set to targets if there are no inputs in the problem,
    # e.g., for language models where the inputs are just a prefix of targets.
    p_hp = hparams.problem_hparams
    has_input = "inputs" in p_hp.vocabulary
    inputs_vocab_key = "inputs" if has_input else "targets"
    inputs_vocab = p_hp.vocabulary[inputs_vocab_key]
    targets_vocab = p_hp.vocabulary["targets"]
    filename = decoding._add_shard_to_filename(filename, decode_hp)
    tf.logging.info("Performing decoding from file (%s)." % filename)
    if has_input:
        sorted_inputs, sorted_keys = decoding._get_sorted_inputs(
            filename, decode_hp.delimiter)
    else:
        sorted_inputs = decoding._get_language_modeling_inputs(
            filename, decode_hp.delimiter, repeat=decode_hp.num_decodes)
        sorted_keys = range(len(sorted_inputs))

    # If decode_to_file was provided use it as the output filename without change
    # (except for adding shard_id if using more shards for decoding).
    # Otherwise, use the input filename plus model, hp, problem, beam, alpha.
    decode_filename = decode_to_file if decode_to_file else filename
    if not decode_to_file:
        decode_filename = decoding._decode_filename(decode_filename,
                                                    problem_name, decode_hp)
    else:
        decode_filename = decoding._add_shard_to_filename(
            decode_filename, decode_hp)
    tf.logging.info("Writing decodes into %s" % decode_filename)

    # Check for decoding checkpoint.
    decodes = []
    shuffle_file_path = decode_filename + '.shuffle.txt'
    if tf.gfile.Exists(shuffle_file_path):
        with tf.gfile.Open(shuffle_file_path, 'r') as f:
            decodes = [line.strip() for line in f.readlines()]
        tf.logging.info('Read {} sentences from checkpoint.'.format(
            len(decodes)))

    all_sorted_inputs = sorted_inputs
    # We only need to decode these inputs:
    sorted_inputs = sorted_inputs[len(decodes):]

    # We don't need to waste computation on empty lines:
    num_empty_lines = 0
    while sorted_inputs and sorted_inputs[-1] == '':
        num_empty_lines += 1
        sorted_inputs.pop(-1)

    num_sentences = len(sorted_inputs)
    num_decode_batches = (num_sentences - 1) // decode_hp.batch_size + 1

    if estimator.config.use_tpu:
        length = getattr(hparams, "length", 0) or hparams.max_length
        batch_ids = []
        for line in sorted_inputs:
            if has_input:
                ids = inputs_vocab.encode(line.strip()) + [1]
            else:
                ids = targets_vocab.encode(line)
            if len(ids) < length:
                ids.extend([0] * (length - len(ids)))
            else:
                ids = ids[:length]
            batch_ids.append(ids)
        np_ids = np.array(batch_ids, dtype=np.int32)

        def input_fn(params):
            batch_size = params["batch_size"]
            dataset = tf.data.Dataset.from_tensor_slices({"inputs": np_ids})
            dataset = dataset.map(
                lambda ex:
                {"inputs": tf.reshape(ex["inputs"], (length, 1, 1))})
            dataset = dataset.batch(batch_size)
            return dataset
    else:

        def input_fn():
            input_gen = decoding._decode_batch_input_fn(
                num_decode_batches,
                sorted_inputs,
                inputs_vocab,
                decode_hp.batch_size,
                decode_hp.max_input_size,
                task_id=-1,
                has_input=has_input)
            gen_fn = decoding.make_input_fn_from_generator(input_gen)
            example = gen_fn()
            return decoding._decode_input_tensor_to_features_dict(
                example, hparams)

    result_iter = estimator.predict(input_fn, checkpoint_path=checkpoint_path)

    start_time = time.time()
    total_time_per_step = 0
    total_cnt = 0

    def timer(gen):
        while True:
            try:
                start_time = time.time()
                item = next(gen)
                elapsed_time = time.time() - start_time
                yield elapsed_time, item
            except StopIteration:
                break

    writing_mode = 'a' if tf.gfile.Exists(shuffle_file_path) else 'w'
    shuffle_file = tf.gfile.Open(shuffle_file_path, writing_mode)
    count = 0
    for elapsed_time, result in timer(result_iter):
        if decode_hp.return_beams:
            beam_decodes = []
            beam_scores = []
            output_beams = np.split(result["outputs"],
                                    decode_hp.beam_size,
                                    axis=0)
            scores = None
            if "scores" in result:
                if np.isscalar(result["scores"]):
                    result["scores"] = result["scores"].reshape(1)
                scores = np.split(result["scores"],
                                  decode_hp.beam_size,
                                  axis=0)
            for k, beam in enumerate(output_beams):
                tf.logging.info("BEAM %d:" % k)
                score = scores and scores[k]
                _, decoded_outputs, _ = decoding.log_decode_results(
                    result["inputs"],
                    beam,
                    problem_name,
                    None,
                    inputs_vocab,
                    targets_vocab,
                    log_results=decode_hp.log_results,
                    skip_eos_postprocess=decode_hp.skip_eos_postprocess)
                beam_decodes.append(decoded_outputs)
                if decode_hp.write_beam_scores:
                    beam_scores.append(score)
            if decode_hp.write_beam_scores:
                decodes.append("\t".join([
                    "\t".join([d, "%.2f" % s])
                    for d, s in zip(beam_decodes, beam_scores)
                ]))
            else:
                decodes.append("\t".join(beam_decodes))
        else:
            _, decoded_outputs, _ = decoding.log_decode_results(
                result["inputs"],
                result["outputs"],
                problem_name,
                None,
                inputs_vocab,
                targets_vocab,
                log_results=decode_hp.log_results,
                skip_eos_postprocess=decode_hp.skip_eos_postprocess)
            decodes.append(decoded_outputs)

        # Write decoded text to checkpoint
        new_decode = decodes[-1]
        shuffle_file.write(new_decode + '\n')

        # Flush checkpoint to storage.
        count += 1
        if count % decode_hp.batch_size == 0:
            tf.logging.info('Done {}/{}. Flushing.'.format(
                count, len(sorted_inputs)))
            shuffle_file.flush()
            shuffle_file.close()
            shuffle_file = tf.gfile.Open(shuffle_file_path, 'a')

        total_time_per_step += elapsed_time
        total_cnt += result["outputs"].shape[-1]

    for _ in range(num_empty_lines):
        decodes.append('')
        shuffle_file.write('\n')

    # Write the final output to file.
    outfile = tf.gfile.Open(decode_filename, "w")
    for index in range(len(all_sorted_inputs)):
        outfile.write("%s%s" %
                      (decodes[sorted_keys[index]], decode_hp.delimiter))
    outfile.flush()
    outfile.close()

    # Close and remove checkpoint.
    shuffle_file.flush()
    shuffle_file.close()
    tf.gfile.Remove(shuffle_file_path)

    # Print some decoding stats.
    duration = time.time() - start_time
    if total_cnt:
        tf.logging.info("Elapsed Time: %5.5f" % duration)
        tf.logging.info(
            "Averaged Single Token Generation Time: %5.7f "
            "(time %5.7f count %d)" %
            (total_time_per_step / total_cnt, total_time_per_step, total_cnt))
    if decode_hp.batch_size == 1:
        tf.logging.info("Inference time %.4f seconds "
                        "(Latency = %.4f ms/setences)" %
                        (duration, 1000.0 * duration / num_sentences))
    else:
        tf.logging.info("Inference time %.4f seconds "
                        "(Throughput = %.4f sentences/second)" %
                        (duration, num_sentences / duration))
def decode_from_dataset(estimator,
                        problem_name,
                        hparams,
                        decode_hp,
                        decode_to_file=None,
                        dataset_split=None):
    """Perform decoding from dataset."""
    tf.logging.info("Performing local inference from dataset for %s.",
                    str(problem_name))

    shard = decode_hp.shard_id if decode_hp.shards > 1 else None

    output_dir = os.path.join(estimator.model_dir, "decode")
    tf.gfile.MakeDirs(output_dir)

    if decode_hp.batch_size:
        hparams.batch_size = decode_hp.batch_size
        hparams.use_fixed_batch_size = True

    dataset_kwargs = {
        "shard": shard,
        "dataset_split": dataset_split,
        "max_records": decode_hp.num_samples
    }

    problem = hparams.problem
    infer_input_fn = problem.make_estimator_input_fn(
        tf.estimator.ModeKeys.PREDICT, hparams, dataset_kwargs=dataset_kwargs)

    predictions = estimator.predict(infer_input_fn)

    decode_to_file = decode_to_file or decode_hp.decode_to_file
    if decode_to_file:
        if decode_hp.shards > 1:
            decode_filename = decode_to_file + ("%.2d" % decode_hp.shard_id)
        else:
            decode_filename = decode_to_file

        output_filepath = decoding._decode_filename(decode_filename,
                                                    problem_name, decode_hp)
        parts = output_filepath.split(".")
        parts[-1] = "targets"
        target_filepath = ".".join(parts)
        parts[-1] = "inputs"
        input_filepath = ".".join(parts)
        parts[-1] = "enc_state"
        encoder_state_file_path = ".".join(parts)

        input_file = tf.gfile.Open(input_filepath, "w")

    problem_hparams = hparams.problem_hparams
    has_input = "inputs" in problem_hparams.vocabulary
    inputs_vocab_key = "inputs" if has_input else "targets"
    inputs_vocab = problem_hparams.vocabulary[inputs_vocab_key]

    ##### Modified #####
    # Encoder outputs list created

    encoder_outputs = []
    decoded_inputs = []

    for num_predictions, prediction in enumerate(predictions):
        num_predictions += 1
        inputs = prediction["inputs"]
        encoder_output = prediction["encoder_outputs"]
        decoded_input = inputs_vocab.decode(
            decoding._save_until_eos(inputs, False))

        encoder_outputs.append(encoder_output)
        decoded_inputs.append(decoded_input)

        ##### Modified #####
        # Writing encoder_outputs list to file

        if decode_to_file:
            for i, (e_output, d_input) in \
                    enumerate(zip(encoder_outputs, decoded_inputs)):

                input_file.write("{}:\t{}".format(
                    i,
                    str(d_input) + decode_hp.delimiter))

            np.save(encoder_state_file_path, np.array(encoder_outputs))

        if (0 <= decode_hp.num_samples <= num_predictions):
            break

    if decode_to_file:
        input_file.close()

    decoding.decorun_postdecode_hooks(
        decoding.DecodeHookArgs(estimator=estimator,
                                problem=problem,
                                output_dir=output_dir,
                                hparams=hparams,
                                decode_hparams=decode_hp))

    tf.logging.info("Completed inference on %d samples." % num_predictions)  # pylint: disable=undefined-loop-variable
def decode_from_file(estimator,
                     filename,
                     hparams,
                     decode_hp,
                     decode_to_file=None,
                     checkpoint_path=None):
    """Compute predictions on entries in filename and write them out."""
    if not decode_hp.batch_size:
        decode_hp.batch_size = 32
        tf.logging.info("decode_hp.batch_size not specified; default=%d" %
                        decode_hp.batch_size)

    p_hp = hparams.problem_hparams
    has_input = "inputs" in p_hp.vocabulary
    inputs_vocab_key = "inputs" if has_input else "targets"
    inputs_vocab = p_hp.vocabulary[inputs_vocab_key]
    problem_name = FLAGS.problem
    tf.logging.info("Performing decoding from a file.")
    sorted_inputs, sorted_keys = decoding._get_sorted_inputs(
        filename, decode_hp.shards, decode_hp.delimiter)
    num_decode_batches = (len(sorted_inputs) - 1) // decode_hp.batch_size + 1

    def input_fn():
        input_gen = decoding._decode_batch_input_fn(num_decode_batches,
                                                    sorted_inputs,
                                                    inputs_vocab,
                                                    decode_hp.batch_size,
                                                    decode_hp.max_input_size)
        gen_fn = decoding.make_input_fn_from_generator(input_gen)
        example = gen_fn()
        return decoding._decode_input_tensor_to_features_dict(example, hparams)

    ##### Modified #####
    # Encoder outputs list created

    decoded_inputs = []
    encoder_outputs = []
    result_iter = estimator.predict(input_fn, checkpoint_path=checkpoint_path)

    start_time = time.time()
    total_time_per_step = 0
    total_cnt = 0

    def timer(gen):
        while True:
            try:
                start_time = time.time()
                item = next(gen)
                elapsed_time = time.time() - start_time
                yield elapsed_time, item
            except StopIteration:
                break

    for elapsed_time, result in timer(result_iter):
        decoded_input = inputs_vocab.decode(
            decoding._save_until_eos(result["inputs"], False))
        decoded_inputs.append(decoded_input)
        encoder_outputs.append(np.array(result["encoder_outputs"]))

        total_time_per_step += elapsed_time
        total_cnt += result["outputs"].shape[-1]
    tf.logging.info("Elapsed Time: %5.5f" % (time.time() - start_time))
    tf.logging.info("Averaged Single Token Generation Time: %5.7f" %
                    (total_time_per_step / total_cnt))

    decoded_inputs.reverse()
    encoder_outputs.reverse()

    decode_filename = decode_to_file if decode_to_file else filename

    if decode_hp.shards > 1:
        decode_filename += "%.2d" % decode_hp.shard_id
    if not decode_to_file:
        decode_filename = decoding._decode_filename(decode_filename,
                                                    problem_name, decode_hp)

    base = os.path.basename(decode_filename).split('.')
    dirname = os.path.dirname(decode_filename)
    encode_filename = os.path.join(dirname, '{}{}'.format(base[0], '.npy'))

    tf.logging.info("Writing inputs into %s" % decode_filename)
    tf.logging.info("Writing encoder outputs into %s" % encode_filename)
    print("Writing encoder outputs into %s" % encode_filename)
    outfile = tf.gfile.Open(decode_filename, "w")

    ##### Modified #####
    # Writing encoder_outputs list to file

    if decode_to_file:
        for i, (e_output, d_input) in \
                enumerate(zip(encoder_outputs, decoded_inputs)):
            outfile.write("{}".format(' '.join([
                word for word in str(d_input).strip().split()
                if word.strip() != '' and word.strip() != '<unk>'
            ]) + decode_hp.delimiter))

        np.save(encode_filename, np.array(encoder_outputs))

    if decode_to_file:
        outfile.close()
def decode_from_babi_dataset(estimator,
                        problem_name,
                        hparams,
                        decode_hp,
                        decode_to_file=None):

  """Perform decoding from dataset."""
  tf.logging.info("Performing local inference from dataset for %s.",
                  str(problem_name))


  # We assume that worker_id corresponds to shard number.
  shard = decode_hp.shard_id if decode_hp.shards > 1 else None

  # Setup decode output directory for any artifacts that may be written out
  output_dir = os.path.join(estimator.model_dir, "decode")
  tf.gfile.MakeDirs(output_dir)

  # If decode_hp.batch_size is specified, use a fixed batch size
  if decode_hp.batch_size:
    hparams.batch_size = decode_hp.batch_size
    hparams.use_fixed_batch_size = True

  dataset_kwargs = {
      "shard": shard,
      "dataset_split": None,
      "max_records": decode_hp.num_samples
  }

  # Build the inference input function
  problem = hparams.problem
  infer_input_fn = problem.make_estimator_input_fn(
      tf.estimator.ModeKeys.PREDICT, hparams, dataset_kwargs=dataset_kwargs)

  predictions, output_dirs = [], []
  for decode_id in range(decode_hp.num_decodes):
    tf.logging.info("Decoding {}".format(decode_id))

    # Create decode directory if not in-memory decoding.
    if not decode_hp.decode_in_memory:
      output_dir = os.path.join(estimator.model_dir, "decode_%05d" % decode_id)
      tf.gfile.MakeDirs(output_dir)
      output_dirs.append(output_dir)

    # Get the predictions as an iterable
    predictions = estimator.predict(infer_input_fn)

    # Prepare output file writers if decode_to_file passed
    decode_to_file = decode_to_file or decode_hp.decode_to_file
    if decode_to_file:
      if decode_hp.shards > 1:
        decode_filename = decode_to_file + ("%.2d" % decode_hp.shard_id)
      else:
        decode_filename = decode_to_file
      output_filepath = decoding._decode_filename(decode_filename, problem_name,
                                                  decode_hp)
      parts = output_filepath.split(".")

      parts[-1] = "decoded"
      decoded_filepath = ".".join(parts)

      dir = os.path.dirname(decoded_filepath)
      if not tf.gfile.Exists(dir):
        tf.gfile.MakeDirs(dir)

      decoded_file = tf.gfile.Open(decoded_filepath, "w")

    problem_hparams = hparams.problem_hparams

    story_vocab = problem_hparams.vocabulary[babi_qa.FeatureNames.STORY]
    question_vocab = problem_hparams.vocabulary[babi_qa.FeatureNames.STORY]
    targets_vocab = problem_hparams.vocabulary["targets"]

    for num_predictions, prediction in enumerate(predictions):
      num_predictions += 1
      story = prediction[babi_qa.FeatureNames.STORY]
      question = prediction[babi_qa.FeatureNames.QUESTION]
      targets = prediction["targets"]
      outputs = prediction["outputs"]

      # Log predictions
      decoded_outputs = []
      decoded_scores = []
      if decode_hp.return_beams:
        output_beams = np.split(outputs, decode_hp.beam_size, axis=0)
        scores = None
        if "scores" in prediction:
          scores = np.split(prediction["scores"], decode_hp.beam_size, axis=0)
        for i, beam in enumerate(output_beams):
          tf.logging.info("BEAM %d:" % i)
          score = scores and scores[i]
          decoded = log_decode_results(
              story,
              question,
              beam,
              story_vocab,
              question_vocab,
              targets_vocab,
              identity_output=decode_hp.identity_output,
              targets=targets)
          decoded_outputs.append(decoded)
          if decode_hp.write_beam_scores:
            decoded_scores.append(score)
      else:
        decoded = log_decode_results(
            story,
            question,
            outputs,
            story_vocab,
            question_vocab,
            targets_vocab,
            identity_output=decode_hp.identity_output,
            targets=targets)
        decoded_outputs.append(decoded)

      # Write out predictions if decode_to_file passed
      if decode_to_file:
        for i, (d_story, d_question ,
                d_output, d_target) in enumerate(decoded_outputs):
          beam_score_str = ""
          if decode_hp.write_beam_scores:
            beam_score_str = "\t%.2f" % decoded_scores[i]

          decoded_file.write('STORY: \n%s\n' % _make_story_pretty(str(d_story)))
          decoded_file.write('QUESTION: %s\n' %_remove_pad(str(d_question)))
          decoded_file.write('ANSWER: %s\n' % _remove_pad(str(d_target)))
          decoded_file.write('OUTPUT: %s\n' % (str(d_output) + beam_score_str + decode_hp.delimiter) )
          decoded_file.write('==================================================================\n')

      if (decode_hp.num_samples >= 0 and
          num_predictions >= decode_hp.num_samples):
        break

    if decode_to_file:
      tf.logging.info("Decoded results are written in: %s" % decoded_filepath)
      decoded_file.close()

    tf.logging.info("Completed inference on %d samples." % num_predictions)  # pylint: disable=undefined-loop-variable
def decode_from_file_search_based(estimator,
                                  filename,
                                  hparams,
                                  decode_hp,
                                  decode_to_file=None,
                                  checkpoint_path=None):
    """Compute predictions on entries in filename and write them out."""
    if not decode_hp.batch_size:
        decode_hp.batch_size = 32
        tf.logging.info(
                "decode_hp.batch_size not specified; default=%d" % decode_hp.batch_size)

    problem_id = decode_hp.problem_idx
    # Inputs vocabulary is set to targets if there are no inputs in the problem,
    # e.g., for language models where the inputs are just a prefix of targets.
    has_input = "inputs" in hparams.problems[problem_id].vocabulary
    inputs_vocab_key = "inputs" if has_input else "targets"
    inputs_vocab = hparams.problems[problem_id].vocabulary[inputs_vocab_key]
    targets_vocab = hparams.problems[problem_id].vocabulary["targets"]
    problem_name = FLAGS.problems.split("-")[problem_id]
    tf.logging.info("Performing decoding from a file.")
    sorted_inputs, sorted_keys = _get_sorted_inputs(filename, decode_hp.shards,
                                                    decode_hp.delimiter)
    num_decode_batches = (len(sorted_inputs) - 1) // decode_hp.batch_size + 1

    data_dir = '/'.join(filename.split('/')[:-1])
    table_path = os.path.join(data_dir, '../../search_engine/big_table.txt')
    he_search_path = os.path.join(data_dir, 'he.search.txt')
    en_search_path = os.path.join(data_dir, 'en.search.txt')
    searcher = Searcher(table_path, he_search_path)
    translator = Translator(data_dir, he_search_path)

    def input_fn():
        input_gen = _decode_batch_input_fn_search_based(
                problem_id, num_decode_batches, sorted_inputs, inputs_vocab, targets_vocab,
                decode_hp.batch_size, decode_hp.max_input_size, searcher, translator,
                hparams.problems[problem_id])
        gen_fn = make_input_fn_from_generator(input_gen)
        example = gen_fn()
        return _decode_input_tensor_to_features_dict(example, hparams)

    decodes = []
    result_iter = estimator.predict(input_fn, checkpoint_path=checkpoint_path)
    for result in result_iter:
        if decode_hp.return_beams:
            beam_decodes = []
            beam_scores = []
            output_beams = np.split(result["outputs"], decode_hp.beam_size, axis=0)
            scores = None
            if "scores" in result:
                scores = np.split(result["scores"], decode_hp.beam_size, axis=0)
            for k, beam in enumerate(output_beams):
                tf.logging.info("BEAM %d:" % k)
                score = scores and scores[k]
                _, decoded_outputs, _ = log_decode_results(result["inputs"], beam,
                                                           problem_name, None,
                                                           inputs_vocab, targets_vocab)
                beam_decodes.append(decoded_outputs)
                if decode_hp.write_beam_scores:
                    beam_scores.append(score)
            if decode_hp.write_beam_scores:
                decodes.append("\t".join(
                        ["\t".join([d, "%.2f" % s]) for d, s
                         in zip(beam_decodes, beam_scores)]))
            else:
                decodes.append("\t".join(beam_decodes))
        else:
            _, decoded_outputs, _ = log_decode_results(
                    result["inputs"], result["outputs"], problem_name,
                    None, inputs_vocab, targets_vocab)
            decodes.append(decoded_outputs)

    # Reversing the decoded inputs and outputs because they were reversed in
    # _decode_batch_input_fn
    sorted_inputs.reverse()
    decodes.reverse()
    # If decode_to_file was provided use it as the output filename without change
    # (except for adding shard_id if using more shards for decoding).
    # Otherwise, use the input filename plus model, hp, problem, beam, alpha.
    decode_filename = decode_to_file if decode_to_file else filename
    if decode_hp.shards > 1:
        decode_filename += "%.2d" % decode_hp.shard_id
    if not decode_to_file:
        decode_filename = _decode_filename(decode_filename, problem_name, decode_hp)
    tf.logging.info("Writing decodes into %s" % decode_filename)
    outfile = tf.gfile.Open(decode_filename, "w")
    for index in range(len(sorted_inputs)):
        outfile.write("%s%s" % (decodes[sorted_keys[index]], decode_hp.delimiter))
Example #6
0
def decode_from_file_fn(estimator,
                        filename,
                        hparams,
                        decode_hp,
                        decode_to_file=None,
                        checkpoint_path=None):
  """Compute predictions on entries in filename and write them out."""
  if not decode_hp.batch_size:
    decode_hp.batch_size = 32
    tf.logging.info(
        "decode_hp.batch_size not specified; default=%d" % decode_hp.batch_size)

  # Inputs vocabulary is set to targets if there are no inputs in the problem,
  # e.g., for language models where the inputs are just a prefix of targets.
  p_hp = hparams.problem_hparams
  has_input = "inputs" in p_hp.vocabulary
  inputs_vocab_key = "inputs" if has_input else "targets"
  inputs_vocab = p_hp.vocabulary[inputs_vocab_key]
  targets_vocab = p_hp.vocabulary["targets"]
  problem_name = FLAGS.problem
  filename = decoding._add_shard_to_filename(filename, decode_hp)
  tf.logging.info("Performing decoding from file (%s)." % filename)
  if has_input:
    sorted_inputs, sorted_keys = decoding._get_sorted_inputs(
        filename, decode_hp.delimiter)
  else:
    sorted_inputs = decoding._get_language_modeling_inputs(
        filename, decode_hp.delimiter, repeat=decode_hp.num_decodes)
    sorted_keys = range(len(sorted_inputs))
  num_sentences = len(sorted_inputs)
  num_decode_batches = (num_sentences - 1) // decode_hp.batch_size + 1

  if estimator.config.use_tpu:
    length = getattr(hparams, "length", 0) or hparams.max_length
    batch_ids = []
    for line in sorted_inputs:
      if has_input:
        ids = inputs_vocab.encode(line.strip()) + [1]
      else:
        ids = targets_vocab.encode(line)
      if len(ids) < length:
        ids.extend([0] * (length - len(ids)))
      else:
        ids = ids[:length]
      batch_ids.append(ids)
    np_ids = np.array(batch_ids, dtype=np.int32)
    def input_fn(params):
      batch_size = params["batch_size"]
      dataset = tf.data.Dataset.from_tensor_slices({"inputs": np_ids})
      dataset = dataset.map(
          lambda ex: {"inputs": tf.reshape(ex["inputs"], (length, 1, 1))})
      dataset = dataset.batch(batch_size)
      return dataset
  else:
    def input_fn():
      input_gen = decoding._decode_batch_input_fn(
          num_decode_batches, sorted_inputs,
          inputs_vocab, decode_hp.batch_size,
          decode_hp.max_input_size,
          task_id=decode_hp.multiproblem_task_id, has_input=has_input)
      gen_fn = decoding.make_input_fn_from_generator(input_gen)
      example = gen_fn()
      return decoding._decode_input_tensor_to_features_dict(example, hparams, decode_hp)
  decodes = []
  result_iter = estimator.predict(input_fn, checkpoint_path=checkpoint_path)

  start_time = time.time()
  total_time_per_step = 0
  total_cnt = 0

  def timer(gen):
    while True:
      try:
        start_time = time.time()
        item = next(gen)
        elapsed_time = time.time() - start_time
        yield elapsed_time, item
      except StopIteration:
        break

  for elapsed_time, result in timer(result_iter):
    if decode_hp.return_beams:
      beam_decodes = []
      beam_scores = []
      output_beams = np.split(result["outputs"], decode_hp.beam_size, axis=0)
      scores = None
      if "scores" in result:
        if np.isscalar(result["scores"]):
          result["scores"] = result["scores"].reshape(1)
        scores = np.split(result["scores"], decode_hp.beam_size, axis=0)
      for k, beam in enumerate(output_beams):
        tf.logging.info("BEAM %d:" % k)
        score = scores and scores[k]
        _, decoded_outputs, _ = decoding.log_decode_results(
            result["inputs"],
            beam,
            problem_name,
            None,
            inputs_vocab,
            targets_vocab,
            log_results=decode_hp.log_results,
            skip_eos_postprocess=decode_hp.skip_eos_postprocess)
        beam_decodes.append(decoded_outputs)
        if decode_hp.write_beam_scores:
          beam_scores.append(score)
      if decode_hp.write_beam_scores:
        decodes.append("\t".join([
            "\t".join([d, "%.2f" % s])
            for d, s in zip(beam_decodes, beam_scores)
        ]))
      else:
        decodes.append("\t".join(beam_decodes))
    else:
      _, decoded_outputs, _ = decoding.log_decode_results(
          result["inputs"],
          result["outputs"],
          problem_name,
          None,
          inputs_vocab,
          targets_vocab,
          log_results=decode_hp.log_results,
          skip_eos_postprocess=decode_hp.skip_eos_postprocess)
      decodes.append(decoded_outputs)
    total_time_per_step += elapsed_time
    total_cnt += result["outputs"].shape[-1]
  duration = time.time() - start_time
  tf.logging.info("Elapsed Time: %5.5f" % duration)
  tf.logging.info("Averaged Single Token Generation Time: %5.7f "
                  "(time %5.7f count %d)" %
                  (total_time_per_step / total_cnt,
                   total_time_per_step, total_cnt))
  if decode_hp.batch_size == 1:
    tf.logging.info("Inference time %.4f seconds "
                    "(Latency = %.4f ms/setences)" %
                    (duration, 1000.0*duration/num_sentences))
  else:
    tf.logging.info("Inference time %.4f seconds "
                    "(Throughput = %.4f sentences/second)" %
                    (duration, num_sentences/duration))

  # If decode_to_file was provided use it as the output filename without change
  # (except for adding shard_id if using more shards for decoding).
  # Otherwise, use the input filename plus model, hp, problem, beam, alpha.
  decode_filename = decode_to_file if decode_to_file else filename
  if not decode_to_file:
    decode_filename = decoding._decode_filename(decode_filename, problem_name, decode_hp)
  else:
    decode_filename = decoding._add_shard_to_filename(decode_filename, decode_hp)
  tf.logging.info("Writing decodes into %s" % decode_filename)
  outfile = tf.gfile.Open(decode_filename, "w")
  for index in range(len(sorted_inputs)):
    special_chars = ["\a", "\n", "\f", "\r", "\b"]
    output = decodes[sorted_keys[index]]
    for c in special_chars:
      output = output.replace(c, ' ')
    try:
      outfile.write("%s%s" % (output, decode_hp.delimiter))
    except:
      outfile.write("%s" % decode_hp.delimiter)
  outfile.flush()
  outfile.close()