def postprocess_output(outputs, sentence_id, eos, bpe_delimiter, number_token=None, name_token=None): """Given batch decoding outputs, select a sentence and postprocess it.""" # Select the sentence output = outputs[sentence_id, :].tolist() # It doesn't cut off at </s> because of mismatch between </s> and b'</s>', the output of the lookup table # The lookup-table outputs are in bytes. We need this for the equality check eos = bytes(eos, encoding='utf-8') if eos and eos in output: output = output[:output.index(eos)] if number_token: number_token_bytes = bytes(number_token, encoding='utf8') output = [ b"53" if word == number_token_bytes else word for word in output ] if name_token: name_token_bytes = bytes(name_token, encoding='utf8') output = [ b"Batman" if word == name_token_bytes else word for word in output ] if bpe_delimiter: response = utils.format_bpe_text(output, bpe_delimiter) else: response = utils.format_text(output) return response
def get_translation(nmt_outputs, sent_id, tgt_eos): """Given batch decoding outputs, select a sentence and turn to text.""" if tgt_eos: tgt_eos = tgt_eos.encode("utf-8") # Select a sentence output = nmt_outputs[sent_id, :].tolist() # If there is an eos symbol in outputs, cut them at that point. if tgt_eos and tgt_eos in output: output = output[:output.index(tgt_eos)] translation = utils.format_text(output) return translation
def get_translation(nmt_outputs, sent_id, tgt_eos, subword_option): """Given batch decoding outputs, select a sentence and turn to text.""" if tgt_eos: tgt_eos = tgt_eos.encode("utf-8") # Select a sentence output = nmt_outputs[sent_id, :].tolist() # If there is an eos symbol in outputs, cut them at that point. if tgt_eos and tgt_eos in output: output = output[:output.index(tgt_eos)] if subword_option == "bpe": # BPE translation = utils.format_bpe_text(output) elif subword_option == "spm": # SPM translation = utils.format_spm_text(output) else: translation = utils.format_text(output) return translation
def get_translation(nmt_outputs, sent_id, tgt_sos, tgt_eos, bpe_delimiter): """Given batch decoding outputs, select a sentence and turn to text.""" if tgt_sos: tgt_sos = tgt_sos.encode("utf-8") if tgt_eos: tgt_eos = tgt_eos.encode("utf-8") if bpe_delimiter: bpe_delimiter = bpe_delimiter.encode("utf-8") # Select a sentence output = nmt_outputs[sent_id, :].tolist() if tgt_sos and output[0] == tgt_sos: output = output[1:] # If there is an eos symbol in outputs, cut them at that point. if tgt_eos and tgt_eos in output: output = output[:output.index(tgt_eos)] if not bpe_delimiter: translation = utils.format_text(output) else: # BPE translation = utils.format_bpe_text(output, delimiter=bpe_delimiter) return translation