def test_base_class(self): with pytest.raises(NotImplementedError): dataloader = FakeMultiDataloader() gen = [] reference = [] bprm = BleuPrecisionRecallMetric(dataloader, 1, 3) super(BleuPrecisionRecallMetric, bprm)._score(gen, reference)
def test_close(self, argument, shape, type, batch_len, ref_len, gen_len, ngram): dataloader = FakeMultiDataloader() if ngram not in range(1, 5): with pytest.raises(ValueError, match=r"ngram should belong to \[1, 4\]"): bprm = BleuPrecisionRecallMetric(dataloader, ngram, 3) return if argument == 'default': reference_key, gen_key = self.default_keywords bprm = BleuPrecisionRecallMetric(dataloader, ngram, 3) else: reference_key, gen_key = ('rk', 'gk') bprm = BleuPrecisionRecallMetric(dataloader, ngram, 3, reference_key, gen_key) # TODO: might need adaptation of dataloader.get_data for test_prec_rec # turn_length is not generated_num_per_context conceptually data = dataloader.get_data(reference_key=reference_key, gen_key=gen_key, \ to_list=(type == 'list'), pad=(shape == 'pad'), \ ref_len=ref_len, gen_len=gen_len, test_prec_rec=True) _data = copy.deepcopy(data) if batch_len == 'unequal': data[reference_key] = data[reference_key][1:] _data = copy.deepcopy(data) with pytest.raises(ValueError, match="Batch num is not matched."): bprm.forward(data) else: bprm.forward(data) ans = bprm.close() prefix = 'BLEU-' + str(ngram) assert sorted(ans.keys()) == [ prefix + ' hashvalue', prefix + ' precision', prefix + ' recall' ] assert same_dict(data, _data)
def test_hashvalue(self): dataloader = FakeMultiDataloader() reference_key, gen_key = self.default_keywords data = dataloader.get_data(reference_key=reference_key, gen_key=gen_key, \ to_list=True, pad=False, \ ref_len='non-empty', gen_len='non-empty', test_prec_rec=True) bprm = BleuPrecisionRecallMetric(dataloader, 4, 3) assert bprm.candidate_allvocabs_key == reference_key bprm_shuffle = BleuPrecisionRecallMetric(dataloader, 4, 3) data_shuffle = shuffle_instances(data, self.default_keywords) for idx in range(len(data_shuffle[reference_key])): np.random.shuffle(data_shuffle[reference_key][idx]) batches_shuffle = split_batch(data_shuffle, self.default_keywords) bprm.forward(data) res = bprm.close() for batch in batches_shuffle: bprm_shuffle.forward(batch) res_shuffle = bprm_shuffle.close() assert same_dict(res, res_shuffle, False) data_less_word = copy.deepcopy(data) data_less_word[reference_key][0][0] = data_less_word[reference_key][0][ 0][:-2] for data_unequal in [data_less_word] + generate_unequal_data(data, self.default_keywords, \ dataloader.pad_id, \ reference_key, reference_is_3D=True): bprm_unequal = BleuPrecisionRecallMetric(dataloader, 4, 3) bprm_unequal.forward(data_unequal) res_unequal = bprm_unequal.close() assert res['BLEU-4 hashvalue'] != res_unequal['BLEU-4 hashvalue']
def test_close(self, argument, shape, type, batch_len, ref_len, gen_len, ngram): dataloader = FakeMultiDataloader() if ngram not in range(1, 5): with pytest.raises(ValueError, match="ngram should belong to \[1, 4\]"): bprm = BleuPrecisionRecallMetric(dataloader, ngram) return if argument == 'default': reference_key, gen_key = ('resp_allvocabs', 'gen') bprm = BleuPrecisionRecallMetric(dataloader, ngram) else: reference_key, gen_key = ('rk', 'gk') bprm = BleuPrecisionRecallMetric(dataloader, ngram, reference_key, gen_key) data = dataloader.get_data(reference_key=reference_key, gen_key=gen_key, \ to_list=(type == 'list'), pad=(shape == 'pad'), \ ref_len=ref_len, gen_len=gen_len) _data = copy.deepcopy(data) if batch_len == 'unequal': data[reference_key] = data[reference_key][1:] _data = copy.deepcopy(data) with pytest.raises(ValueError, match="Batch num is not matched."): bprm.forward(data) else: bprm.forward(data) ans = bprm.close() prefix = 'BLEU-' + str(ngram) assert sorted(ans.keys()) == [prefix + ' precision', prefix + ' recall'] assert same_dict(data, _data)