Ejemplo n.º 1
0
def get_translation(nmt_outputs, infer_logits, sent_id, tgt_eos,
                    subword_option):
    """Given batch decoding outputs, select a sentence and turn to text."""
    if tgt_eos: tgt_eos = tgt_eos.encode("utf-8")
    # Select a sentence
    output = nmt_outputs[sent_id, :].tolist()
    scores = infer_logits[sent_id]
    #fw=open('sample_res/scores_logits_{}'.format(sent_id),'w+')
    #for i in scores:
    #fw.write('\n'.join([' '.join(str(a)for a in e)for e in scores]))
    #fw.close()
    #print ("output",output)
    #print ("scores",scores)
    # If there is an eos symbol in outputs, cut them at that point.
    if tgt_eos and tgt_eos in output:
        output = output[:output.index(tgt_eos)]

    if subword_option == "bpe":  # BPE
        #print ("subword_option ==bpe")
        translation = utils.format_bpe_text(output)
    elif subword_option == "spm":  # SPM
        #print ("subword_option ==spm")
        translation = utils.format_spm_text(output)
    else:
        #print ("scores in format_text!")
        translation = utils.format_text(output, scores)

    return translation
Ejemplo n.º 2
0
def get_translation(nmt_outputs, sent_id, tgt_eos, subword_option):
    if tgt_eos: tgt_eos = tgt_eos.encode("utf-8")
    output = nmt_outputs[sent_id, :].tolist()
    if tgt_eos and tgt_eos in output:
        output = output[:output.index(tgt_eos)]

    if subword_option == "bpe":
        translation = utils.format_bpe_text(output)
    elif subword_option == "spm":
        translation = utils.format_spm_text(output)
    else:
        translation = utils.format_text(output)
    return translation
Ejemplo n.º 3
0
def get_translation(nmt_outputs, sent_id, tgt_eos, subword_option):
  """Given batch decoding outputs, select a sentence and turn to text."""
  if tgt_eos: tgt_eos = tgt_eos.encode("utf-8")
  # Select a sentence
  output = nmt_outputs[sent_id, :].tolist()

  # If there is an eos symbol in outputs, cut them at that point.
  if tgt_eos and tgt_eos in output:
    output = output[:output.index(tgt_eos)]

  if subword_option == "bpe":  # BPE
    translation = utils.format_bpe_text(output)
  elif subword_option == "spm":  # SPM
    translation = utils.format_spm_text(output)
  else:
    translation = utils.format_text(output)

  return translation
 def testFormatSPMText(self):
   spm_line = u"\u2581This \u2581is \u2581a \u2581 te st .".encode("utf-8")
   expected_result = "This is a test."
   self.assertEqual(expected_result,
                    misc_utils.format_spm_text(spm_line.split(b" ")))