Exemplo n.º 1
0
 def test_pad_decr(self):
     self.assertEqual([2, 1, 0], text_encoder.pad_decr([3, 2, 1]))
     self.assertEqual([2, 1, 0], text_encoder.pad_decr([3, 2, 1, 0, 0, 0]))
     self.assertEqual([-1, 2, 1, 0],
                      text_encoder.pad_decr([0, 3, 2, 1, 0, 0]))
     self.assertEqual([], text_encoder.pad_decr([]))
     self.assertEqual([], text_encoder.pad_decr(np.array([])))
Exemplo n.º 2
0
    def decode(self, ids):
        """Decodes a list of integers into text."""
        ids = text_encoder.pad_decr(ids)
        subword_ids = ids
        del ids

        subwords = []

        # Some ids correspond to bytes. Because unicode characters are composed of
        # possibly multiple bytes, we attempt to decode contiguous lists of bytes
        # all together. Invalid byte sequences are replaced with the unicode
        # replacement (i.e. unknown) character U+FFFD.
        prev_bytes = []

        def consume_prev_bytes():
            if prev_bytes:
                bytestr = b"".join(prev_bytes)
                bytes_text = bytestr.decode("utf-8", "replace")
                subwords.append(bytes_text)
            return []

        for subword_id in subword_ids:
            subword = self._id_to_subword(subword_id)
            if isinstance(subword, six.binary_type):
                # Byte-encoded
                prev_bytes.append(subword)
            else:
                # If there were bytes previously, convert to unicode.
                prev_bytes = consume_prev_bytes()
                trimmed, add_space = _trim_underscore_and_tell(subword)
                subwords.append(trimmed)
                if add_space:
                    subwords.append(" ")
        # If there were trailing bytes, convert to unicode.
        prev_bytes = consume_prev_bytes()

        return tf.compat.as_text("".join(subwords))