Пример #1
0
    def test_close1(self, data_loader):
        dataloader = FakeMultiDataloader()
        data = dataloader.get_data(reference_key='reference_key', reference_len_key='reference_len_key', \
                 turn_len_key='turn_len_key', gen_prob_key='gen_prob_key', \
                 gen_key='gen_key', context_key='context_key')
        if data_loader == 'field':
            dataloader = dataloader.get_default_field()
        pm = MultiTurnPerplexityMetric(dataloader, 'reference_key', 'reference_len_key', 'gen_prob_key', \
                  generate_rare_vocab=True, full_check=True)
        perplexity = TestMultiTurnPerplexityMetric().get_perplexity( \
         data, dataloader, True, 'reference_key', 'reference_len_key', 'gen_prob_key')

        bcm = MultiTurnBleuCorpusMetric(dataloader, multi_turn_reference_allvocabs_key='reference_key', \
                multi_turn_gen_key='gen_key', turn_len_key='turn_len_key')
        bleu = TestMultiTurnBleuCorpusMetric().get_bleu(
            dataloader, data, 'reference_key', 'gen_key')

        _data = copy.deepcopy(data)
        mc = MetricChain()
        mc.add_metric(pm)
        mc.add_metric(bcm)
        mc.forward(data)
        res = mc.close()

        assert np.isclose(res['perplexity'], perplexity)
        assert np.isclose(res['bleu'], bleu)
        assert same_dict(data, _data)
Пример #2
0
    def get_teacher_forcing_metric(self,
                                   gen_log_prob_key="gen_log_prob",
                                   invalid_vocab=False):
        '''Get metrics for teacher-forcing.

        It contains:

        * :class:`.metric.PerplexityMetric`

        Arguments:
            gen_log_prob_key (str):  The key of predicted log probability over words.
                Refer to :class:`.metric.PerplexityMetric`. Default: ``gen_log_prob``.
            invalid_vocab (bool): Whether ``gen_log_prob`` contains invalid vocab.
                Refer to :class:`.metric.PerplexityMetric`. Default: ``False``.


        Returns:
            A :class:`.metric.MetricChain` object.
        '''
        metric = MetricChain()
        metric.add_metric(
            PerplexityMetric(self,
                             reference_allvocabs_key="resp_allvocabs",
                             reference_len_key="resp_length",
                             gen_log_prob_key=gen_log_prob_key,
                             invalid_vocab=invalid_vocab))
        return metric
Пример #3
0
 def get_inference_metric(self, multi_turn_gen_key="multi_turn_gen"):
     metric = MetricChain()
     metric.add_metric(
         MyBleuMetric(self,
                      multi_turn_gen_key=multi_turn_gen_key,
                      multi_turn_reference_allvocabs_key="sent_allvocabs",
                      turn_len_key="turn_length"))
     metric.add_metric(
         MultiTurnDialogRecorder(
             self,
             multi_turn_gen_key=multi_turn_gen_key,
             multi_turn_reference_allvocabs_key="sent_allvocabs",
             turn_len_key="turn_length"))
     metric.add_metric(
         MyDistinctMetric(self,
                          multi_turn_gen_key=multi_turn_gen_key,
                          turn_len_key="turn_length"))
     return metric
Пример #4
0
    def get_inference_metric(self, gen_key="gen"):
        '''Get metrics for inference.

        It contains:

        * :class:`.metric.BleuCorpusMetric`
        * :class:`.metric.SingleTurnDialogRecorder`

        Arguments:
            gen_key (str): The key of generated sentences in index form.
                Refer to :class:`.metric.BleuCorpusMetric` or
                :class:`.metric.SingleTurnDialogRecorder`. Default: ``gen``.
        Returns:
            A :class:`.metric.MetricChain` object.
        '''
        metric = MetricChain()
        metric.add_metric(BleuCorpusMetric(self, gen_key=gen_key, \
            reference_allvocabs_key="resp_allvocabs"))
        metric.add_metric(SingleTurnDialogRecorder(self, gen_key=gen_key))
        metric.add_metric(SingleTurnDistinct(self, gen_key=gen_key))
        return metric
Пример #5
0
	def test_add_metric(self):
		mc = MetricChain()
		with pytest.raises(TypeError):
			mc.add_metric([1, 2, 3])