class TestVocabSize(unittest.TestCase):
    r"""Test case for `lmp.tokenizer.WhitespaceDictTokenizer.vocab_size`."""
    def setUp(self):
        r"""Setup both cased and uncased tokenizer instances."""
        self.cased_tokenizer = WhitespaceDictTokenizer()
        self.uncased_tokenizer = WhitespaceDictTokenizer(is_uncased=True)
        self.tokenizers = [self.cased_tokenizer, self.uncased_tokenizer]

    def tearDown(self):
        r"""Delete both cased and uncased tokenizer instances."""
        del self.tokenizers
        del self.cased_tokenizer
        del self.uncased_tokenizer
        gc.collect()

    def test_signature(self):
        r"""Ensure signature consistency."""
        msg = 'Inconsistent property signature.'

        self.assertTrue(inspect.isdatadescriptor(
            WhitespaceDictTokenizer.vocab_size),
                        msg=msg)
        self.assertFalse(inspect.isfunction(
            WhitespaceDictTokenizer.vocab_size),
                         msg=msg)
        self.assertFalse(inspect.ismethod(WhitespaceDictTokenizer.vocab_size),
                         msg=msg)

    def test_return_type(self):
        r"""Return `int`"""
        msg = 'Must return `int`.'

        for tokenizer in self.tokenizers:
            self.assertIsInstance(tokenizer.vocab_size, int, msg=msg)

    def test_return_value(self):
        r"""Return vocabulary size."""
        msg = 'Inconsistent vocabulary size.'

        for tokenizer in self.tokenizers:
            self.assertEqual(tokenizer.vocab_size, 4, msg=msg)

    def test_increase_vocab_size(self):
        r"""Increase vocabulary size after `build_vocab`."""
        msg = 'Must increase vocabulary size after `build_vocab`.'
        examples = (
            (('Hello World !', 'I am a LEGEND .', 'Hello legend !'), 9, 8),
            (('y = f(x)', ), 12, 11),
            (('', ), 12, 11),
        )

        sp_tokens_size = len(list(WhitespaceDictTokenizer.special_tokens()))

        for batch_sequences, cased_vocab_size, uncased_vocab_size in examples:
            self.cased_tokenizer.build_vocab(batch_sequences)
            self.assertEqual(self.cased_tokenizer.vocab_size,
                             cased_vocab_size + sp_tokens_size,
                             msg=msg)
            self.uncased_tokenizer.build_vocab(batch_sequences)
            self.assertEqual(self.uncased_tokenizer.vocab_size,
                             uncased_vocab_size + sp_tokens_size,
                             msg=msg)

    def test_reset_vocab_size(self):
        r"""Reset vocabulary size after `reset_vocab`."""
        msg = 'Must reset vocabulary size after `reset_vocab`.'
        examples = (
            ('HeLlO WoRlD!', 'I aM a LeGeNd.'),
            ('y = f(x)', ),
            ('', ),
        )

        sp_tokens_size = len(list(WhitespaceDictTokenizer.special_tokens()))

        for batch_sequences in examples:
            for tokenizer in self.tokenizers:
                tokenizer.build_vocab(batch_sequences)
                tokenizer.reset_vocab()
                self.assertEqual(tokenizer.vocab_size, sp_tokens_size, msg=msg)
