def base_test_init(self, dl): assert isinstance(dl, LanguageGeneration) assert isinstance(dl.ext_vocab, list) assert dl.ext_vocab[:4] == ["<pad>", "<unk>", "<go>", "<eos>"] assert [dl.pad_id, dl.unk_id, dl.go_id, dl.eos_id] == [0, 1, 2, 3] assert isinstance(dl.key_name, list) assert dl.key_name for word in dl.key_name: assert isinstance(word, str) assert isinstance(dl.all_vocab_list, list) assert dl.vocab_list[:len(dl.ext_vocab)] == dl.ext_vocab assert isinstance(dl.word2id, dict) assert len(dl.word2id) == len(dl.all_vocab_list) assert dl.vocab_size == len(dl.vocab_list) for i, word in enumerate(dl.all_vocab_list): assert isinstance(word, str) assert dl.word2id[word] == i assert dl.all_vocab_size == len(dl.all_vocab_list) for key in dl.key_name: sentence = dl.data[key]['sent'] assert isinstance(sentence, list) assert isinstance(sentence[0], list) assert sentence[0][0] == dl.go_id assert sentence[0][-1] == dl.eos_id # assert the data has valid token assert dl.vocab_size > 4 # assert the data has invalid token assert dl.all_vocab_size > dl.vocab_size gen = Dataloader().get_all_subclasses() for each in gen: pass Dataloader().load_class('LanguageGeneration') Dataloader().load_class('None') with pytest.raises(NotImplementedError): basic = GenerationBase() with pytest.raises(NotImplementedError): class MyLanguageGeneration(GenerationBase): def __init__(self): pass MyLanguageGeneration().get_batch(None, None)
def base_test_version(dl_class): """ Args: dl_class (type): subclass of Dataloader. """ if isinstance(dl_class, str): dl_class = Dataloader.load_class(dl_class) assert hasattr(dl_class, '_version') version = dl_class._version version_path = str(version_dir / '{}_v{}.jsonl'.format(dl_class.__name__, version)) version_info = load_version_info(version_path) assert version_info for dic in version_info: assert 'hash_value' in dic assert 'args' in dic assert 'kwargs' in dic hash_value = dic['hash_value'] args = dic['args'] kwargs = dic['kwargs'] dl = dl_class(*args, **kwargs) assert hash_value == dl.hash_value