예제 #1
0
	def base_test_init(self, dl):
		assert isinstance(dl, LanguageGeneration)
		assert isinstance(dl.ext_vocab, list)
		assert dl.ext_vocab[:4] == ["<pad>", "<unk>", "<go>", "<eos>"]
		assert [dl.pad_id, dl.unk_id, dl.go_id, dl.eos_id] == [0, 1, 2, 3]
		assert isinstance(dl.key_name, list)
		assert dl.key_name
		for word in dl.key_name:
			assert isinstance(word, str)
		assert isinstance(dl.all_vocab_list, list)
		assert dl.vocab_list[:len(dl.ext_vocab)] == dl.ext_vocab
		assert isinstance(dl.word2id, dict)
		assert len(dl.word2id) == len(dl.all_vocab_list)
		assert dl.vocab_size == len(dl.vocab_list)
		for i, word in enumerate(dl.all_vocab_list):
			assert isinstance(word, str)
			assert dl.word2id[word] == i
		assert dl.all_vocab_size == len(dl.all_vocab_list)
		for key in dl.key_name:
			sentence = dl.data[key]['sent']
			assert isinstance(sentence, list)
			assert isinstance(sentence[0], list)
			assert sentence[0][0] == dl.go_id
			assert sentence[0][-1] == dl.eos_id

		# assert the data has valid token
		assert dl.vocab_size > 4
		# assert the data has invalid token
		assert dl.all_vocab_size > dl.vocab_size



		gen = Dataloader().get_all_subclasses()
		for each in gen:
			pass
		Dataloader().load_class('LanguageGeneration')
		Dataloader().load_class('None')


		with pytest.raises(NotImplementedError):
			basic = GenerationBase()



		with pytest.raises(NotImplementedError):
			class MyLanguageGeneration(GenerationBase):
				def __init__(self):
					pass
			MyLanguageGeneration().get_batch(None, None)
예제 #2
0
def base_test_version(dl_class):
    """

	Args:
		dl_class (type): subclass of Dataloader.
	"""
    if isinstance(dl_class, str):
        dl_class = Dataloader.load_class(dl_class)
    assert hasattr(dl_class, '_version')
    version = dl_class._version
    version_path = str(version_dir /
                       '{}_v{}.jsonl'.format(dl_class.__name__, version))
    version_info = load_version_info(version_path)
    assert version_info
    for dic in version_info:
        assert 'hash_value' in dic
        assert 'args' in dic
        assert 'kwargs' in dic
        hash_value = dic['hash_value']
        args = dic['args']
        kwargs = dic['kwargs']
        dl = dl_class(*args, **kwargs)
        assert hash_value == dl.hash_value