class TestDetokenize(unittest.TestCase):
    r"""Test case for `lmp.tokenizer.WhitespaceDictTokenizer.detokenize`."""
    @classmethod
    def setUpClass(cls):
        cls.vocab_source = [
            'Hello World !',
            'I am a legend .',
            'Hello legend !',
        ]

    @classmethod
    def tearDownClass(cls):
        del cls.vocab_source
        gc.collect()

    def setUp(self):
        r"""Setup both cased and uncased tokenizer instances."""
        self.cased_tokenizer = WhitespaceDictTokenizer()
        self.cased_tokenizer.build_vocab(self.__class__.vocab_source)
        self.uncased_tokenizer = WhitespaceDictTokenizer(is_uncased=True)
        self.uncased_tokenizer.build_vocab(self.__class__.vocab_source)
        self.tokenizers = [self.cased_tokenizer, self.uncased_tokenizer]

    def tearDown(self):
        r"""Delete both cased and uncased tokenizer instances."""
        del self.tokenizers
        del self.cased_tokenizer
        del self.uncased_tokenizer
        gc.collect()

    def test_signature(self):
        r"""Ensure signature consistency."""
        msg = 'Inconsistent method signature.'

        self.assertEqual(
            inspect.signature(WhitespaceDictTokenizer.detokenize),
            inspect.Signature(parameters=[
                inspect.Parameter(name='self',
                                  kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
                                  default=inspect.Parameter.empty),
                inspect.Parameter(name='tokens',
                                  kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
                                  annotation=Iterable[str],
                                  default=inspect.Parameter.empty)
            ],
                              return_annotation=str),
            msg=msg)

    def test_invalid_input_tokens(self):
        r"""Raise `TypeError` when input `tokens` is invalid."""
        msg1 = 'Must raise `TypeError` when input `tokens` is invalid.'
        msg2 = 'Inconsistent error message.'
        examples = (
            False,
            True,
            0,
            1,
            -1,
            0.0,
            1.0,
            math.nan,
            -math.nan,
            math.inf,
            -math.inf,
            0j,
            1j,
            object(),
            lambda x: x,
            type,
            None,
            NotImplemented,
            ...,
            [False],
            [True],
            [0],
            [1],
            [-1],
            [0.0],
            [1.0],
            [math.nan],
            [-math.nan],
            [math.inf],
            [-math.inf],
            [0j],
            [1j],
            [b''],
            [()],
            [[]],
            [{}],
            [set()],
            [object()],
            [lambda x: x],
            [type],
            [None],
            [NotImplemented],
            [...],
            ['', False],
            ['', True],
            ['', 0],
            ['', 1],
            ['', -1],
            ['', 0.0],
            ['', 1.0],
            ['', math.nan],
            ['', -math.nan],
            ['', math.inf],
            ['', -math.inf],
            ['', 0j],
            ['', 1j],
            ['', b''],
            ['', ()],
            ['', []],
            ['', {}],
            ['', set()],
            ['', object()],
            ['', lambda x: x],
            ['', type],
            ['', None],
            ['', NotImplemented],
            ['', ...],
        )

        for invalid_input in examples:
            for tokenizer in self.tokenizers:
                with self.assertRaises(TypeError, msg=msg1) as ctx_man:
                    tokenizer.detokenize(tokens=invalid_input)

                self.assertEqual(
                    ctx_man.exception.args[0],
                    '`tokens` must be an instance of `Iterable[str]`.',
                    msg=msg2)

    def test_return_type(self):
        r"""Return `str`."""
        msg = 'Must return `str`.'
        examples = (
            ('HeLlO', 'WoRlD', '!'),
            (''),
            (),
        )

        for tokens in examples:
            for tokenizer in self.tokenizers:
                self.assertIsInstance(tokenizer.detokenize(tokens),
                                      str,
                                      msg=msg)

    def test_normalize(self):
        r"""Return sequence is normalized."""
        msg = 'Return sequence must be normalized.'
        examples = (
            (
                (' ', 'HeLlO', 'WoRlD', '!'),
                'HeLlO WoRlD !',
                'hello world !',
            ),
            (
                ('HeLlO', 'WoRlD', '!', ' '),
                'HeLlO WoRlD !',
                'hello world !',
            ),
            (
                (' ', ' ', 'HeLlO', ' ', ' ', 'WoRlD', '!', ' ', ' '),
                'HeLlO WoRlD !',
                'hello world !',
            ),
            (
                ('0'),
                '0',
                '0',
            ),
            (
                ('é'),
                unicodedata.normalize('NFKC', 'é'),
                unicodedata.normalize('NFKC', 'é'),
            ),
            (
                ('0', 'é'),
                unicodedata.normalize('NFKC', '0 é'),
                unicodedata.normalize('NFKC', '0 é'),
            ),
            (
                (),
                '',
                '',
            ),
        )

        for tokens, cased_sequence, uncased_sequence in examples:
            self.assertEqual(self.cased_tokenizer.detokenize(tokens),
                             cased_sequence,
                             msg=msg)
            self.assertEqual(self.uncased_tokenizer.detokenize(tokens),
                             uncased_sequence,
                             msg=msg)
