コード例 #1
0
	def test_close(self, argument, shape, type, batch_len, gen_len, ref_len):
		# 'default' or 'custom'
		# 'pad' or 'jag'
		# 'list' or 'array'
		# 'equal' or 'unequal'
		# 'random', 'non-empty', 'empty'
		# 'random', 'non-empty', 'empty'
		dataloader = FakeMultiDataloader()
		reference_key, turn_len_key, gen_key = self.default_keywords \
			if argument == 'default' else ('rk', 'tlk', 'gk')
		data = dataloader.get_data(reference_key=reference_key, turn_len_key=turn_len_key, gen_key=gen_key, \
								   to_list=(type == 'list'), pad=(shape == 'pad'), \
								   gen_len=gen_len, ref_len=ref_len)
		_data = copy.deepcopy(data)
		if argument == 'default':
			mtbcm = MultiTurnBleuCorpusMetric(dataloader)
		else:
			mtbcm = MultiTurnBleuCorpusMetric(dataloader, multi_turn_reference_allvocabs_key=reference_key, \
											  multi_turn_gen_key=gen_key, turn_len_key=turn_len_key)

		if batch_len == 'unequal':
			data[reference_key] = data[reference_key][1:]
			_data = copy.deepcopy(data)
			with pytest.raises(ValueError, match='Batch num is not matched.'):
				mtbcm.forward(data)
		else:
			mtbcm.forward(data)
			assert np.isclose(mtbcm.close()['bleu'], self.get_bleu(dataloader, data, reference_key, gen_key))
		assert same_dict(data, _data)
コード例 #2
0
ファイル: test_metric_chain.py プロジェクト: xiaoanshi/cotk
    def test_close1(self, data_loader):
        dataloader = FakeMultiDataloader()
        data = dataloader.get_data(reference_key='reference_key', reference_len_key='reference_len_key', \
                 turn_len_key='turn_len_key', gen_prob_key='gen_prob_key', \
                 gen_key='gen_key', context_key='context_key')
        if data_loader == 'field':
            dataloader = dataloader.get_default_field()
        pm = MultiTurnPerplexityMetric(dataloader, 'reference_key', 'reference_len_key', 'gen_prob_key', \
                  generate_rare_vocab=True, full_check=True)
        perplexity = TestMultiTurnPerplexityMetric().get_perplexity( \
         data, dataloader, True, 'reference_key', 'reference_len_key', 'gen_prob_key')

        bcm = MultiTurnBleuCorpusMetric(dataloader, multi_turn_reference_allvocabs_key='reference_key', \
                multi_turn_gen_key='gen_key', turn_len_key='turn_len_key')
        bleu = TestMultiTurnBleuCorpusMetric().get_bleu(
            dataloader, data, 'reference_key', 'gen_key')

        _data = copy.deepcopy(data)
        mc = MetricChain()
        mc.add_metric(pm)
        mc.add_metric(bcm)
        mc.forward(data)
        res = mc.close()

        assert np.isclose(res['perplexity'], perplexity)
        assert np.isclose(res['bleu'], bleu)
        assert same_dict(data, _data)
コード例 #3
0
	def test_bleu(self):
		dataloader = FakeMultiDataloader()
		ref = [[[2, 5, 3]]]
		gen = [[[5]]]
		turn_len = [1]
		data = {self.default_reference_key: ref, self.default_gen_key: gen, self.default_turn_len_key: turn_len}
		mtbcm = MultiTurnBleuCorpusMetric(dataloader)

		with pytest.raises(ZeroDivisionError):
			mtbcm.forward(data)
			mtbcm.close()
コード例 #4
0
	def test_bleu(self):
		dataloader = FakeMultiDataloader()
		ref = [[[2, 1, 3]]]
		gen = [[[1]]]
		turn_len = [1]
		data = {'reference_allvocabs': ref, 'gen': gen, 'turn_length': turn_len}
		mtbcm = MultiTurnBleuCorpusMetric(dataloader)

		with pytest.raises(ZeroDivisionError):
			mtbcm.forward(data)
			mtbcm.close()
コード例 #5
0
	def test_hashvalue(self, to_list, pad):
		dataloader = FakeMultiDataloader()
		reference_key, turn_len_key, gen_key = self.default_keywords
		key_list = [reference_key, turn_len_key, gen_key]
		data = dataloader.get_data(reference_key=reference_key, turn_len_key=turn_len_key, gen_key=gen_key, \
								   to_list=to_list, pad=pad, ref_len='non-empty', \
								   ref_vocab='non-empty')

		mtbcm = MultiTurnBleuCorpusMetric(dataloader)
		mtbcm_shuffle = MultiTurnBleuCorpusMetric(dataloader)

		data_shuffle = shuffle_instances(data, key_list)
		batches_shuffle = split_batch(data_shuffle, key_list, \
									  less_pad=pad, to_list=to_list, \
									  reference_key=reference_key, reference_is_3D=True)

		mtbcm.forward(data)
		res = mtbcm.close()

		for batch in batches_shuffle:
			mtbcm_shuffle.forward(batch)
		res_shuffle = mtbcm_shuffle.close()
		assert same_dict(res, res_shuffle, False)

		data_less_word = copy.deepcopy(data)
		for idx, turn_len in enumerate(data_less_word[turn_len_key]):
			if turn_len > 1:
				data_less_word[turn_len_key][idx] -= 1
		for data_unequal in [data_less_word] + generate_unequal_data(data, key_list, dataloader.pad_id, \
																	 reference_key=reference_key, reference_is_3D=True):
			mtbcm_unequal = MultiTurnBleuCorpusMetric(dataloader)

			mtbcm_unequal.forward(data_unequal)
			res_unequal = mtbcm_unequal.close()

			assert res['bleu hashvalue'] != res_unequal['bleu hashvalue']