Esempio n. 1
0
 def base_test_get_batch(self, lp: LanguageProcessing):
     with pytest.raises(ValueError):
         lp.get_batch("unknown set", [0, 1])
     for set_name in lp.data.keys():
         with pytest.raises(IndexError):
             length = len(lp.index[set_name])
             lp.get_batch(set_name, [length - 1, length])
         assert len(lp.index[set_name]) >= 2
         batch = lp.get_batch(set_name, [0, 1])
         for field_name, content in batch.items():
             assert len(content) == 2
Esempio n. 2
0
 def base_test_restart(self, lp: LanguageProcessing):
     with pytest.raises(ValueError):
         lp.restart("unknown set")
     for set_name in lp.data.keys():
         with pytest.raises(ValueError):
             lp.restart(set_name)
         record_index = copy.copy(lp.index[set_name])
         lp.restart(set_name, batch_size=3, shuffle=False)
         assert record_index == lp.index[set_name]
         assert lp.batch_id[set_name] == 0
         assert lp.batch_size[set_name] == 3
         #rng_state_st = random.getstate()
         lp.restart(set_name, shuffle=True)
         #rng_state_ed = random.getstate()
         #assert operator.eq(rng_state_st, rng_state_ed)
         assert lp.batch_id[set_name] == 0
         record_index = copy.copy(lp.index[set_name])
         lp.restart(set_name, shuffle=False)
         assert record_index == lp.index[set_name]
         assert lp.batch_id[set_name] == 0
Esempio n. 3
0
 def base_test_set_default_field(self, lp: LanguageProcessing):
     for set_name, data in lp.data.items():
         with pytest.raises(KeyError):
             lp.set_default_field('unknown_set', 'unknown_field')
         for field_name, _ in data.items():
             with pytest.raises(KeyError):
                 lp.set_default_field(set_name, 'unknown_field')
             lp.set_default_field(set_name, field_name)
             assert lp.default_field_set_name == set_name
             assert lp.default_field_name == field_name
Esempio n. 4
0
 def base_test_get_batches(self, lp: LanguageProcessing):
     lp_cp = copy.deepcopy(lp)
     for set_name in lp.data.keys():
         #rng_state = random.getstate()
         lp_batches = iter(lp.get_batches(set_name, 3, False))
         #random.setstate(rng_state)
         lp_cp.restart(set_name, 3, False)
         while True:
             res_cp = lp_cp.get_next_batch(set_name)
             if res_cp is None:
                 break
             res = next(lp_batches)
             assert sorted(res_cp.keys()) == sorted(res.keys())
             for key in res_cp.keys():
                 if isinstance(res_cp[key], np.ndarray):
                     assert (res_cp[key] == res[key]).all()
                 else:
                     assert res_cp[key] == res[key]
Esempio n. 5
0
    def base_test_get_next_batch(self, lp: LanguageProcessing):
        with pytest.raises(ValueError):
            lp.get_next_batch("unknown set")
        for set_name in lp.data.keys():
            with pytest.raises(RuntimeError):
                lp.get_next_batch(set_name)

            lp.restart(set_name, 7)
            sample_num = 0
            while True:
                batch = lp.get_next_batch(set_name, ignore_left_samples=True)
                if not batch:
                    break
                for field_name, content in batch.items():
                    assert len(content) == 7
                sample_num += 7
            for field_name, content in lp.data[set_name].items():
                assert isinstance(content, dict)
                assert sample_num + 7 >= len(content)