Beispiel #3
0
class TestBuildVocab(unittest.TestCase):
    r"""Test case for `lmp.tokenizer.WhitespaceDictTokenizer.build_vocab`."""
    def setUp(self):
        r"""Setup both cased and uncased tokenizer instances."""
        self.cased_tokenizer = WhitespaceDictTokenizer()
        self.uncased_tokenizer = WhitespaceDictTokenizer(is_uncased=True)
        self.tokenizers = [self.cased_tokenizer, self.uncased_tokenizer]

    def tearDown(self):
        r"""Delete both cased and uncased tokenizer instances."""
        del self.tokenizers
        del self.cased_tokenizer
        del self.uncased_tokenizer
        gc.collect()

    def test_signature(self):
        r"""Ensure signature consistency."""
        msg = 'Inconsistent method signature.'

        self.assertEqual(
            inspect.signature(WhitespaceDictTokenizer.build_vocab),
            inspect.Signature(parameters=[
                inspect.Parameter(name='self',
                                  kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
                                  default=inspect.Parameter.empty),
                inspect.Parameter(name='batch_sequences',
                                  kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
                                  annotation=Iterable[str],
                                  default=inspect.Parameter.empty),
                inspect.Parameter(name='min_count',
                                  kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
                                  annotation=int,
                                  default=1),
            ],
                              return_annotation=None),
            msg=msg)

    def test_invalid_input_batch_sequences(self):
        r"""Raise `TypeError` when input `batch_sequences` is invalid."""
        msg1 = (
            'Must raise `TypeError` when input `batch_sequences` is invalid.')
        msg2 = 'Inconsistent error message.'
        examples = (
            False,
            True,
            0,
            1,
            -1,
            0.0,
            1.0,
            math.nan,
            -math.nan,
            math.inf,
            -math.inf,
            0j,
            1j,
            object(),
            lambda x: x,
            type,
            None,
            NotImplemented,
            ...,
            [False],
            [True],
            [0],
            [1],
            [-1],
            [0.0],
            [1.0],
            [math.nan],
            [-math.nan],
            [math.inf],
            [-math.inf],
            [0j],
            [1j],
            [b''],
            [()],
            [[]],
            [{}],
            [set()],
            [object()],
            [lambda x: x],
            [type],
            [None],
            [NotImplemented],
            [...],
            ['', False],
            ['', True],
            ['', 0],
            ['', 1],
            ['', -1],
            ['', 0.0],
            ['', 1.0],
            ['', math.nan],
            ['', -math.nan],
            ['', math.inf],
            ['', -math.inf],
            ['', 0j],
            ['', 1j],
            ['', b''],
            ['', ()],
            ['', []],
            ['', {}],
            ['', set()],
            ['', object()],
            ['', lambda x: x],
            ['', type],
            ['', None],
            ['', NotImplemented],
            ['', ...],
        )

        for invalid_input in examples:
            for tokenizer in self.tokenizers:
                with self.assertRaises(TypeError, msg=msg1) as cxt_man:
                    tokenizer.build_vocab(batch_sequences=invalid_input)

                self.assertEqual(cxt_man.exception.args[0],
                                 '`batch_sequences` must be an instance of '
                                 '`Iterable[str]`.',
                                 msg=msg2)

    def test_invalid_input_min_count(self):
        r"""Raise `TypeError` when input `min_count` is invalid."""
        msg1 = 'Must raise `TypeError` when input `min_count` is invalid.'
        msg2 = 'Inconsistent error message.'
        examples = (
            0.0,
            1.0,
            math.nan,
            -math.nan,
            math.inf,
            -math.inf,
            0j,
            1j,
            '',
            b'',
            (),
            [],
            {},
            set(),
            object(),
            lambda x: x,
            type,
            None,
            NotImplemented,
            ...,
        )

        for invalid_input in examples:
            for tokenizer in self.tokenizers:
                with self.assertRaises(TypeError, msg=msg1) as cxt_man:
                    tokenizer.build_vocab(batch_sequences=[],
                                          min_count=invalid_input)

                self.assertEqual(cxt_man.exception.args[0],
                                 '`min_count` must be an instance of `int`.',
                                 msg=msg2)

    def test_cased_sensitive(self):
        r"""Vocabulary must be case sensitive."""
        msg = 'Vocabulary must be case sensitive.'
        examples = (
            (('A B C D', 'a b c d'), 8, 4),
            (('e f g h i', 'E F G H I'), 10, 5),
        )

        sp_tokens_size = len(list(WhitespaceDictTokenizer.special_tokens()))

        for batch_sequences, cased_vocab_size, uncased_vocab_size in examples:
            self.cased_tokenizer.reset_vocab()
            self.cased_tokenizer.build_vocab(batch_sequences=batch_sequences)
            self.assertEqual(self.cased_tokenizer.vocab_size,
                             cased_vocab_size + sp_tokens_size,
                             msg=msg)
            self.uncased_tokenizer.reset_vocab()
            self.uncased_tokenizer.build_vocab(batch_sequences=batch_sequences)
            self.assertEqual(self.uncased_tokenizer.vocab_size,
                             uncased_vocab_size + sp_tokens_size,
                             msg=msg)

    def test_sort_by_token_frequency_in_descending_order(self):
        r"""Sort vocabulary by token frequency in descending order."""
        msg = ('Must sort vocabulary by token frequency in descending order.')
        examples = (
            (
                ('A a A a', 'b B b', 'c C', 'd'),
                ('A', 'a', 'b', 'B', 'c', 'C', 'd'),
                ('a', 'b', 'c', 'd'),
            ),
            (
                ('E e E e E', 'F f F f', 'G g G', 'H h', 'I'),
                ('E', 'e', 'F', 'f', 'G', 'g', 'H', 'h', 'I'),
                ('e', 'f', 'g', 'h', 'i'),
            ),
        )

        for (batch_sequences, cased_vocab_order,
             uncased_vocab_order) in examples:
            self.cased_tokenizer.reset_vocab()
            self.cased_tokenizer.build_vocab(batch_sequences=batch_sequences)

            for (vocab1, vocab2) in zip(cased_vocab_order[:-1],
                                        cased_vocab_order[1:]):
                self.assertLessEqual(
                    self.cased_tokenizer.convert_token_to_id(vocab1),
                    self.cased_tokenizer.convert_token_to_id(vocab2),
                    msg=msg)

            self.uncased_tokenizer.reset_vocab()
            self.uncased_tokenizer.build_vocab(batch_sequences=batch_sequences)

            for (vocab1, vocab2) in zip(uncased_vocab_order[:-1],
                                        uncased_vocab_order[1:]):
                self.assertLessEqual(
                    self.uncased_tokenizer.convert_token_to_id(vocab1),
                    self.uncased_tokenizer.convert_token_to_id(vocab2),
                    msg=msg)

    def test_min_count(self):
        r"""Filter out tokens whose frequency is smaller than `min_count`."""
        msg = ('Must filter out tokens whose frequency is smaller than '
               '`min_count`.')
        examples = (
            (
                ('A a A a', 'b B b', 'c C', 'd'),
                ('A', 'a', 'b'),
                ('B', 'c', 'C', 'd'),
                ('a', 'b', 'c'),
                ('d'),
                2,
            ),
            (
                ('E e E e E', 'F f F f', 'G g G', 'H h', 'I'),
                ('E'),
                ('e', 'F', 'f', 'G', 'g', 'H', 'h', 'I'),
                ('e', 'f', 'g'),
                ('h', 'i'),
                3,
            ),
            (
                ('E e E e E', 'F f F f', 'G g G', 'H h', 'I'),
                (),
                ('E', 'e', 'F', 'f', 'G', 'g', 'H', 'h', 'I'),
                (),
                ('e', 'f', 'g', 'h', 'i'),
                10,
            ),
        )

        for (batch_sequences, cased_known_token, cased_unknown_token,
             uncased_known_token, uncased_unknown_token,
             min_count) in examples:
            self.cased_tokenizer.reset_vocab()
            self.cased_tokenizer.build_vocab(batch_sequences=batch_sequences,
                                             min_count=min_count)

            for token in cased_known_token:
                token_id = self.cased_tokenizer.convert_token_to_id(token)
                self.assertEqual(
                    token,
                    self.cased_tokenizer.convert_id_to_token(token_id),
                    msg=msg)

            unk_token_id = self.cased_tokenizer.convert_token_to_id(
                WhitespaceDictTokenizer.unk_token)
            for unk_token in cased_unknown_token:
                self.assertEqual(
                    self.cased_tokenizer.convert_token_to_id(unk_token),
                    unk_token_id,
                    msg=msg)

            self.uncased_tokenizer.reset_vocab()
            self.uncased_tokenizer.build_vocab(batch_sequences=batch_sequences,
                                               min_count=min_count)

            for token in uncased_known_token:
                token_id = self.uncased_tokenizer.convert_token_to_id(token)
                self.assertEqual(
                    token,
                    self.uncased_tokenizer.convert_id_to_token(token_id),
                    msg=msg)

            unk_token_id = self.uncased_tokenizer.convert_token_to_id(
                WhitespaceDictTokenizer.unk_token)
            for unk_token in uncased_unknown_token:
                self.assertEqual(
                    self.uncased_tokenizer.convert_token_to_id(unk_token),
                    unk_token_id,
                    msg=msg)
