def test_close(self, argument, shape, type, batch_len, gen_len, ref_len): # 'default' or 'custom' # 'pad' or 'jag' # 'list' or 'array' # 'equal' or 'unequal' # 'random', 'non-empty', 'empty' # 'random', 'non-empty', 'empty' dataloader = FakeDataLoader() reference_key, gen_key = self.default_keywords \ if argument == 'default' else ('rk', 'gk') data = dataloader.get_data(reference_key=reference_key, gen_key=gen_key, \ to_list=(type == 'list'), pad=(shape == 'pad'), \ gen_len=gen_len, ref_len=ref_len) _data = copy.deepcopy(data) if argument == 'default': bcm = BleuCorpusMetric(dataloader) else: bcm = BleuCorpusMetric(dataloader, reference_allvocabs_key=reference_key, gen_key=gen_key) if batch_len == 'unequal': data[reference_key] = data[reference_key][1:] _data = copy.deepcopy(data) with pytest.raises(ValueError, match='Batch num is not matched.'): bcm.forward(data) else: bcm.forward(data) assert np.isclose(bcm.close()['bleu'], self.get_bleu(dataloader, data, reference_key, gen_key)) assert same_dict(data, _data)
def test_bleu_bug(self): dataloader = FakeDataLoader() ref = [[2, 5, 3]] gen = [[5]] data = {self.default_reference_key: ref, self.default_gen_key: gen} bcm = BleuCorpusMetric(dataloader) with pytest.raises(ZeroDivisionError): bcm.forward(data) bcm.close()
def test_bleu_bug(self): dataloader = FakeDataLoader() ref = [[2, 1, 3]] gen = [[1]] data = {'resp_allvocabs': ref, 'gen': gen} bcm = BleuCorpusMetric(dataloader) with pytest.raises(ZeroDivisionError): bcm.forward(data) bcm.close()
def test_hashvalue(self, to_list, pad): dataloader = FakeDataLoader() reference_key, gen_key = self.default_keywords key_list = [reference_key, gen_key] data = dataloader.get_data(reference_key=reference_key, gen_key=gen_key, \ to_list=to_list, pad=pad, \ gen_len='non-empty', ref_len='non-empty') bcm = BleuCorpusMetric(dataloader) bcm_shuffle = BleuCorpusMetric(dataloader) data_shuffle = shuffle_instances(data, key_list) batches_shuffle = split_batch(data_shuffle, key_list, \ less_pad=pad, to_list=to_list, \ reference_key=reference_key, reference_is_3D=False) bcm.forward(data) res = bcm.close() for batch in batches_shuffle: bcm_shuffle.forward(batch) res_shuffle = bcm_shuffle.close() assert same_dict(res, res_shuffle, False) for data_unequal in generate_unequal_data(data, key_list, dataloader.pad_id, \ reference_key, reference_is_3D=False): bcm_unequal = BleuCorpusMetric(dataloader) bcm_unequal.forward(data_unequal) res_unequal = bcm_unequal.close() assert res['bleu hashvalue'] != res_unequal['bleu hashvalue']
def test_hashvalue(self, data_loader, to_list, pad, reference_num, data_reference_num): dataloader = FakeDataLoader() reference_key, gen_key = self.default_keywords key_list = [reference_key, gen_key] data = dataloader.get_data(reference_key=reference_key, gen_key=gen_key, \ to_list=to_list, pad=pad, \ gen_len='non-empty', ref_len='non-empty', reference_num=data_reference_num) if data_loader == 'field': dataloader = dataloader.get_default_field() bcm = BleuCorpusMetric(dataloader, reference_num=reference_num) bcm_shuffle = BleuCorpusMetric(dataloader, reference_num=reference_num) data_shuffle = shuffle_instances(data, key_list) batches_shuffle = split_batch(data_shuffle, key_list, \ less_pad=pad, to_list=to_list, \ reference_key=reference_key, reference_is_3D=False) if reference_num != data_reference_num: with pytest.raises(RuntimeError): bcm.forward(data) return else: bcm.forward(data) res = bcm.close() for batch in batches_shuffle: bcm_shuffle.forward(batch) res_shuffle = bcm_shuffle.close() assert same_dict(res, res_shuffle, False) for data_unequal in generate_unequal_data(data, key_list, dataloader.pad_id, \ reference_key, reference_is_3D=(data_reference_num > 1)): bcm_unequal = BleuCorpusMetric(dataloader, reference_num=reference_num) bcm_unequal.forward(data_unequal) res_unequal = bcm_unequal.close() assert res['bleu hashvalue'] != res_unequal['bleu hashvalue']