def postprocess_output(outputs,
                       sentence_id,
                       eos,
                       bpe_delimiter,
                       number_token=None,
                       name_token=None):
    """Given batch decoding outputs, select a sentence and postprocess it."""
    # Select the sentence
    output = outputs[sentence_id, :].tolist()
    # It doesn't cut off at </s> because of mismatch between </s> and b'</s>', the output of the lookup table

    # The lookup-table outputs are in bytes. We need this for the equality check
    eos = bytes(eos, encoding='utf-8')
    if eos and eos in output:
        output = output[:output.index(eos)]
    if number_token:
        number_token_bytes = bytes(number_token, encoding='utf8')
        output = [
            b"53" if word == number_token_bytes else word for word in output
        ]
    if name_token:
        name_token_bytes = bytes(name_token, encoding='utf8')
        output = [
            b"Batman" if word == name_token_bytes else word for word in output
        ]

    if bpe_delimiter:
        response = utils.format_bpe_text(output, bpe_delimiter)
    else:
        response = utils.format_text(output)

    return response
예제 #2
0
파일: nmt_utils.py 프로젝트: ml-lab/Pun-GAN
def get_translation(nmt_outputs, infer_logits, sent_id, tgt_eos,
                    subword_option):
    """Given batch decoding outputs, select a sentence and turn to text."""
    if tgt_eos: tgt_eos = tgt_eos.encode("utf-8")
    # Select a sentence
    output = nmt_outputs[sent_id, :].tolist()
    scores = infer_logits[sent_id]
    #fw=open('sample_res/scores_logits_{}'.format(sent_id),'w+')
    #for i in scores:
    #fw.write('\n'.join([' '.join(str(a)for a in e)for e in scores]))
    #fw.close()
    #print ("output",output)
    #print ("scores",scores)
    # If there is an eos symbol in outputs, cut them at that point.
    if tgt_eos and tgt_eos in output:
        output = output[:output.index(tgt_eos)]

    if subword_option == "bpe":  # BPE
        #print ("subword_option ==bpe")
        translation = utils.format_bpe_text(output)
    elif subword_option == "spm":  # SPM
        #print ("subword_option ==spm")
        translation = utils.format_spm_text(output)
    else:
        #print ("scores in format_text!")
        translation = utils.format_text(output, scores)

    return translation
예제 #3
0
def get_translation(nmt_outputs, sent_id, tgt_eos):
  """Given batch decoding outputs, select a sentence and turn to text."""
  # Select a sentence
  output = nmt_outputs[sent_id, :].tolist()
  if tgt_eos and tgt_eos in output:
    output = output[:output.index(tgt_eos)]
  translation = utils.format_text(output)
  return translation
예제 #4
0
def decode_inference_indices(model, sess):
    nmt_outputs, infer_summary = model.decode(sess)
    assert nmt_outputs.shape[0] == 1

    output = nmt_outputs[0, :].tolist()

    translation = utils.format_text(output)

    return translation
예제 #5
0
def get_translation_cut_both(nmt_outputs, sent_id, start_token, end_token):
  """Given batch decoding outputs, select a sentence and turn to text."""
  # Select a sentence
  output = nmt_outputs[sent_id, :].tolist()
  if start_token and start_token in output:
    output = output[:output.index(start_token)]
  if end_token and end_token in output:
    output = output[:output.index(end_token)]

  translation = utils.format_text(output)

  return translation
예제 #6
0
def _get_translation(outputs, sent_id, tgt_eos, subword_option):
    """Given batch decoding outputs, select a sentence and turn to text."""
    if tgt_eos: tgt_eos = tgt_eos.encode("utf-8")
    # Select a sentence
    output = outputs[sent_id, :].tolist()

    # If there is an eos symbol in outputs, cut them at that point.
    if tgt_eos and tgt_eos in output:
        output = output[:output.index(tgt_eos)]

    translation = utils.format_text(output)
    return translation
예제 #7
0
def get_translation(nmt_outputs, sent_id, tgt_eos, subword_option):
    if tgt_eos: tgt_eos = tgt_eos.encode("utf-8")
    output = nmt_outputs[sent_id, :].tolist()
    if tgt_eos and tgt_eos in output:
        output = output[:output.index(tgt_eos)]

    if subword_option == "bpe":
        translation = utils.format_bpe_text(output)
    elif subword_option == "spm":
        translation = utils.format_spm_text(output)
    else:
        translation = utils.format_text(output)
    return translation
예제 #8
0
파일: nmt_utils.py 프로젝트: zhang197/nslt
def get_translation(nmt_outputs, sent_id, tgt_eos, bpe_delimiter):
    """Given batch decoding outputs, select a sentence and turn to text."""
    # Select a sentence
    output = nmt_outputs[sent_id, :].tolist()

    # If there is an eos symbol in outputs, cut them at that point.
    if tgt_eos and tgt_eos in output:
        output = output[:output.index(tgt_eos)]

    if not bpe_delimiter:
        translation = utils.format_text(output)
    else:  # BPE
        translation = utils.format_bpe_text(output, delimiter=bpe_delimiter)

    return translation
예제 #9
0
def decode_inference_indices(model, sess):
    nmt_outputs, infer_summary = model.decode(sess)
    assert nmt_outputs.shape[0] == 1

    #     if tgt_eos: tgt_eos = tgt_eos.encode("utf-8")
    # Select a sentence
    output = nmt_outputs[0, :].tolist()
    #
    #     # If there is an eos symbol in outputs, cut them at that point.
    if "</s>" in output:
        output = output[:output.index("</s>")]

    translation = utils.format_text(output)

    return translation
예제 #10
0
def get_translation(nmt_outputs, sent_id, tgt_eos, subword_option):
  """Given batch decoding outputs, select a sentence and turn to text."""
  if tgt_eos: tgt_eos = tgt_eos.encode("utf-8")
  # Select a sentence
  output = nmt_outputs[sent_id, :].tolist()

  # If there is an eos symbol in outputs, cut them at that point.
  if tgt_eos and tgt_eos in output:
    output = output[:output.index(tgt_eos)]

  if subword_option == "bpe":  # BPE
    translation = utils.format_bpe_text(output)
  elif subword_option == "spm":  # SPM
    translation = utils.format_spm_text(output)
  else:
    translation = utils.format_text(output)

  return translation