Beispiel #4
0
class TestBatchDecode(unittest.TestCase):
    r"""Test case for `lmp.tokenizer.WhitespaceDictTokenizer.batch_decode`."""
    @classmethod
    def setUpClass(cls):
        cls.vocab_source = [
            'Hello World !',
            'I am a legend .',
            'Hello legend !',
        ]

    @classmethod
    def tearDownClass(cls):
        del cls.vocab_source
        gc.collect()

    def setUp(self):
        r"""Setup both cased and uncased tokenizer instances."""
        self.cased_tokenizer = WhitespaceDictTokenizer()
        self.cased_tokenizer.build_vocab(self.__class__.vocab_source)
        self.uncased_tokenizer = WhitespaceDictTokenizer(is_uncased=True)
        self.uncased_tokenizer.build_vocab(self.__class__.vocab_source)
        self.tokenizers = [self.cased_tokenizer, self.uncased_tokenizer]

    def tearDown(self):
        r"""Delete both cased and uncased tokenizer instances."""
        del self.tokenizers
        del self.cased_tokenizer
        del self.uncased_tokenizer
        gc.collect()

    def test_signature(self):
        r"""Ensure signature consistency."""
        msg = 'Inconsistent method signature.'

        self.assertEqual(
            inspect.signature(WhitespaceDictTokenizer.batch_decode),
            inspect.Signature(parameters=[
                inspect.Parameter(
                    name='self',
                    kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
                ),
                inspect.Parameter(name='batch_token_ids',
                                  kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
                                  annotation=Iterable[Iterable[int]],
                                  default=inspect.Parameter.empty),
                inspect.Parameter(name='remove_special_tokens',
                                  kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
                                  annotation=bool,
                                  default=False)
            ],
                              return_annotation=List[str]),
            msg=msg)

    def test_invalid_input_batch_token_ids(self):
        r"""Raise `TypeError` when input `batch_token_ids` is invalid."""
        msg1 = (
            'Must raise `TypeError` when input `batch_token_ids` is invalid.')
        msg2 = 'Inconsistent error message.'
        examples = (
            False,
            True,
            0,
            1,
            -1,
            0.0,
            1.0,
            math.nan,
            -math.nan,
            math.inf,
            -math.inf,
            0j,
            1j,
            object(),
            lambda x: x,
            type,
            None,
            NotImplemented,
            ...,
            [False],
            [True],
            [0],
            [1],
            [-1],
            [0.0],
            [1.0],
            [math.nan],
            [-math.nan],
            [math.inf],
            [-math.inf],
            [0j],
            [1j],
            [object()],
            [lambda x: x],
            [type],
            [None],
            [NotImplemented],
            [...],
            [[], False],
            [[], True],
            [[], 0],
            [[], 1],
            [[], -1],
            [[], 0.0],
            [[], 1.0],
            [[], math.nan],
            [[], -math.nan],
            [[], math.inf],
            [[], -math.inf],
            [[], 0j],
            [[], 1j],
            [[], object()],
            [[], lambda x: x],
            [[], type],
            [[], None],
            [[], NotImplemented],
            [[], ...],
            [[0.0]],
            [[1.0]],
            [[math.nan]],
            [[-math.nan]],
            [[math.inf]],
            [[-math.inf]],
            [[0j]],
            [[1j]],
            [['']],
            [[b'']],
            [[()]],
            [[[]]],
            [[{}]],
            [[set()]],
            [[object()]],
            [[lambda x: x]],
            [[type]],
            [[None]],
            [[NotImplemented]],
            [[...]],
            [[0, 0.0]],
            [[0, 1.0]],
            [[0, math.nan]],
            [[0, -math.nan]],
            [[0, math.inf]],
            [[0, -math.inf]],
            [[0, 0j]],
            [[0, 1j]],
            [[0, '']],
            [[0, b'']],
            [[0, ()]],
            [[0, []]],
            [[0, {}]],
            [[0, set()]],
            [[0, object()]],
            [[0, lambda x: x]],
            [[0, type]],
            [[0, None]],
            [[0, NotImplemented]],
            [[0, ...]],
        )

        for invalid_input in examples:
            for tokenizer in self.tokenizers:
                with self.assertRaises(TypeError, msg=msg1) as cxt_man:
                    tokenizer.batch_decode(batch_token_ids=invalid_input)

                self.assertEqual(cxt_man.exception.args[0],
                                 '`batch_token_ids` must be an instance of '
                                 '`Iterable[Iterable[int]]`.',
                                 msg=msg2)

    def test_invalid_input_remove_special_tokens(self):
        r"""Raise `TypeError` when input `remove_special_tokens` is invalid."""
        msg1 = ('Must raise `TypeError` when input `remove_special_tokens` is '
                'invalid.')
        msg2 = 'Inconsistent error message.'
        examples = (
            0,
            1,
            -1,
            0.0,
            1.0,
            math.nan,
            -math.nan,
            math.inf,
            -math.inf,
            0j,
            1j,
            '',
            b'',
            0j,
            1j,
            (),
            [],
            {},
            set(),
            object(),
            lambda x: x,
            type,
            None,
            NotImplemented,
            ...,
        )

        for invalid_input in examples:
            for tokenizer in self.tokenizers:
                with self.assertRaises(TypeError, msg=msg1) as cxt_man:
                    tokenizer.batch_decode(batch_token_ids=[[]],
                                           remove_special_tokens=invalid_input)

                self.assertEqual(
                    cxt_man.exception.args[0],
                    '`remove_special_tokens` must be an instance of `bool`.',
                    msg=msg2)

    def test_return_type(self):
        r"""Return `List[str]`."""
        msg = 'Must return `List[str]`.'
        examples = (
            [[0, 1, 2, 3], [4, 5, 6, 7, 8]],
            [[9, 10, 11, 12, 13], []],
            [[], [14, 15, 16, 17]],
            [[], []],
            [],
        )

        for batch_token_ids in examples:
            for tokenizer in self.tokenizers:
                batch_sequences = tokenizer.batch_decode(
                    batch_token_ids=batch_token_ids)
                self.assertIsInstance(batch_sequences, list, msg=msg)
                for sequence in batch_sequences:
                    self.assertIsInstance(sequence, str, msg=msg)

    def test_remove_special_tokens(self):
        r"""Remove special tokens."""
        msg = 'Must remove special tokens.'
        examples = (
            (
                False,
                [
                    [0, 4, 7, 5, 1, 2],
                    [0, 8, 9, 10, 3, 1, 2, 2],
                    [0, 3, 6, 11, 1],
                ],
                [
                    '[bos] Hello World ! [eos] [pad]',
                    '[bos] I am a [unk] [eos] [pad] [pad]',
                    '[bos] [unk] legend . [eos]',
                ],
                [
                    '[bos] hello world ! [eos] [pad]',
                    '[bos] i am a [unk] [eos] [pad] [pad]',
                    '[bos] [unk] legend . [eos]',
                ],
            ),
            (
                True,
                [
                    [0, 4, 7, 5, 1, 2],
                    [0, 8, 9, 10, 3, 1, 2, 2],
                    [0, 3, 6, 11, 1],
                ],
                [
                    'Hello World !',
                    'I am a [unk]',
                    '[unk] legend .',
                ],
                [
                    'hello world !',
                    'i am a [unk]',
                    '[unk] legend .',
                ],
            ),
        )

        for (remove_special_tokens, batch_token_ids, cased_batch_sequence,
             uncased_batch_sequence) in examples:
            self.assertEqual(self.cased_tokenizer.batch_decode(
                batch_token_ids=batch_token_ids,
                remove_special_tokens=remove_special_tokens),
                             cased_batch_sequence,
                             msg=msg)
            self.assertEqual(self.uncased_tokenizer.batch_decode(
                batch_token_ids=batch_token_ids,
                remove_special_tokens=remove_special_tokens),
                             uncased_batch_sequence,
                             msg=msg)
