def postprocess_output(outputs,
                       sentence_id,
                       eos,
                       bpe_delimiter,
                       number_token=None,
                       name_token=None):
    """Given batch decoding outputs, select a sentence and postprocess it."""
    # Select the sentence
    output = outputs[sentence_id, :].tolist()
    # It doesn't cut off at </s> because of mismatch between </s> and b'</s>', the output of the lookup table

    # The lookup-table outputs are in bytes. We need this for the equality check
    eos = bytes(eos, encoding='utf-8')
    if eos and eos in output:
        output = output[:output.index(eos)]
    if number_token:
        number_token_bytes = bytes(number_token, encoding='utf8')
        output = [
            b"53" if word == number_token_bytes else word for word in output
        ]
    if name_token:
        name_token_bytes = bytes(name_token, encoding='utf8')
        output = [
            b"Batman" if word == name_token_bytes else word for word in output
        ]

    if bpe_delimiter:
        response = utils.format_bpe_text(output, bpe_delimiter)
    else:
        response = utils.format_text(output)

    return response
def get_translation(nmt_outputs, sent_id, tgt_eos):
    """Given batch decoding outputs, select a sentence and turn to text."""
    if tgt_eos: tgt_eos = tgt_eos.encode("utf-8")
    # Select a sentence
    output = nmt_outputs[sent_id, :].tolist()

    # If there is an eos symbol in outputs, cut them at that point.
    if tgt_eos and tgt_eos in output:
        output = output[:output.index(tgt_eos)]
        translation = utils.format_text(output)

    return translation
def get_translation(nmt_outputs, sent_id, tgt_eos, subword_option):
  """Given batch decoding outputs, select a sentence and turn to text."""
  if tgt_eos: tgt_eos = tgt_eos.encode("utf-8")
  # Select a sentence
  output = nmt_outputs[sent_id, :].tolist()

  # If there is an eos symbol in outputs, cut them at that point.
  if tgt_eos and tgt_eos in output:
    output = output[:output.index(tgt_eos)]

  if subword_option == "bpe":  # BPE
    translation = utils.format_bpe_text(output)
  elif subword_option == "spm":  # SPM
    translation = utils.format_spm_text(output)
  else:
    translation = utils.format_text(output)

  return translation
Example #4
0
def get_translation(nmt_outputs, sent_id, tgt_sos, tgt_eos, bpe_delimiter):
    """Given batch decoding outputs, select a sentence and turn to text."""
    if tgt_sos: tgt_sos = tgt_sos.encode("utf-8")
    if tgt_eos: tgt_eos = tgt_eos.encode("utf-8")
    if bpe_delimiter: bpe_delimiter = bpe_delimiter.encode("utf-8")
    # Select a sentence
    output = nmt_outputs[sent_id, :].tolist()

    if tgt_sos and output[0] == tgt_sos:
        output = output[1:]

    # If there is an eos symbol in outputs, cut them at that point.
    if tgt_eos and tgt_eos in output:
        output = output[:output.index(tgt_eos)]

    if not bpe_delimiter:
        translation = utils.format_text(output)
    else:  # BPE
        translation = utils.format_bpe_text(output, delimiter=bpe_delimiter)

    return translation