class Metrics(object): """Class that maintains evaluation metrics over dialog.""" def __init__(self, opt): self.metrics = {} self.metrics['cnt'] = 0 self.metrics_list = set() optional_metrics_list = [] metrics_arg = opt.get('metrics', 'default') if metrics_arg == 'default': optional_metrics_list = DEFAULT_METRICS elif metrics_arg == 'all': optional_metrics_list = ALL_METRICS else: optional_metrics_list = set(metrics_arg.split(',')) optional_metrics_list.add('correct') for each_m in optional_metrics_list: if each_m.startswith('rouge'): if rouge is not None: # only compute rouge if rouge is available self.metrics_list.add('rouge') elif each_m == 'bleu' and nltkbleu is None: # only compute bleu if bleu is available pass else: self.metrics_list.add(each_m) metrics_list = (self.metrics_list if 'rouge' not in self.metrics_list else self.metrics_list | ROUGE_METRICS) for k in metrics_list: self.metrics[k] = 0.0 self.metrics[k + '_cnt'] = 0 self.eval_pr = [1, 5, 10, 100] for k in self.eval_pr: self.metrics['hits@' + str(k)] = 0 self.metrics['hits@_cnt'] = 0 self.flags = { 'has_text_cands': False, 'print_prediction_metrics': False } if opt.get('numthreads', 1) > 1: self.metrics = SharedTable(self.metrics) self.flags = SharedTable(self.flags) def __str__(self): return str(self.metrics) def __repr__(self): representation = super().__repr__() return representation.replace('>', ': {}>'.format(repr(self.metrics))) def _lock(self): if hasattr(self.metrics, 'get_lock'): # use the shared_table's lock return self.metrics.get_lock() else: # otherwise do nothing return no_lock() def _update_ranking_metrics(self, observation, labels): text_cands = observation.get('text_candidates', None) if text_cands is None: return else: # Now loop through text candidates, assuming they are sorted. # If any of them is a label then score a point. # maintain hits@1, 5, 10, 50, 100, etc. label_set = set(normalize_answer(l) for l in labels) cnts = {k: 0 for k in self.eval_pr} cnt = 0 for c in text_cands: cnt += 1 if normalize_answer(c) in label_set: for k in self.eval_pr: if cnt <= k: cnts[k] += 1 # hits metric is 1 if cnts[k] > 0. # (other metrics such as p@k and r@k take # the value of cnt into account.) with self._lock(): self.flags['has_text_cands'] = True for k in self.eval_pr: if cnts[k] > 0: self.metrics['hits@' + str(k)] += 1 self.metrics['hits@_cnt'] += 1 def update(self, observation, labels): """Update metrics based on an observation and true labels.""" with self._lock(): self.metrics['cnt'] += 1 # Exact match metric. correct = 0 prediction = observation.get('text', None) if prediction is not None: if _exact_match(prediction, labels): correct = 1 with self._lock(): self.flags['print_prediction_metrics'] = True self.metrics['correct'] += correct self.metrics['correct_cnt'] += 1 # F1 and BLEU metrics. if 'f1' in self.metrics_list: f1 = _f1_score(prediction, labels) if 'bleu' in self.metrics_list: bleu = _bleu(prediction, labels) if 'rouge' in self.metrics_list: rouge1, rouge2, rougeL = _rouge(prediction, labels) with self._lock(): if 'f1' in self.metrics: self.metrics['f1'] += f1 self.metrics['f1_cnt'] += 1 if 'bleu' in self.metrics: self.metrics['bleu'] += bleu self.metrics['bleu_cnt'] += 1 if 'rouge-L' in self.metrics and rouge1 is not None: self.metrics['rouge-1'] += rouge1 self.metrics['rouge-1_cnt'] += 1 self.metrics['rouge-2'] += rouge2 self.metrics['rouge-2_cnt'] += 1 self.metrics['rouge-L'] += rougeL self.metrics['rouge-L_cnt'] += 1 # Ranking metrics. self._update_ranking_metrics(observation, labels) # User-reported metrics if 'metrics' in observation: for k, v in observation['metrics'].items(): if k not in ALL_METRICS and k != 'rouge': if k in self.metrics_list: with self._lock(): self.metrics[k] += v self.metrics[k + '_cnt'] += 1 else: if type(self.metrics) is SharedTable: # can't share custom metrics during hogwild pass else: # no need to lock because not SharedTable if k not in self.metrics: self.metrics[k] = v self.metrics_list.add(k) self.metrics[k + '_cnt'] = 1.0 else: self.metrics[k] += v # Return a dict containing the metrics for this specific example. # Metrics across all data is stored internally in the class, and # can be accessed with the report method. loss = {} loss['correct'] = correct return loss def report(self): """Report the metrics over all data seen so far.""" m = {} total = self.metrics['cnt'] m['exs'] = total if total > 0: if self.flags['print_prediction_metrics']: if 'accuracy' in self.metrics_list: m['accuracy'] = round_sigfigs( self.metrics['correct'] / max(1, self.metrics['correct_cnt']), 4) if 'f1' in self.metrics_list: m['f1'] = round_sigfigs( self.metrics['f1'] / max(1, self.metrics['f1_cnt']), 4) if self.flags['has_text_cands']: for k in self.eval_pr: m['hits@' + str(k)] = round_sigfigs( self.metrics['hits@' + str(k)] / max(1, self.metrics['hits@_cnt']), 3, ) for k in self.metrics_list: if self.metrics[k + '_cnt'] > 0 and k != 'correct' and k != 'f1': m[k] = round_sigfigs( self.metrics[k] / max(1, self.metrics[k + '_cnt']), 4) return m def clear(self): """Clear all the metrics.""" # TODO: rename to reset for consistency with rest of ParlAI with self._lock(): self.metrics['cnt'] = 0 metrics_list = (self.metrics_list if 'rouge' not in self.metrics_list else self.metrics_list | ROUGE_METRICS) for k in metrics_list: v = self.metrics[k] v_typ = type(v) if 'Tensor' in str(v_typ): self.metrics[k].zero_() if isinstance(v, int): self.metrics[k] = 0 else: self.metrics[k] = 0.0 self.metrics[k + '_cnt'] = 0 for k in self.eval_pr: self.metrics['hits@' + str(k)] = 0 self.metrics['hits@_cnt'] = 0
class PerplexityWorld(World): """ Instead of just calling act/observe on each agent, this world just calls act on the teacher and then calls `next_word_probability` on the agent. The label for each example is parsed by the provided tokenizer, and then for each word in the parsed label the model is given the input and all of the tokens up to the current word and asked to predict the current word. The model must return a probability of any words it thinks are likely in the form of a dict mapping words to scores. If the scores do not sum to 1, they are normalized to do so. If the correct word is not present or has a probablity of zero, it will be assigned a probability of 1e-8. The API of the next_word_probability function which agents must implement is mentioned in the documentation for this file. """ def __init__(self, opt, agents, shared=None): super().__init__(opt) if shared: # Create agents based on shared data. self.task, self.agent, self.dict = create_agents_from_shared( shared['agents']) self.metrics = shared['metrics'] else: if len(agents) != 3: raise RuntimeError('There must be exactly three agents.') if opt.get('batchsize', 1) > 1: raise RuntimeError('This world only works with bs=1. Try ' 'using multiple threads instead, nt>1.') self.task, self.agent, self.dict = agents if not hasattr(self.agent, 'next_word_probability'): raise RuntimeError('Agent must implement function ' '`next_word_probability`.') self.metrics = { 'exs': 0, 'loss': 0.0, 'num_tokens': 0, 'num_unk': 0 } if opt.get('numthreads', 1) > 1: self.metrics = SharedTable(self.metrics) self.agents = [self.task, self.agent, self.dict] self.acts = [None, None] def _lock(self): if hasattr(self.metrics, 'get_lock'): # use the shared_table's lock return self.metrics.get_lock() else: # otherwise do nothing return no_lock() def parley(self): action = self.task.act() self.acts[0] = action.copy() # hide labels from model labels = action.get('eval_labels', action.get('labels', None)) if 'label_candidates' in action: action.pop('label_candidates') if labels is None: # empty example, move on return parsed = self.dict.tokenize(labels[0]) loss = 0 num_tokens = 0 num_unk = 0 self.agent.observe(action) for i in range(len(parsed)): if parsed[i] in self.dict: # only score words which are in the dictionary probs = self.agent.next_word_probability(parsed[:i]) # get probability of correct answer, divide by total prob mass prob_true = probs.get(parsed[i], 0) if prob_true > 0: prob_true /= sum( (probs.get(k, 0) for k in self.dict.keys())) loss -= math.log(prob_true) else: loss = float('inf') num_tokens += 1 else: num_unk += 1 with self._lock(): self.metrics['exs'] += 1 self.metrics['loss'] += loss self.metrics['num_tokens'] += num_tokens self.metrics['num_unk'] += num_unk def epoch_done(self): return self.task.epoch_done() def num_examples(self): return self.task.num_examples() def num_episodes(self): return self.task.num_episodes() def share(self): shared = super().share() shared['metrics'] = self.metrics return shared def reset_metrics(self): with self._lock(): self.metrics['exs'] = 0 self.metrics['loss'] = 0 self.metrics['num_tokens'] = 0 self.metrics['num_unk'] = 0 def report(self): m = {} with self._lock(): m['exs'] = self.metrics['exs'] if m['exs'] > 0: # m['num_unk'] = self.metrics['num_unk'] # m['num_tokens'] = self.metrics['num_tokens'] m['loss'] = round_sigfigs( self.metrics['loss'] / self.metrics['num_tokens'], 3) m['ppl'] = round_sigfigs( math.exp(self.metrics['loss'] / self.metrics['num_tokens']), 4) return m