class TestTokenize(unittest.TestCase):
    r"""Test case for `lmp.tokenizer.WhitespaceDictTokenizer.tokenize`."""

    @classmethod
    def setUpClass(cls):
        cls.vocab_source = [
            'Hello World !',
            'I am a legend .',
            'Hello legend !',
        ]

    @classmethod
    def tearDownClass(cls):
        del cls.vocab_source
        gc.collect()

    def setUp(self):
        r"""Setup both cased and uncased tokenizer instances."""
        self.cased_tokenizer = WhitespaceDictTokenizer()
        self.cased_tokenizer.build_vocab(self.__class__.vocab_source)
        self.uncased_tokenizer = WhitespaceDictTokenizer(is_uncased=True)
        self.uncased_tokenizer.build_vocab(self.__class__.vocab_source)
        self.tokenizers = [self.cased_tokenizer, self.uncased_tokenizer]

    def tearDown(self):
        r"""Delete both cased and uncased tokenizer instances."""
        del self.tokenizers
        del self.cased_tokenizer
        del self.uncased_tokenizer
        gc.collect()

    def test_signature(self):
        r"""Ensure signature consistency."""
        msg = 'Inconsistent method signature.'

        self.assertEqual(
            inspect.signature(WhitespaceDictTokenizer.tokenize),
            inspect.Signature(
                parameters=[
                    inspect.Parameter(
                        name='self',
                        kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
                        default=inspect.Parameter.empty
                    ),
                    inspect.Parameter(
                        name='sequence',
                        kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
                        annotation=str,
                        default=inspect.Parameter.empty
                    )
                ],
                return_annotation=List[str]
            ),
            msg=msg
        )

    def test_invalid_input_sequence(self):
        r"""Raise `TypeError` when input `sequence` is invalid."""
        msg1 = 'Must raise `TypeError` when input `sequence` is invalid.'
        msg2 = 'Inconsistent error message.'
        examples = (
            False, True, 0, 1, -1, 0.0, 1.0, math.nan, -math.nan, math.inf,
            -math.inf, b'', 0j, 1j, (), [], {}, set(), object(), lambda x: x,
            type, None, NotImplemented, ...,
        )

        for invalid_input in examples:
            for tokenizer in self.tokenizers:
                with self.assertRaises(TypeError, msg=msg1) as ctx_man:
                    tokenizer.tokenize(invalid_input)

                self.assertEqual(
                    ctx_man.exception.args[0],
                    '`sequence` must be an instance of `str`.',
                    msg=msg2
                )

    def test_return_type(self):
        r"""Return `List[str]`."""
        msg = 'Must return `List[str]`.'
        examples = (
            'Hello World !',
            'H',
            '',
        )

        for sequence in examples:
            for tokenizer in self.tokenizers:
                tokens = tokenizer.tokenize(sequence)
                self.assertIsInstance(tokens, list, msg=msg)
                for token in tokens:
                    self.assertIsInstance(token, str, msg=msg)

    def test_normalize(self):
        r"""Return sequence is normalized."""
        msg = 'Return sequence must be normalized.'
        examples = (
            (
                ' HeLlO WoRlD !',
                ['HeLlO', 'WoRlD', '!'],
                ['hello', 'world', '!'],
            ),
            (
                'HeLlO WoRlD ! ',
                ['HeLlO', 'WoRlD', '!'],
                ['hello', 'world', '!'],
            ),
            (
                '  HeLlO  WoRlD !  ',
                ['HeLlO', 'WoRlD', '!'],
                ['hello', 'world', '!'],
            ),
            (
                '0',
                ['0'],
                ['0'],
            ),
            (
                'é',
                [unicodedata.normalize('NFKC', 'é')],
                [unicodedata.normalize('NFKC', 'é')],
            ),
            (
                '0 é',
                [
                    unicodedata.normalize('NFKC', '0'),
                    unicodedata.normalize('NFKC', 'é'),
                ],
                [
                    unicodedata.normalize('NFKC', '0'),
                    unicodedata.normalize('NFKC', 'é'),
                ],
            ),
            (
                '',
                [],
                [],
            ),
        )

        for sequence, cased_tokens, uncased_tokens in examples:
            self.assertEqual(
                self.cased_tokenizer.tokenize(sequence),
                cased_tokens,
                msg=msg
            )
            self.assertEqual(
                self.uncased_tokenizer.tokenize(sequence),
                uncased_tokens,
                msg=msg
            )
