def test_close(self, argument, shape, type, batch_len, gen_len, ref_len): # 'default' or 'custom' # 'pad' or 'jag' # 'list' or 'array' # 'equal' or 'unequal' # 'random', 'non-empty', 'empty' # 'random', 'non-empty', 'empty' dataloader = FakeMultiDataloader() reference_key, turn_len_key, gen_key = self.default_keywords \ if argument == 'default' else ('rk', 'tlk', 'gk') data = dataloader.get_data(reference_key=reference_key, turn_len_key=turn_len_key, gen_key=gen_key, \ to_list=(type == 'list'), pad=(shape == 'pad'), \ gen_len=gen_len, ref_len=ref_len) _data = copy.deepcopy(data) if argument == 'default': mtbcm = MultiTurnBleuCorpusMetric(dataloader) else: mtbcm = MultiTurnBleuCorpusMetric(dataloader, multi_turn_reference_allvocabs_key=reference_key, \ multi_turn_gen_key=gen_key, turn_len_key=turn_len_key) if batch_len == 'unequal': data[reference_key] = data[reference_key][1:] _data = copy.deepcopy(data) with pytest.raises(ValueError, match='Batch num is not matched.'): mtbcm.forward(data) else: mtbcm.forward(data) assert np.isclose(mtbcm.close()['bleu'], self.get_bleu(dataloader, data, reference_key, gen_key)) assert same_dict(data, _data)
def test_close1(self, data_loader): dataloader = FakeMultiDataloader() data = dataloader.get_data(reference_key='reference_key', reference_len_key='reference_len_key', \ turn_len_key='turn_len_key', gen_prob_key='gen_prob_key', \ gen_key='gen_key', context_key='context_key') if data_loader == 'field': dataloader = dataloader.get_default_field() pm = MultiTurnPerplexityMetric(dataloader, 'reference_key', 'reference_len_key', 'gen_prob_key', \ generate_rare_vocab=True, full_check=True) perplexity = TestMultiTurnPerplexityMetric().get_perplexity( \ data, dataloader, True, 'reference_key', 'reference_len_key', 'gen_prob_key') bcm = MultiTurnBleuCorpusMetric(dataloader, multi_turn_reference_allvocabs_key='reference_key', \ multi_turn_gen_key='gen_key', turn_len_key='turn_len_key') bleu = TestMultiTurnBleuCorpusMetric().get_bleu( dataloader, data, 'reference_key', 'gen_key') _data = copy.deepcopy(data) mc = MetricChain() mc.add_metric(pm) mc.add_metric(bcm) mc.forward(data) res = mc.close() assert np.isclose(res['perplexity'], perplexity) assert np.isclose(res['bleu'], bleu) assert same_dict(data, _data)
def test_bleu(self): dataloader = FakeMultiDataloader() ref = [[[2, 5, 3]]] gen = [[[5]]] turn_len = [1] data = {self.default_reference_key: ref, self.default_gen_key: gen, self.default_turn_len_key: turn_len} mtbcm = MultiTurnBleuCorpusMetric(dataloader) with pytest.raises(ZeroDivisionError): mtbcm.forward(data) mtbcm.close()
def test_bleu(self): dataloader = FakeMultiDataloader() ref = [[[2, 1, 3]]] gen = [[[1]]] turn_len = [1] data = {'reference_allvocabs': ref, 'gen': gen, 'turn_length': turn_len} mtbcm = MultiTurnBleuCorpusMetric(dataloader) with pytest.raises(ZeroDivisionError): mtbcm.forward(data) mtbcm.close()
def test_hashvalue(self, to_list, pad): dataloader = FakeMultiDataloader() reference_key, turn_len_key, gen_key = self.default_keywords key_list = [reference_key, turn_len_key, gen_key] data = dataloader.get_data(reference_key=reference_key, turn_len_key=turn_len_key, gen_key=gen_key, \ to_list=to_list, pad=pad, ref_len='non-empty', \ ref_vocab='non-empty') mtbcm = MultiTurnBleuCorpusMetric(dataloader) mtbcm_shuffle = MultiTurnBleuCorpusMetric(dataloader) data_shuffle = shuffle_instances(data, key_list) batches_shuffle = split_batch(data_shuffle, key_list, \ less_pad=pad, to_list=to_list, \ reference_key=reference_key, reference_is_3D=True) mtbcm.forward(data) res = mtbcm.close() for batch in batches_shuffle: mtbcm_shuffle.forward(batch) res_shuffle = mtbcm_shuffle.close() assert same_dict(res, res_shuffle, False) data_less_word = copy.deepcopy(data) for idx, turn_len in enumerate(data_less_word[turn_len_key]): if turn_len > 1: data_less_word[turn_len_key][idx] -= 1 for data_unequal in [data_less_word] + generate_unequal_data(data, key_list, dataloader.pad_id, \ reference_key=reference_key, reference_is_3D=True): mtbcm_unequal = MultiTurnBleuCorpusMetric(dataloader) mtbcm_unequal.forward(data_unequal) res_unequal = mtbcm_unequal.close() assert res['bleu hashvalue'] != res_unequal['bleu hashvalue']