def test_return_type(self):
        r"""Return `int`."""
        msg = 'Must return `int`.'
        examples = (
            [
                'Hello',
                'World',
                'Hello World',
            ],
            [
                'Mario use Kimura Lock on Luigi, and Luigi tap out.',
                'Mario use Superman Punch.',
                'Luigi get TKO.',
                'Toad and Toadette are fightting over mushroom (weed).',
            ],
            [''],
            [],
        )

        for batch_sequences in examples:
            self.assertIsInstance(
                len(BaseDataset(batch_sequences=batch_sequences)),
                int,
                msg=msg
            )
    def test_yield_value(self):
        r"""Is an iterable which yield sequences in order."""
        msg = 'Must be an iterable which yield sequences in order.'
        examples = (
            [
                'Hello',
                'World',
                'Hello World',
            ],
            [
                'Mario use Kimura Lock on Luigi, and Luigi tap out.',
                'Mario use Superman Punch.',
                'Luigi get TKO.',
                'Toad and Toadette are fightting over mushroom (weed).',
            ],
            [''],
            [],
        )

        for batch_sequences in examples:
            dataset = BaseDataset(batch_sequences=batch_sequences)
            self.assertIsInstance(dataset, Iterable, msg=msg)

            for ans_sequence, sequence in zip(batch_sequences, dataset):
                self.assertIsInstance(sequence, str, msg=msg)
                self.assertEqual(sequence, ans_sequence, msg=msg)
    def test_return_type(self):
        r"""Return `collate_fn`."""
        msg = 'Must return `collate_fn`.'
        examples = (
            CharDictTokenizer,
            CharListTokenizer,
            WhitespaceDictTokenizer,
            WhitespaceListTokenizer,
        )

        for tokenizer_class in examples:
            collate_fn = BaseDataset(
                []).create_collate_fn(tokenizer=tokenizer_class())
            self.assertTrue(inspect.isfunction(collate_fn))
            self.assertEqual(
                inspect.signature(collate_fn),
                inspect.Signature(parameters=[
                    inspect.Parameter(
                        name='batch_sequences',
                        kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
                        annotation=Iterable[str],
                        default=inspect.Parameter.empty),
                ],
                                  return_annotation=Tuple[torch.Tensor,
                                                          torch.Tensor]),
                msg=msg)
    def test_instance_attributes(self):
        r"""Declare required instance attributes."""
        msg1 = 'Missing instance attribute `{}`.'
        msg2 = 'Instance attribute `{}` must be an instance of `{}`.'
        examples = (('batch_sequences', list), )

        for attr, attr_type in examples:
            dataset = BaseDataset(batch_sequences=[])
            self.assertTrue(hasattr(dataset, attr), msg=msg1.format(attr))
            self.assertIsInstance(getattr(dataset, attr),
                                  attr_type,
                                  msg=msg2.format(attr, attr_type.__name__))
    def test_invalid_input_max_seq_len(self):
        r"""Raise exception when input `max_seq_len` is invalid."""
        msg1 = (
            'Must raise `TypeError` or `ValueError` when input `max_seq_len` '
            'is invalid.')
        msg2 = 'Inconsistent error message.'
        examples = (
            False,
            True,
            0,
            1,
            -2,
            0.0,
            1.0,
            math.nan,
            -math.nan,
            math.inf,
            -math.inf,
            0j,
            1j,
            '',
            b'',
            (),
            [],
            {},
            set(),
            object(),
            lambda x: x,
            type,
            None,
            NotImplemented,
            ...,
        )

        for invalid_input in examples:
            with self.assertRaises((TypeError, ValueError),
                                   msg=msg1) as cxt_man:
                BaseDataset([]).create_collate_fn(
                    tokenizer=CharDictTokenizer(), max_seq_len=invalid_input)

            if isinstance(cxt_man.exception, TypeError):
                self.assertEqual(cxt_man.exception.args[0],
                                 '`max_seq_len` must be an instance of `int`.',
                                 msg=msg2)
            else:
                self.assertEqual(
                    cxt_man.exception.args[0],
                    '`max_seq_len` must be greater than `1` or equal to '
                    '`-1`.',
                    msg=msg2)
    def test_invalid_input_tokenizer(self):
        r"""Raise `TypeError` when input `tokenizer` is invalid."""
        msg1 = 'Must raise `TypeError` when input `tokenizer` is invalid.'
        msg2 = 'Inconsistent error message.'
        examples = (False, True, 0, 1, -1, 0.0, 1.0, math.nan, -math.nan,
                    math.inf, -math.inf, 0j, 1j, '', b'', (), [], {}, set(),
                    object(), lambda x: x, type, None, NotImplemented, ...)

        for invalid_input in examples:
            with self.assertRaises(TypeError, msg=msg1) as ctx_man:
                BaseDataset([]).create_collate_fn(tokenizer=invalid_input)

            self.assertEqual(
                ctx_man.exception.args[0],
                '`tokenizer` must be an instance of `lmp.tokenizer.BaseTokenizer`.',
                msg=msg2)
