def test_close1(self, data_loader): dataloader = FakeMultiDataloader() data = dataloader.get_data(reference_key='reference_key', reference_len_key='reference_len_key', \ turn_len_key='turn_len_key', gen_prob_key='gen_prob_key', \ gen_key='gen_key', context_key='context_key') if data_loader == 'field': dataloader = dataloader.get_default_field() pm = MultiTurnPerplexityMetric(dataloader, 'reference_key', 'reference_len_key', 'gen_prob_key', \ generate_rare_vocab=True, full_check=True) perplexity = TestMultiTurnPerplexityMetric().get_perplexity( \ data, dataloader, True, 'reference_key', 'reference_len_key', 'gen_prob_key') bcm = MultiTurnBleuCorpusMetric(dataloader, multi_turn_reference_allvocabs_key='reference_key', \ multi_turn_gen_key='gen_key', turn_len_key='turn_len_key') bleu = TestMultiTurnBleuCorpusMetric().get_bleu( dataloader, data, 'reference_key', 'gen_key') _data = copy.deepcopy(data) mc = MetricChain() mc.add_metric(pm) mc.add_metric(bcm) mc.forward(data) res = mc.close() assert np.isclose(res['perplexity'], perplexity) assert np.isclose(res['bleu'], bleu) assert same_dict(data, _data)
def get_teacher_forcing_metric(self, gen_log_prob_key="gen_log_prob", invalid_vocab=False): '''Get metrics for teacher-forcing. It contains: * :class:`.metric.PerplexityMetric` Arguments: gen_log_prob_key (str): The key of predicted log probability over words. Refer to :class:`.metric.PerplexityMetric`. Default: ``gen_log_prob``. invalid_vocab (bool): Whether ``gen_log_prob`` contains invalid vocab. Refer to :class:`.metric.PerplexityMetric`. Default: ``False``. Returns: A :class:`.metric.MetricChain` object. ''' metric = MetricChain() metric.add_metric( PerplexityMetric(self, reference_allvocabs_key="resp_allvocabs", reference_len_key="resp_length", gen_log_prob_key=gen_log_prob_key, invalid_vocab=invalid_vocab)) return metric
def get_inference_metric(self, multi_turn_gen_key="multi_turn_gen"): metric = MetricChain() metric.add_metric( MyBleuMetric(self, multi_turn_gen_key=multi_turn_gen_key, multi_turn_reference_allvocabs_key="sent_allvocabs", turn_len_key="turn_length")) metric.add_metric( MultiTurnDialogRecorder( self, multi_turn_gen_key=multi_turn_gen_key, multi_turn_reference_allvocabs_key="sent_allvocabs", turn_len_key="turn_length")) metric.add_metric( MyDistinctMetric(self, multi_turn_gen_key=multi_turn_gen_key, turn_len_key="turn_length")) return metric
def get_inference_metric(self, gen_key="gen"): '''Get metrics for inference. It contains: * :class:`.metric.BleuCorpusMetric` * :class:`.metric.SingleTurnDialogRecorder` Arguments: gen_key (str): The key of generated sentences in index form. Refer to :class:`.metric.BleuCorpusMetric` or :class:`.metric.SingleTurnDialogRecorder`. Default: ``gen``. Returns: A :class:`.metric.MetricChain` object. ''' metric = MetricChain() metric.add_metric(BleuCorpusMetric(self, gen_key=gen_key, \ reference_allvocabs_key="resp_allvocabs")) metric.add_metric(SingleTurnDialogRecorder(self, gen_key=gen_key)) metric.add_metric(SingleTurnDistinct(self, gen_key=gen_key)) return metric
def test_add_metric(self): mc = MetricChain() with pytest.raises(TypeError): mc.add_metric([1, 2, 3])