def test_return_type(self): r"""Return `int`.""" msg = 'Must return `int`.' examples = ( [ 'Hello', 'World', 'Hello World', ], [ 'Mario use Kimura Lock on Luigi, and Luigi tap out.', 'Mario use Superman Punch.', 'Luigi get TKO.', 'Toad and Toadette are fightting over mushroom (weed).', ], [''], [], ) for batch_sequences in examples: self.assertIsInstance( len(BaseDataset(batch_sequences=batch_sequences)), int, msg=msg )
def test_yield_value(self): r"""Is an iterable which yield sequences in order.""" msg = 'Must be an iterable which yield sequences in order.' examples = ( [ 'Hello', 'World', 'Hello World', ], [ 'Mario use Kimura Lock on Luigi, and Luigi tap out.', 'Mario use Superman Punch.', 'Luigi get TKO.', 'Toad and Toadette are fightting over mushroom (weed).', ], [''], [], ) for batch_sequences in examples: dataset = BaseDataset(batch_sequences=batch_sequences) self.assertIsInstance(dataset, Iterable, msg=msg) for ans_sequence, sequence in zip(batch_sequences, dataset): self.assertIsInstance(sequence, str, msg=msg) self.assertEqual(sequence, ans_sequence, msg=msg)
def test_return_type(self): r"""Return `collate_fn`.""" msg = 'Must return `collate_fn`.' examples = ( CharDictTokenizer, CharListTokenizer, WhitespaceDictTokenizer, WhitespaceListTokenizer, ) for tokenizer_class in examples: collate_fn = BaseDataset( []).create_collate_fn(tokenizer=tokenizer_class()) self.assertTrue(inspect.isfunction(collate_fn)) self.assertEqual( inspect.signature(collate_fn), inspect.Signature(parameters=[ inspect.Parameter( name='batch_sequences', kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Iterable[str], default=inspect.Parameter.empty), ], return_annotation=Tuple[torch.Tensor, torch.Tensor]), msg=msg)
def test_instance_attributes(self): r"""Declare required instance attributes.""" msg1 = 'Missing instance attribute `{}`.' msg2 = 'Instance attribute `{}` must be an instance of `{}`.' examples = (('batch_sequences', list), ) for attr, attr_type in examples: dataset = BaseDataset(batch_sequences=[]) self.assertTrue(hasattr(dataset, attr), msg=msg1.format(attr)) self.assertIsInstance(getattr(dataset, attr), attr_type, msg=msg2.format(attr, attr_type.__name__))
def test_invalid_input_max_seq_len(self): r"""Raise exception when input `max_seq_len` is invalid.""" msg1 = ( 'Must raise `TypeError` or `ValueError` when input `max_seq_len` ' 'is invalid.') msg2 = 'Inconsistent error message.' examples = ( False, True, 0, 1, -2, 0.0, 1.0, math.nan, -math.nan, math.inf, -math.inf, 0j, 1j, '', b'', (), [], {}, set(), object(), lambda x: x, type, None, NotImplemented, ..., ) for invalid_input in examples: with self.assertRaises((TypeError, ValueError), msg=msg1) as cxt_man: BaseDataset([]).create_collate_fn( tokenizer=CharDictTokenizer(), max_seq_len=invalid_input) if isinstance(cxt_man.exception, TypeError): self.assertEqual(cxt_man.exception.args[0], '`max_seq_len` must be an instance of `int`.', msg=msg2) else: self.assertEqual( cxt_man.exception.args[0], '`max_seq_len` must be greater than `1` or equal to ' '`-1`.', msg=msg2)
def test_invalid_input_tokenizer(self): r"""Raise `TypeError` when input `tokenizer` is invalid.""" msg1 = 'Must raise `TypeError` when input `tokenizer` is invalid.' msg2 = 'Inconsistent error message.' examples = (False, True, 0, 1, -1, 0.0, 1.0, math.nan, -math.nan, math.inf, -math.inf, 0j, 1j, '', b'', (), [], {}, set(), object(), lambda x: x, type, None, NotImplemented, ...) for invalid_input in examples: with self.assertRaises(TypeError, msg=msg1) as ctx_man: BaseDataset([]).create_collate_fn(tokenizer=invalid_input) self.assertEqual( ctx_man.exception.args[0], '`tokenizer` must be an instance of `lmp.tokenizer.BaseTokenizer`.', msg=msg2)
def setUp(self): r"""Setup `collate_fn` instances.""" self.collate_fn_objs = [] cls = self.__class__ for is_uncased in cls.is_uncased_range: for max_seq_len in cls.max_seq_len_range: for tokenizer_class in cls.tokenizer_class_range: self.collate_fn_objs.append({ 'collate_fn': BaseDataset.create_collate_fn( tokenizer=tokenizer_class(is_uncased=is_uncased), max_seq_len=max_seq_len ), 'is_uncased': is_uncased, 'max_seq_len': max_seq_len, 'tokeizer_class': tokenizer_class, })
def test_invalid_input_index(self): r"""Raise `IndexError` or `TypeError` when input `index` is invalid.""" msg1 = ( 'Must raise `IndexError` or `TypeError` when input `index` is invalid.' ) msg2 = 'Inconsistent error message.' examples = (-1, 0.0, 1.0, math.nan, -math.nan, math.inf, -math.inf, 0j, 1j, '', b'', (), [], {}, set(), object(), lambda x: x, type, None, NotImplemented, ...) for invalid_input in examples: with self.assertRaises((IndexError, TypeError), msg=msg1) as ctx_man: BaseDataset([])[invalid_input] if isinstance(ctx_man.exception, TypeError): self.assertEqual(ctx_man.exception.args[0], '`index` must be an instance of `int`.', msg=msg2) else: self.assertIsInstance(ctx_man.exception, IndexError)
def test_return_value(self): r"""Sample single sequence using index.""" msg = 'Must sample single sequence using index.' examples = ( [ 'Hello', 'World', 'Hello World', ], [ 'Mario use Kimura Lock on Luigi, and Luigi tap out.', 'Mario use Superman Punch.', 'Luigi get TKO.', 'Toad and Toadette are fightting over mushroom (weed).', ], [''], [], ) for batch_sequences in examples: dataset = BaseDataset(batch_sequences=batch_sequences) for i in range(len(dataset)): self.assertEqual(dataset[i], batch_sequences[i], msg=msg)
def test_return_dataset_size(self): r"""Return dataset size.""" msg = 'Must return dataset size.' examples = ( ( [ 'Hello', 'World', 'Hello World', ], 3, ), ( [ 'Mario use Kimura Lock on Luigi, and Luigi tap out.', 'Mario use Superman Punch.', 'Luigi get TKO.', 'Toad and Toadette are fightting over mushroom (weed).', ], 4, ), ( [''], 1 ), ( [], 0 ), ) for batch_sequences, dataset_size in examples: self.assertEqual( len(BaseDataset(batch_sequences=batch_sequences)), dataset_size, msg=msg )
def test_invalid_input_batch_sequences(self): r"""Raise `TypeError` when input `batch_sequences` is invalid.""" msg1 = ( 'Must raise `TypeError` when input `batch_sequences` is invalid.') msg2 = 'Inconsistent error message.' examples = ( False, True, 0, 1, -1, 0.0, 1.0, math.nan, -math.nan, math.inf, -math.inf, 0j, 1j, object(), lambda x: x, type, None, NotImplemented, ..., [False], [True], [0], [1], [-1], [0.0], [1.0], [math.nan], [-math.nan], [math.inf], [-math.inf], [0j], [1j], [b''], [()], [[]], [{}], [set()], [object()], [lambda x: x], [type], [None], [NotImplemented], [...], ['', False], ['', True], ['', 0], ['', 1], ['', -1], ['', 0.0], ['', 1.0], ['', math.nan], ['', -math.nan], ['', math.inf], ['', -math.inf], ['', 0j], ['', 1j], ['', b''], ['', ()], ['', []], ['', {}], ['', set()], ['', object()], ['', lambda x: x], ['', type], ['', None], ['', NotImplemented], ['', ...], ) for invalid_input in examples: with self.assertRaises(TypeError, msg=msg1) as ctx_man: BaseDataset(batch_sequences=invalid_input) self.assertEqual( ctx_man.exception.args[0], '`batch_sequences` must be an instance of `Iterable[str]`.', msg=msg2)