Esempio n. 6
0
    def base_test_init(self, lp: LanguageProcessing):
        with pytest.raises(RuntimeError):
            file_id = './tests/dataloader/dummy_languageprocessing'
            fields = []
            LanguageProcessing.simple_create(file_id,
                                             fields,
                                             tokenizer='space',
                                             min_frequent_vocab_times=3)
        with pytest.raises(RuntimeError):
            LanguageProcessing('./tests/dataloader/dummy_languageprocessing',
                               [])

        with pytest.raises(TypeError):
            file_id = './tests/dataloader/dummy_languageprocessing'
            fields = OrderedDict({'sent': 0})
            LanguageProcessing.simple_create(file_id,
                                             fields,
                                             tokenizer='space',
                                             min_frequent_vocab_times=3)
        with pytest.raises(TypeError):
            LanguageProcessing('./tests/dataloader/dummy_languageprocessing',
                               OrderedDict({'sent': 0}))

        with pytest.raises(RuntimeError):
            file_id = './tests/dataloader/dummy_languageprocessing'
            fields = OrderedDict({
                'post': 'SentenceDefault',
                'resp': 'SentenceDefault'
            })
            LanguageProcessing.simple_create(file_id,
                                             fields,
                                             tokenizer='space',
                                             min_frequent_vocab_times=3)

        assert isinstance(lp.file_id, str)
        assert isinstance(lp.file_path, str)
        for set_name, fields in lp.fields.items():
            assert isinstance(set_name, str)
            assert isinstance(fields, dict)
            for field_name, field in fields.items():
                assert isinstance(field_name, str)
                assert isinstance(field, Field)

        assert isinstance(lp.vocabs, list)
        for vocab in lp.vocabs:
            assert isinstance(vocab, Vocab)
        assert isinstance(lp.tokenizers, list)
        for toker in lp.tokenizers:
            assert isinstance(toker, Tokenizer)

        for (_, data), (_, index) in zip(lp.data.items(), lp.index.items()):
            assert isinstance(data, dict)
            assert isinstance(index, list)
            for field_name, content in data.items():
                assert isinstance(content, dict)
                for _, each_content in content.items():
                    assert isinstance(each_content, list)
                    assert len(index) == len(each_content)
        for _, batch_id in lp.batch_id.items():
            assert batch_id == 0
        for _, batch_size in lp.batch_size.items():
            assert batch_size is None
Esempio n. 7
0
 def _simple_create_LanguageProcessing():
     return LanguageProcessing.simple_create(file_id,
                                             fields,
                                             tokenizer='space',
                                             min_frequent_vocab_times=3)
Esempio n. 8
0
 def _load_LanguageProcessing():
     with VocabContext.set_parameters(min_frequent_vocab_times=3):
         with FieldContext.set_parameters(tokenizer='space'):
             return LanguageProcessing(file_id, fields)
Esempio n. 9
0
 def _simple_create_LanguageProcessing():
     return LanguageProcessing.simple_create(file_id, fields)
Esempio n. 10
0
 def _load_LanguageProcessing():
     return LanguageProcessing(file_id, fields)
Esempio n. 11
0
    def base_test_convert(self, lp: LanguageProcessing):
        sent_id = [0, 1, 2]
        sent = ["<pad>", "<unk>", "<go>"]
        assert sent == lp.convert_ids_to_tokens(sent_id)
        assert sent_id == lp.convert_tokens_to_ids(sent)

        sent = ["<unk>", "<go>", "<pad>", "<unkownword>", "<pad>", "<go>"]
        sent_id = [1, 2, 0, 1, 0, 2]
        assert sent_id == lp.convert_tokens_to_ids(sent)
        assert sent_id == lp.convert_tokens_to_ids(sent,
                                                   only_frequent_word=True)

        sent = [lp.all_vocab_list[lp.frequent_vocab_size]]
        assert [1] == lp.convert_tokens_to_ids(sent, only_frequent_word=True)
        assert [lp.frequent_vocab_size] == lp.convert_tokens_to_ids(sent)

        sent_id = [0, 1, 2, 0, 0, 3, 1, 0, 0]
        sent = [
            "<pad>", "<unk>", "<go>", "<pad>", "<pad>", "<eos>", "<unk>",
            "<pad>", "<pad>"
        ]
        assert sent == lp.convert_ids_to_tokens(sent_id, trim=False)
        sent = ["<pad>", "<unk>", "<go>"]
        assert sent == lp.convert_ids_to_tokens(sent_id)

        sent_id = [0, 0, 3]
        sent = ["<pad>", "<pad>", "<eos>"]
        assert sent == lp.convert_ids_to_tokens(sent_id,
                                                remove_special=False,
                                                trim=False)
        assert not lp.convert_ids_to_tokens(sent_id)

        sent_id = [3, 3, 3]
        sent = ["<eos>", "<eos>", "<eos>"]
        assert sent == lp.convert_ids_to_tokens(sent_id,
                                                remove_special=False,
                                                trim=False)
        assert not lp.convert_ids_to_tokens(sent_id)

        sent_id = [0, 0, 0]
        sent = ["<pad>", "<pad>", "<pad>"]
        assert sent == lp.convert_ids_to_tokens(sent_id, trim=False)
        assert not lp.convert_ids_to_tokens(sent_id)
Esempio n. 12
0
 def base_test_get_field(self, lp: LanguageProcessing):
     for set_name, data in lp.data.items():
         for field_name, _ in data.items():
             assert lp.get_field(
                 set_name, field_name) == lp.fields[set_name][field_name]