示例#1
0
 def test_base_class(self):
     with pytest.raises(NotImplementedError):
         dataloader = FakeMultiDataloader()
         gen = []
         reference = []
         bprm = BleuPrecisionRecallMetric(dataloader, 1, 3)
         super(BleuPrecisionRecallMetric, bprm)._score(gen, reference)
示例#2
0
    def test_close(self, argument, shape, type, batch_len, ref_len, gen_len,
                   ngram):
        dataloader = FakeMultiDataloader()

        if ngram not in range(1, 5):
            with pytest.raises(ValueError,
                               match=r"ngram should belong to \[1, 4\]"):
                bprm = BleuPrecisionRecallMetric(dataloader, ngram, 3)
            return

        if argument == 'default':
            reference_key, gen_key = self.default_keywords
            bprm = BleuPrecisionRecallMetric(dataloader, ngram, 3)
        else:
            reference_key, gen_key = ('rk', 'gk')
            bprm = BleuPrecisionRecallMetric(dataloader, ngram, 3,
                                             reference_key, gen_key)

        # TODO: might need adaptation of dataloader.get_data for test_prec_rec
        # turn_length is not generated_num_per_context conceptually
        data = dataloader.get_data(reference_key=reference_key, gen_key=gen_key, \
                 to_list=(type == 'list'), pad=(shape == 'pad'), \
                 ref_len=ref_len, gen_len=gen_len, test_prec_rec=True)
        _data = copy.deepcopy(data)
        if batch_len == 'unequal':
            data[reference_key] = data[reference_key][1:]
            _data = copy.deepcopy(data)
            with pytest.raises(ValueError, match="Batch num is not matched."):
                bprm.forward(data)
        else:
            bprm.forward(data)
            ans = bprm.close()
            prefix = 'BLEU-' + str(ngram)
            assert sorted(ans.keys()) == [
                prefix + ' hashvalue', prefix + ' precision',
                prefix + ' recall'
            ]

        assert same_dict(data, _data)
示例#3
0
    def test_hashvalue(self):
        dataloader = FakeMultiDataloader()
        reference_key, gen_key = self.default_keywords
        data = dataloader.get_data(reference_key=reference_key, gen_key=gen_key, \
                 to_list=True, pad=False, \
                 ref_len='non-empty', gen_len='non-empty', test_prec_rec=True)
        bprm = BleuPrecisionRecallMetric(dataloader, 4, 3)
        assert bprm.candidate_allvocabs_key == reference_key
        bprm_shuffle = BleuPrecisionRecallMetric(dataloader, 4, 3)

        data_shuffle = shuffle_instances(data, self.default_keywords)
        for idx in range(len(data_shuffle[reference_key])):
            np.random.shuffle(data_shuffle[reference_key][idx])
        batches_shuffle = split_batch(data_shuffle, self.default_keywords)

        bprm.forward(data)
        res = bprm.close()

        for batch in batches_shuffle:
            bprm_shuffle.forward(batch)
        res_shuffle = bprm_shuffle.close()
        assert same_dict(res, res_shuffle, False)

        data_less_word = copy.deepcopy(data)
        data_less_word[reference_key][0][0] = data_less_word[reference_key][0][
            0][:-2]
        for data_unequal in [data_less_word] + generate_unequal_data(data, self.default_keywords, \
                    dataloader.pad_id, \
                    reference_key, reference_is_3D=True):
            bprm_unequal = BleuPrecisionRecallMetric(dataloader, 4, 3)

            bprm_unequal.forward(data_unequal)
            res_unequal = bprm_unequal.close()

            assert res['BLEU-4 hashvalue'] != res_unequal['BLEU-4 hashvalue']
示例#4
0
	def test_close(self, argument, shape, type, batch_len, ref_len, gen_len, ngram):
		dataloader = FakeMultiDataloader()

		if ngram not in range(1, 5):
			with pytest.raises(ValueError, match="ngram should belong to \[1, 4\]"):
				bprm = BleuPrecisionRecallMetric(dataloader, ngram)
			return

		if argument == 'default':
			reference_key, gen_key = ('resp_allvocabs', 'gen')
			bprm = BleuPrecisionRecallMetric(dataloader, ngram)
		else:
			reference_key, gen_key = ('rk', 'gk')
			bprm = BleuPrecisionRecallMetric(dataloader, ngram, reference_key, gen_key)

		data = dataloader.get_data(reference_key=reference_key, gen_key=gen_key, \
								   to_list=(type == 'list'), pad=(shape == 'pad'), \
								   ref_len=ref_len, gen_len=gen_len)
		_data = copy.deepcopy(data)
		if batch_len == 'unequal':
			data[reference_key] = data[reference_key][1:]
			_data = copy.deepcopy(data)
			with pytest.raises(ValueError, match="Batch num is not matched."):
				bprm.forward(data)
		else:
			bprm.forward(data)
			ans = bprm.close()
			prefix = 'BLEU-' + str(ngram)
			assert sorted(ans.keys()) == [prefix + ' precision', prefix + ' recall']

		assert same_dict(data, _data)