def postprocess_output(outputs,
                       sentence_id,
                       eos,
                       bpe_delimiter,
                       number_token=None,
                       name_token=None):
    """Given batch decoding outputs, select a sentence and postprocess it."""
    # Select the sentence
    output = outputs[sentence_id, :].tolist()
    # It doesn't cut off at </s> because of mismatch between </s> and b'</s>', the output of the lookup table

    # The lookup-table outputs are in bytes. We need this for the equality check
    eos = bytes(eos, encoding='utf-8')
    if eos and eos in output:
        output = output[:output.index(eos)]
    if number_token:
        number_token_bytes = bytes(number_token, encoding='utf8')
        output = [
            b"53" if word == number_token_bytes else word for word in output
        ]
    if name_token:
        name_token_bytes = bytes(name_token, encoding='utf8')
        output = [
            b"Batman" if word == name_token_bytes else word for word in output
        ]

    if bpe_delimiter:
        response = utils.format_bpe_text(output, bpe_delimiter)
    else:
        response = utils.format_text(output)

    return response
Пример #2
0
def get_translation(nmt_outputs, infer_logits, sent_id, tgt_eos,
                    subword_option):
    """Given batch decoding outputs, select a sentence and turn to text."""
    if tgt_eos: tgt_eos = tgt_eos.encode("utf-8")
    # Select a sentence
    output = nmt_outputs[sent_id, :].tolist()
    scores = infer_logits[sent_id]
    #fw=open('sample_res/scores_logits_{}'.format(sent_id),'w+')
    #for i in scores:
    #fw.write('\n'.join([' '.join(str(a)for a in e)for e in scores]))
    #fw.close()
    #print ("output",output)
    #print ("scores",scores)
    # If there is an eos symbol in outputs, cut them at that point.
    if tgt_eos and tgt_eos in output:
        output = output[:output.index(tgt_eos)]

    if subword_option == "bpe":  # BPE
        #print ("subword_option ==bpe")
        translation = utils.format_bpe_text(output)
    elif subword_option == "spm":  # SPM
        #print ("subword_option ==spm")
        translation = utils.format_spm_text(output)
    else:
        #print ("scores in format_text!")
        translation = utils.format_text(output, scores)

    return translation
Пример #3
0
 def testFormatBpeText(self):
     bpe_line = (
         b"En@@ ough to make already reluc@@ tant men hesitate to take screening"
         b" tests .")
     expected_result = (
         b"Enough to make already reluctant men hesitate to take screening tests"
         b" .")
     self.assertEqual(expected_result,
                      misc_utils.format_bpe_text(bpe_line.split(b" ")))
Пример #4
0
def get_translation(nmt_outputs, sent_id, tgt_eos, subword_option):
    if tgt_eos: tgt_eos = tgt_eos.encode("utf-8")
    output = nmt_outputs[sent_id, :].tolist()
    if tgt_eos and tgt_eos in output:
        output = output[:output.index(tgt_eos)]

    if subword_option == "bpe":
        translation = utils.format_bpe_text(output)
    elif subword_option == "spm":
        translation = utils.format_spm_text(output)
    else:
        translation = utils.format_text(output)
    return translation
Пример #5
0
def get_translation(nmt_outputs, sent_id, tgt_eos, bpe_delimiter):
    """Given batch decoding outputs, select a sentence and turn to text."""
    # Select a sentence
    output = nmt_outputs[sent_id, :].tolist()

    # If there is an eos symbol in outputs, cut them at that point.
    if tgt_eos and tgt_eos in output:
        output = output[:output.index(tgt_eos)]

    if not bpe_delimiter:
        translation = utils.format_text(output)
    else:  # BPE
        translation = utils.format_bpe_text(output, delimiter=bpe_delimiter)

    return translation
Пример #6
0
def get_translation(nmt_outputs, sent_id, tgt_eos, subword_option):
  """Given batch decoding outputs, select a sentence and turn to text."""
  if tgt_eos: tgt_eos = tgt_eos.encode("utf-8")
  # Select a sentence
  output = nmt_outputs[sent_id, :].tolist()

  # If there is an eos symbol in outputs, cut them at that point.
  if tgt_eos and tgt_eos in output:
    output = output[:output.index(tgt_eos)]

  if subword_option == "bpe":  # BPE
    translation = utils.format_bpe_text(output)
  elif subword_option == "spm":  # SPM
    translation = utils.format_spm_text(output)
  else:
    translation = utils.format_text(output)

  return translation