예제 #7
0
    def setUp(self):
        r"""Setup `collate_fn` instances."""
        self.collate_fn_objs = []

        cls = self.__class__
        for is_uncased in cls.is_uncased_range:
            for max_seq_len in cls.max_seq_len_range:
                for tokenizer_class in cls.tokenizer_class_range:
                    self.collate_fn_objs.append({
                        'collate_fn': BaseDataset.create_collate_fn(
                            tokenizer=tokenizer_class(is_uncased=is_uncased),
                            max_seq_len=max_seq_len
                        ),
                        'is_uncased': is_uncased,
                        'max_seq_len': max_seq_len,
                        'tokeizer_class': tokenizer_class,
                    })
예제 #8
0
    def test_invalid_input_index(self):
        r"""Raise `IndexError` or `TypeError` when input `index` is invalid."""
        msg1 = (
            'Must raise `IndexError` or `TypeError` when input `index` is invalid.'
        )
        msg2 = 'Inconsistent error message.'
        examples = (-1, 0.0, 1.0, math.nan, -math.nan,
                    math.inf, -math.inf, 0j, 1j, '', b'', (), [], {}, set(),
                    object(), lambda x: x, type, None, NotImplemented, ...)

        for invalid_input in examples:
            with self.assertRaises((IndexError, TypeError),
                                   msg=msg1) as ctx_man:
                BaseDataset([])[invalid_input]

            if isinstance(ctx_man.exception, TypeError):
                self.assertEqual(ctx_man.exception.args[0],
                                 '`index` must be an instance of `int`.',
                                 msg=msg2)
            else:
                self.assertIsInstance(ctx_man.exception, IndexError)
예제 #9
0
    def test_return_value(self):
        r"""Sample single sequence using index."""
        msg = 'Must sample single sequence using index.'
        examples = (
            [
                'Hello',
                'World',
                'Hello World',
            ],
            [
                'Mario use Kimura Lock on Luigi, and Luigi tap out.',
                'Mario use Superman Punch.',
                'Luigi get TKO.',
                'Toad and Toadette are fightting over mushroom (weed).',
            ],
            [''],
            [],
        )

        for batch_sequences in examples:
            dataset = BaseDataset(batch_sequences=batch_sequences)
            for i in range(len(dataset)):
                self.assertEqual(dataset[i], batch_sequences[i], msg=msg)
    def test_return_dataset_size(self):
        r"""Return dataset size."""
        msg = 'Must return dataset size.'
        examples = (
            (
                [
                    'Hello',
                    'World',
                    'Hello World',
                ],
                3,
            ),
            (
                [
                    'Mario use Kimura Lock on Luigi, and Luigi tap out.',
                    'Mario use Superman Punch.',
                    'Luigi get TKO.',
                    'Toad and Toadette are fightting over mushroom (weed).',
                ],
                4,
            ),
            (
                [''],
                1
            ),
            (
                [],
                0
            ),
        )

        for batch_sequences, dataset_size in examples:
            self.assertEqual(
                len(BaseDataset(batch_sequences=batch_sequences)),
                dataset_size,
                msg=msg
            )
    def test_invalid_input_batch_sequences(self):
        r"""Raise `TypeError` when input `batch_sequences` is invalid."""
        msg1 = (
            'Must raise `TypeError` when input `batch_sequences` is invalid.')
        msg2 = 'Inconsistent error message.'
        examples = (
            False,
            True,
            0,
            1,
            -1,
            0.0,
            1.0,
            math.nan,
            -math.nan,
            math.inf,
            -math.inf,
            0j,
            1j,
            object(),
            lambda x: x,
            type,
            None,
            NotImplemented,
            ...,
            [False],
            [True],
            [0],
            [1],
            [-1],
            [0.0],
            [1.0],
            [math.nan],
            [-math.nan],
            [math.inf],
            [-math.inf],
            [0j],
            [1j],
            [b''],
            [()],
            [[]],
            [{}],
            [set()],
            [object()],
            [lambda x: x],
            [type],
            [None],
            [NotImplemented],
            [...],
            ['', False],
            ['', True],
            ['', 0],
            ['', 1],
            ['', -1],
            ['', 0.0],
            ['', 1.0],
            ['', math.nan],
            ['', -math.nan],
            ['', math.inf],
            ['', -math.inf],
            ['', 0j],
            ['', 1j],
            ['', b''],
            ['', ()],
            ['', []],
            ['', {}],
            ['', set()],
            ['', object()],
            ['', lambda x: x],
            ['', type],
            ['', None],
            ['', NotImplemented],
            ['', ...],
        )

        for invalid_input in examples:
            with self.assertRaises(TypeError, msg=msg1) as ctx_man:
                BaseDataset(batch_sequences=invalid_input)

            self.assertEqual(
                ctx_man.exception.args[0],
                '`batch_sequences` must be an instance of `Iterable[str]`.',
                msg=msg2)