class TestBatchEncode(unittest.TestCase):
    r"""Test case for `lmp.tokenizer.WhitespaceDictTokenizer.batch_encode`."""
    @classmethod
    def setUpClass(cls):
        cls.vocab_source = [
            'Hello World !',
            'I am a legend .',
            'Hello legend !',
        ]

    @classmethod
    def tearDownClass(cls):
        del cls.vocab_source
        gc.collect()

    def setUp(self):
        r"""Setup both cased and uncased tokenizer instances."""
        self.cased_tokenizer = WhitespaceDictTokenizer()
        self.cased_tokenizer.build_vocab(self.__class__.vocab_source)
        self.uncased_tokenizer = WhitespaceDictTokenizer(is_uncased=True)
        self.uncased_tokenizer.build_vocab(self.__class__.vocab_source)
        self.tokenizers = [self.cased_tokenizer, self.uncased_tokenizer]

    def tearDown(self):
        r"""Delete both cased and uncased tokenizer instances."""
        del self.tokenizers
        del self.cased_tokenizer
        del self.uncased_tokenizer
        gc.collect()

    def test_signature(self):
        r"""Ensure signature consistency."""
        msg = 'Inconsistent method signature.'

        self.assertEqual(
            inspect.signature(WhitespaceDictTokenizer.batch_encode),
            inspect.Signature(parameters=[
                inspect.Parameter(
                    name='self',
                    kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
                ),
                inspect.Parameter(name='batch_sequences',
                                  kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
                                  annotation=Iterable[str],
                                  default=inspect.Parameter.empty),
                inspect.Parameter(name='max_seq_len',
                                  kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
                                  annotation=int,
                                  default=-1)
            ],
                              return_annotation=List[List[int]]),
            msg=msg)

    def test_invalid_input_batch_sequences(self):
        r"""Raise `TypeError` when input `batch_sequences` is invalid."""
        msg1 = (
            'Must raise `TypeError` when input `batch_sequences` is invalid.')
        msg2 = 'Inconsistent error message.'
        examples = (
            False,
            True,
            0,
            1,
            -1,
            0.0,
            1.0,
            math.nan,
            -math.nan,
            math.inf,
            -math.inf,
            0j,
            1j,
            object(),
            lambda x: x,
            type,
            None,
            NotImplemented,
            ...,
            [False],
            [True],
            [0],
            [1],
            [-1],
            [0.0],
            [1.0],
            [math.nan],
            [-math.nan],
            [math.inf],
            [-math.inf],
            [0j],
            [1j],
            [b''],
            [()],
            [[]],
            [{}],
            [set()],
            [object()],
            [lambda x: x],
            [type],
            [None],
            [NotImplemented],
            [...],
            ['', False],
            ['', True],
            ['', 0],
            ['', 1],
            ['', -1],
            ['', 0.0],
            ['', 1.0],
            ['', math.nan],
            ['', -math.nan],
            ['', math.inf],
            ['', -math.inf],
            ['', 0j],
            ['', 1j],
            ['', b''],
            ['', ()],
            ['', []],
            ['', {}],
            ['', set()],
            ['', object()],
            ['', lambda x: x],
            ['', type],
            ['', None],
            ['', NotImplemented],
            ['', ...],
        )

        for invalid_input in examples:
            for tokenizer in self.tokenizers:
                with self.assertRaises(TypeError, msg=msg1) as cxt_man:
                    tokenizer.batch_encode(batch_sequences=invalid_input)

                self.assertEqual(
                    cxt_man.exception.args[0],
                    '`batch_sequences` must be an instance of `Iterable[str]`.',
                    msg=msg2)

    def test_invalid_input_max_seq_len(self):
        r"""Raise exception when input `max_seq_len` is invalid."""
        msg1 = (
            'Must raise `TypeError` or `ValueError` when input `max_seq_len` '
            'is invalid.')
        msg2 = 'Inconsistent error message.'
        examples = (
            False,
            True,
            0,
            1,
            -2,
            0.0,
            1.0,
            math.nan,
            -math.nan,
            math.inf,
            -math.inf,
            0j,
            1j,
            '',
            b'',
            (),
            [],
            {},
            set(),
            object(),
            lambda x: x,
            type,
            None,
            NotImplemented,
            ...,
        )

        for invalid_input in examples:
            for tokenizer in self.tokenizers:
                with self.assertRaises((TypeError, ValueError),
                                       msg=msg1) as cxt_man:
                    tokenizer.batch_encode(batch_sequences=[''],
                                           max_seq_len=invalid_input)

                if isinstance(cxt_man.exception, TypeError):
                    self.assertEqual(
                        cxt_man.exception.args[0],
                        '`max_seq_len` must be an instance of `int`.',
                        msg=msg2)
                else:
                    self.assertEqual(
                        cxt_man.exception.args[0],
                        '`max_seq_len` must be greater than `1` or equal to '
                        '`-1`.',
                        msg=msg2)

    def test_return_type(self):
        r"""Return `List[List[int]]`."""
        msg = 'Must return `List[List[int]]`.'
        examples = (
            ['Hello World!', 'I am a legend.', 'y = f(x)'],
            ['Hello World!', '', ''],
            ['', 'I am a legend.', ''],
            ['', '', 'y = f(x)'],
            ['', '', ''],
            [],
        )

        for batch_sequences in examples:
            for tokenizer in self.tokenizers:
                batch_token_ids = tokenizer.batch_encode(
                    batch_sequences=batch_sequences)
                self.assertIsInstance(batch_token_ids, list, msg=msg)
                for token_ids in batch_token_ids:
                    self.assertIsInstance(token_ids, list, msg=msg)
                    for token_id in token_ids:
                        self.assertIsInstance(token_id, int, msg=msg)

    def test_encode_format(self):
        r"""Follow encode format."""
        msg = ('Must follow encode format: '
               '[bos] t1 t2 ... tn [eos] [pad] ... [pad].')
        examples = (
            (
                ['Hello World !', 'I am a legend .', 'y = f(x)'],
                [
                    [0, 4, 7, 5, 1, 2, 2],
                    [0, 8, 9, 10, 6, 11, 1],
                    [0, 3, 3, 3, 1, 2, 2],
                ],
            ),
            (
                ['Hello World !', '', ''],
                [
                    [0, 4, 7, 5, 1],
                    [0, 1, 2, 2, 2],
                    [0, 1, 2, 2, 2],
                ],
            ),
            (
                ['', 'I am a legend .', ''],
                [
                    [0, 1, 2, 2, 2, 2, 2],
                    [0, 8, 9, 10, 6, 11, 1],
                    [0, 1, 2, 2, 2, 2, 2],
                ],
            ),
            (
                ['', '', 'y = f(x)'],
                [
                    [0, 1, 2, 2, 2],
                    [0, 1, 2, 2, 2],
                    [0, 3, 3, 3, 1],
                ],
            ),
            (
                ['', '', ''],
                [
                    [0, 1],
                    [0, 1],
                    [0, 1],
                ],
            ),
            (
                [],
                [],
            ),
        )

        for batch_sequences, batch_token_ids in examples:
            for tokenizer in self.tokenizers:
                self.assertEqual(
                    tokenizer.batch_encode(batch_sequences=batch_sequences),
                    batch_token_ids,
                    msg=msg)

    def test_truncate(self):
        r"""Batch token ids' length must not exceed `max_seq_len`."""
        msg = 'Token ids\' length must not exceed `max_seq_len`.'
        examples = (
            (
                ['Hello World !', 'I am a legend .', 'y = f(x)'],
                [
                    [0, 4, 7, 5, 1],
                    [0, 8, 9, 10, 1],
                    [0, 3, 3, 3, 1],
                ],
                5,
            ),
            (
                ['Hello World !', 'I am a legend .', 'y = f(x)'],
                [
                    [0, 4, 7, 1],
                    [0, 8, 9, 1],
                    [0, 3, 3, 1],
                ],
                4,
            ),
            (
                ['Hello World !', 'I am a legend .', 'y = f(x)'],
                [
                    [0, 1],
                    [0, 1],
                    [0, 1],
                ],
                2,
            ),
            (
                ['', '', ''],
                [
                    [0, 1],
                    [0, 1],
                    [0, 1],
                ],
                2,
            ),
            ([], [], 2),
        )

        for batch_sequences, batch_token_ids, max_seq_len in examples:
            for tokenizer in self.tokenizers:
                self.assertEqual(tokenizer.batch_encode(
                    batch_sequences=batch_sequences, max_seq_len=max_seq_len),
                                 batch_token_ids,
                                 msg=msg)

    def test_padding(self):
        r"""Batch token ids' length must pad to `max_seq_len`."""
        msg = 'Token ids\' length must pad to `max_seq_len`.'
        examples = (
            (
                ['Hello World !', 'I am a legend .', 'y = f(x)'],
                [
                    [0, 4, 7, 5, 1, 2, 2, 2],
                    [0, 8, 9, 10, 6, 11, 1, 2],
                    [0, 3, 3, 3, 1, 2, 2, 2],
                ],
                8,
            ),
            (
                ['Hello World !', '', ''],
                [
                    [0, 4, 7, 5, 1],
                    [0, 1, 2, 2, 2],
                    [0, 1, 2, 2, 2],
                ],
                5,
            ),
            (
                ['', 'I am a legend .', ''],
                [
                    [0, 1, 2, 2, 2, 2, 2],
                    [0, 8, 9, 10, 6, 11, 1],
                    [0, 1, 2, 2, 2, 2, 2],
                ],
                7,
            ),
            (
                ['', '', 'y = f(x)'],
                [
                    [0, 1, 2, 2, 2, 2],
                    [0, 1, 2, 2, 2, 2],
                    [0, 3, 3, 3, 1, 2],
                ],
                6,
            ),
            (
                ['', '', ''],
                [
                    [0, 1, 2, 2, 2, 2, 2, 2, 2, 2],
                    [0, 1, 2, 2, 2, 2, 2, 2, 2, 2],
                    [0, 1, 2, 2, 2, 2, 2, 2, 2, 2],
                ],
                10,
            ),
            ([], [], 100),
        )

        for batch_sequences, batch_token_ids, max_seq_len in examples:
            for tokenizer in self.tokenizers:
                self.assertEqual(tokenizer.batch_encode(
                    batch_sequences=batch_sequences, max_seq_len=max_seq_len),
                                 batch_token_ids,
                                 msg=msg)