def test_pad_decr(self):
     self.assertEqual([2, 1, 0], text_encoder.pad_decr([3, 2, 1]))
     self.assertEqual([2, 1, 0], text_encoder.pad_decr([3, 2, 1, 0, 0, 0]))
     self.assertEqual([-1, 2, 1, 0],
                      text_encoder.pad_decr([0, 3, 2, 1, 0, 0]))
     self.assertEqual([], text_encoder.pad_decr([]))
     self.assertEqual([], text_encoder.pad_decr(np.array([])))
Exemple #2
0
  def decode(self, ids):
    """Decodes a list of integers into text."""
    ids = text_encoder.pad_decr(ids)
    subword_ids = ids
    del ids

    subwords = []

    # Some ids correspond to bytes. Because unicode characters are composed of
    # possibly multiple bytes, we attempt to decode contiguous lists of bytes
    # all together. Invalid byte sequences are replaced with the unicode
    # replacement (i.e. unknown) character U+FFFD.
    prev_bytes = []

    def consume_prev_bytes():
      if prev_bytes:
        bytestr = b"".join(prev_bytes)
        bytes_text = bytestr.decode("utf-8", "replace")
        subwords.append(bytes_text)
      return []

    for subword_id in subword_ids:
      subword = self._id_to_subword(subword_id)
      if isinstance(subword, six.binary_type):
        # Byte-encoded
        prev_bytes.append(subword)
      else:
        # If there were bytes previously, convert to unicode.
        prev_bytes = consume_prev_bytes()
        trimmed, add_space = _trim_underscore_and_tell(subword)
        subwords.append(trimmed)
        if add_space:
          subwords.append(" ")
    # If there were trailing bytes, convert to unicode.
    prev_bytes = consume_prev_bytes()

    return tf.compat.as_text("".join(subwords))