Example #1
0
class Metrics(object):
    """Class that maintains evaluation metrics over dialog."""
    def __init__(self, opt):
        self.metrics = {}
        self.metrics['cnt'] = 0
        self.metrics_list = ['mean_rank', 'loss', 'correct', 'f1', 'ppl']
        if nltkbleu is not None:
            # only compute bleu if we can
            self.metrics_list.append('bleu')
        for k in self.metrics_list:
            self.metrics[k] = 0.0
            self.metrics[k + '_cnt'] = 0
        self.eval_pr = [1, 5, 10, 100]
        for k in self.eval_pr:
            self.metrics['hits@' + str(k)] = 0
        self.metrics['hits@_cnt'] = 0
        self.flags = {
            'has_text_cands': False,
            'print_prediction_metrics': False
        }
        if opt.get('numthreads', 1) > 1:
            self.metrics = SharedTable(self.metrics)
            self.flags = SharedTable(self.flags)

    def __str__(self):
        return str(self.metrics)

    def __repr__(self):
        representation = super().__repr__()
        return representation.replace('>', ': {}>'.format(repr(self.metrics)))

    def _lock(self):
        if hasattr(self.metrics, 'get_lock'):
            # use the shared_table's lock
            return self.metrics.get_lock()
        else:
            # otherwise do nothing
            return no_lock()

    def update_ranking_metrics(self, observation, labels):
        text_cands = observation.get('text_candidates', None)
        if text_cands is None:
            return
        else:
            # Now loop through text candidates, assuming they are sorted.
            # If any of them is a label then score a point.
            # maintain hits@1, 5, 10, 50, 100,  etc.
            label_set = set(normalize_answer(l) for l in labels)
            cnts = {k: 0 for k in self.eval_pr}
            cnt = 0
            for c in text_cands:
                cnt += 1
                if normalize_answer(c) in label_set:
                    for k in self.eval_pr:
                        if cnt <= k:
                            cnts[k] += 1
            # hits metric is 1 if cnts[k] > 0.
            # (other metrics such as p@k and r@k take
            # the value of cnt into account.)
            with self._lock():
                self.flags['has_text_cands'] = True
                for k in self.eval_pr:
                    if cnts[k] > 0:
                        self.metrics['hits@' + str(k)] += 1
                self.metrics['hits@_cnt'] += 1

    def update(self, observation, labels):
        with self._lock():
            self.metrics['cnt'] += 1

        # Exact match metric.
        correct = 0
        prediction = observation.get('text', None)
        if prediction is not None:
            if _exact_match(prediction, labels):
                correct = 1
            with self._lock():
                self.flags['print_prediction_metrics'] = True
                self.metrics['correct'] += correct
                self.metrics['correct_cnt'] += 1

            # F1 and BLEU metrics.
            f1 = _f1_score(prediction, labels)
            bleu = _bleu(prediction, labels)
            with self._lock():
                self.metrics['f1'] += f1
                self.metrics['f1_cnt'] += 1
                if bleu is not None:
                    self.metrics['bleu'] += bleu
                    self.metrics['bleu_cnt'] += 1

        # Ranking metrics.
        self.update_ranking_metrics(observation, labels)

        # User-reported metrics
        if 'metrics' in observation:
            for k, v in observation['metrics'].items():
                if k not in ['correct', 'f1', 'hits@k', 'bleu']:
                    if k in self.metrics_list:
                        with self._lock():
                            self.metrics[k] += v
                            self.metrics[k + '_cnt'] += 1
                    else:
                        if type(self.metrics) is SharedTable:
                            # can't share custom metrics during hogwild
                            pass
                        else:
                            # no need to lock because not SharedTable
                            if k not in self.metrics:
                                self.metrics[k] = v
                                self.metrics_list.append(k)
                                self.metrics[k + '_cnt'] = 1.0
                            else:
                                self.metrics[k] += v

        # Return a dict containing the metrics for this specific example.
        # Metrics across all data is stored internally in the class, and
        # can be accessed with the report method.
        loss = {}
        loss['correct'] = correct
        return loss

    def report(self):
        # Report the metrics over all data seen so far.
        m = {}
        total = self.metrics['cnt']
        m['exs'] = total
        if total > 0:
            if self.flags['print_prediction_metrics']:
                m['accuracy'] = round_sigfigs(
                    self.metrics['correct'] /
                    max(1, self.metrics['correct_cnt']), 4)
                m['f1'] = round_sigfigs(
                    self.metrics['f1'] / max(1, self.metrics['f1_cnt']), 4)
            if self.flags['has_text_cands']:
                for k in self.eval_pr:
                    m['hits@' + str(k)] = round_sigfigs(
                        self.metrics['hits@' + str(k)] /
                        max(1, self.metrics['hits@_cnt']), 3)
            for k in self.metrics_list:
                if self.metrics[k +
                                '_cnt'] > 0 and k != 'correct' and k != 'f1':
                    m[k] = round_sigfigs(
                        self.metrics[k] / max(1, self.metrics[k + '_cnt']), 4)
        return m

    def clear(self):
        with self._lock():
            self.metrics['cnt'] = 0
            for k in self.metrics_list:
                v = self.metrics[k]
                v_typ = type(v)
                if 'Tensor' in str(v_typ):
                    self.metrics[k].zero_()
                else:
                    self.metrics[k] = 0.0
                self.metrics[k + '_cnt'] = 0
            for k in self.eval_pr:
                self.metrics['hits@' + str(k)] = 0
            self.metrics['hits@_cnt'] = 0
Example #2
0
class PerplexityWorld(World):
    """Instead of just calling act/observe on each agent, this world just calls
    act on the teacher and then calls `next_word_probability` on the agent.

    The label for each example is parsed by the provided tokenizer, and then
    for each word in the parsed label the model is given the input and all of
    the tokens up to the current word and asked to predict the current word.

    The model must return a probability of any words it thinks are likely in
    the form of a dict mapping words to scores. If the scores do not sum to 1,
    they are normalized to do so. If the correct word is not present or has a
    probablity of zero, it will be assigned a probability of 1e-8.

    The API of the next_word_probability function which agents must implement
    is mentioned in the documentation for this file.
    """
    def __init__(self, opt, agents, shared=None):
        super().__init__(opt)
        if shared:
            # Create agents based on shared data.
            self.task, self.agent, self.dict = create_agents_from_shared(
                shared['agents']
            )
            self.metrics = shared['metrics']
        else:
            if len(agents) != 3:
                raise RuntimeError('There must be exactly three agents.')
            if opt.get('batchsize', 1) > 1:
                raise RuntimeError('This world only works with bs=1. Try '
                                   'using multiple threads instead, nt>1.')
            self.task, self.agent, self.dict = agents
            if not hasattr(self.agent, 'next_word_probability'):
                raise RuntimeError('Agent must implement function '
                                   '`next_word_probability`.')
            self.metrics = {'exs': 0, 'loss': 0.0, 'num_tokens': 0, 'num_unk': 0}
            if opt.get('numthreads', 1) > 1:
                self.metrics = SharedTable(self.metrics)
        self.agents = [self.task, self.agent, self.dict]
        self.acts = [None, None]

    def _lock(self):
        if hasattr(self.metrics, 'get_lock'):
            # use the shared_table's lock
            return self.metrics.get_lock()
        else:
            # otherwise do nothing
            return no_lock()

    def parley(self):
        action = self.task.act()
        self.acts[0] = action.copy()

        # hide labels from model
        labels = action.get('eval_labels', action.get('labels', None))
        if 'label_candidates' in action:
            action.pop('label_candidates')
        if labels is None:
            # empty example, move on
            return

        parsed = self.dict.tokenize(labels[0])
        loss = 0
        num_tokens = 0
        num_unk = 0
        self.agent.observe(action)
        for i in range(len(parsed)):
            if parsed[i] in self.dict:
                # only score words which are in the dictionary
                probs = self.agent.next_word_probability(parsed[:i])
                # get probability of correct answer, divide by total prob mass
                prob_true = probs.get(parsed[i], 0)
                if prob_true > 0:
                    prob_true /= sum((probs.get(k, 0) for k in self.dict.keys()))
                    loss -= math.log(prob_true)
                else:
                    loss = float('inf')
                num_tokens += 1
            else:
                num_unk += 1
        with self._lock():
            self.metrics['exs'] += 1
            self.metrics['loss'] += loss
            self.metrics['num_tokens'] += num_tokens
            self.metrics['num_unk'] += num_unk

    def epoch_done(self):
        return self.task.epoch_done()

    def num_examples(self):
        return self.task.num_examples()

    def num_episodes(self):
        return self.task.num_episodes()

    def share(self):
        shared = super().share()
        shared['metrics'] = self.metrics
        return shared

    def reset_metrics(self):
        with self._lock():
            self.metrics['exs'] = 0
            self.metrics['loss'] = 0
            self.metrics['num_tokens'] = 0
            self.metrics['num_unk'] = 0

    def report(self):
        m = {}
        with self._lock():
            m['exs'] = self.metrics['exs']
            if m['exs'] > 0:
                # m['num_unk'] = self.metrics['num_unk']
                # m['num_tokens'] = self.metrics['num_tokens']
                m['loss'] = round_sigfigs(
                    self.metrics['loss'] / self.metrics['num_tokens'],
                    3
                )
                m['ppl'] = round_sigfigs(
                    math.exp(self.metrics['loss'] / self.metrics['num_tokens']),
                    4
                )
        return m
Example #3
0
class Metrics(object):
    """Class that maintains evaluation metrics over dialog."""
    def __init__(self, opt):
        self.metrics = {}
        self.metrics['cnt'] = 0
        self.metrics_list = set()
        optional_metrics_list = []
        metrics_arg = opt.get('metrics', 'default')
        if metrics_arg == 'default':
            optional_metrics_list = DEFAULT_METRICS
        elif metrics_arg == 'all':
            optional_metrics_list = ALL_METRICS
        else:
            optional_metrics_list = set(metrics_arg.split(','))
            optional_metrics_list.add('correct')
        for each_m in optional_metrics_list:
            if each_m.startswith('rouge') and rouge is not None:
                self.metrics_list.add('rouge')
            elif each_m == 'bleu' and nltkbleu is None:
                # only compute bleu if we can
                pass
            else:
                self.metrics_list.add(each_m)
        metrics_list = (self.metrics_list if 'rouge' not in self.metrics_list
                        else self.metrics_list | ROUGE_METRICS)
        for k in metrics_list:
            self.metrics[k] = 0.0
            self.metrics[k + '_cnt'] = 0
        self.eval_pr = [1, 5, 10, 100]
        for k in self.eval_pr:
            self.metrics['hits@' + str(k)] = 0
        self.metrics['hits@_cnt'] = 0
        self.flags = {
            'has_text_cands': False,
            'print_prediction_metrics': False
        }
        if opt.get('numthreads', 1) > 1:
            self.metrics = SharedTable(self.metrics)
            self.flags = SharedTable(self.flags)

    def __str__(self):
        return str(self.metrics)

    def __repr__(self):
        representation = super().__repr__()
        return representation.replace('>', ': {}>'.format(repr(self.metrics)))

    def _lock(self):
        if hasattr(self.metrics, 'get_lock'):
            # use the shared_table's lock
            return self.metrics.get_lock()
        else:
            # otherwise do nothing
            return no_lock()

    def _update_ranking_metrics(self, observation, labels):
        text_cands = observation.get('text_candidates', None)
        if text_cands is None:
            return
        else:
            # Now loop through text candidates, assuming they are sorted.
            # If any of them is a label then score a point.
            # maintain hits@1, 5, 10, 50, 100,  etc.
            label_set = set(normalize_answer(l) for l in labels)
            cnts = {k: 0 for k in self.eval_pr}
            cnt = 0
            for c in text_cands:
                cnt += 1
                if normalize_answer(c) in label_set:
                    for k in self.eval_pr:
                        if cnt <= k:
                            cnts[k] += 1
            # hits metric is 1 if cnts[k] > 0.
            # (other metrics such as p@k and r@k take
            # the value of cnt into account.)
            with self._lock():
                self.flags['has_text_cands'] = True
                for k in self.eval_pr:
                    if cnts[k] > 0:
                        self.metrics['hits@' + str(k)] += 1
                self.metrics['hits@_cnt'] += 1

    def update(self, observation, labels):
        """Update metrics based on an observation and true labels."""
        with self._lock():
            self.metrics['cnt'] += 1

        # Exact match metric.
        correct = 0
        prediction = observation.get('text', None)
        if prediction is not None:
            if _exact_match(prediction, labels):
                correct = 1
            with self._lock():
                self.flags['print_prediction_metrics'] = True
                self.metrics['correct'] += correct
                self.metrics['correct_cnt'] += 1

            # F1 and BLEU metrics.
            if 'f1' in self.metrics_list:
                f1 = _f1_score(prediction, labels)
            if 'bleu' in self.metrics_list:
                bleu = _bleu(prediction, labels)
            if 'rouge' in self.metrics_list:
                rouge1, rouge2, rougeL = _rouge(prediction, labels)

            with self._lock():
                if 'f1' in self.metrics:
                    self.metrics['f1'] += f1
                    self.metrics['f1_cnt'] += 1
                if 'bleu' in self.metrics:
                    self.metrics['bleu'] += bleu
                    self.metrics['bleu_cnt'] += 1
                if 'rouge-L' in self.metrics:
                    self.metrics['rouge-1'] += rouge1
                    self.metrics['rouge-1_cnt'] += 1
                    self.metrics['rouge-2'] += rouge2
                    self.metrics['rouge-2_cnt'] += 1
                    self.metrics['rouge-L'] += rougeL
                    self.metrics['rouge-L_cnt'] += 1

        # Ranking metrics.
        self._update_ranking_metrics(observation, labels)

        # User-reported metrics
        if 'metrics' in observation:
            for k, v in observation['metrics'].items():
                if k not in ALL_METRICS and k != 'rouge':
                    if k in self.metrics_list:
                        with self._lock():
                            self.metrics[k] += v
                            self.metrics[k + '_cnt'] += 1
                    else:
                        if type(self.metrics) is SharedTable:
                            # can't share custom metrics during hogwild
                            pass
                        else:
                            # no need to lock because not SharedTable
                            if k not in self.metrics:
                                self.metrics[k] = v
                                self.metrics_list.add(k)
                                self.metrics[k + '_cnt'] = 1.0
                            else:
                                self.metrics[k] += v

        # Return a dict containing the metrics for this specific example.
        # Metrics across all data is stored internally in the class, and
        # can be accessed with the report method.
        loss = {}
        loss['correct'] = correct
        return loss

    def report(self):
        """Report the metrics over all data seen so far."""
        m = {}
        total = self.metrics['cnt']
        m['exs'] = total
        if total > 0:
            if self.flags['print_prediction_metrics']:
                if 'accuracy' in self.metrics_list:
                    m['accuracy'] = round_sigfigs(
                        self.metrics['correct'] /
                        max(1, self.metrics['correct_cnt']), 4)
                if 'f1' in self.metrics_list:
                    m['f1'] = round_sigfigs(
                        self.metrics['f1'] / max(1, self.metrics['f1_cnt']), 4)
            if self.flags['has_text_cands']:
                for k in self.eval_pr:
                    m['hits@' + str(k)] = round_sigfigs(
                        self.metrics['hits@' + str(k)] /
                        max(1, self.metrics['hits@_cnt']),
                        3,
                    )
            for k in self.metrics_list:
                if self.metrics[k +
                                '_cnt'] > 0 and k != 'correct' and k != 'f1':
                    m[k] = round_sigfigs(
                        self.metrics[k] / max(1, self.metrics[k + '_cnt']), 4)
        return m

    def clear(self):
        """Clear all the metrics."""
        # TODO: rename to reset for consistency with rest of ParlAI
        with self._lock():
            self.metrics['cnt'] = 0
            metrics_list = (self.metrics_list if 'rouge'
                            not in self.metrics_list else self.metrics_list
                            | ROUGE_METRICS)
            for k in metrics_list:
                v = self.metrics[k]
                v_typ = type(v)
                if 'Tensor' in str(v_typ):
                    self.metrics[k].zero_()
                if isinstance(v, int):
                    self.metrics[k] = 0
                else:
                    self.metrics[k] = 0.0
                self.metrics[k + '_cnt'] = 0
            for k in self.eval_pr:
                self.metrics['hits@' + str(k)] = 0
            self.metrics['hits@_cnt'] = 0
Example #4
0
class Metrics(object):
    """Class that maintains evaluation metrics over dialog."""

    def __init__(self, opt):
        self.metrics = {}
        self.metrics['cnt'] = 0
        self.metrics['correct'] = 0
        self.metrics['f1'] = 0.0
        self.eval_pr = [1, 5, 10, 50, 100]
        for k in self.eval_pr:
            self.metrics['hits@' + str(k)] = 0
        if opt.get('numthreads', 1) > 1:
            self.metrics = SharedTable(self.metrics)
        self.datatype = opt.get('datatype', 'train')

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, exc_traceback):
        pass

    def __str__(self):
        return str(self.metrics)

    def __repr__(self):
        return repr(self.metrics)

    def _lock(self):
        if hasattr(self.metrics, 'get_lock'):
            # use the shared_table's lock
            return self.metrics.get_lock()
        else:
            # otherwise do nothing
            return self

    def update_ranking_metrics(self, observation, labels):
        text_cands = observation.get('text_candidates', None)
        if text_cands is None:
            text = observation.get('text', None)
            if text is None:
                return
            else:
                text_cands = [text]
        # Now loop through text candidates, assuming they are sorted.
        # If any of them is a label then score a point.
        # maintain hits@1, 5, 10, 50, 100,  etc.
        label_set = set(_normalize_answer(l) for l in labels)
        cnts = {k: 0 for k in self.eval_pr}
        cnt = 0
        for c in text_cands:
            cnt += 1
            if _normalize_answer(c) in label_set:
                for k in self.eval_pr:
                    if cnt <= k:
                        cnts[k] += 1
        # hits metric is 1 if cnts[k] > 0.
        # (other metrics such as p@k and r@k take
        # the value of cnt into account.)
        with self._lock():
            for k in self.eval_pr:
                if cnts[k] > 0:
                    self.metrics['hits@' + str(k)] += 1

    def update(self, observation, labels):
        with self._lock():
            self.metrics['cnt'] += 1

        # Exact match metric.
        correct = 0
        prediction = observation.get('text', None)
        if _exact_match(prediction, labels):
            correct = 1
        with self._lock():
            self.metrics['correct'] += correct

        # F1 metric.
        f1 = _f1_score(prediction, labels)
        with self._lock():
            self.metrics['f1'] += f1

        # Ranking metrics.
        self.update_ranking_metrics(observation, labels)

        # Return a dict containing the metrics for this specific example.
        # Metrics across all data is stored internally in the class, and
        # can be accessed with the report method.
        loss = {}
        loss['correct'] = correct
        return loss

    def report(self):
        # Report the metrics over all data seen so far.
        m = {}
        m['total'] = self.metrics['cnt']
        if self.metrics['cnt'] > 0:
            m['accuracy'] = round_sigfigs(
                self.metrics['correct'] / self.metrics['cnt'], 4)
            m['f1'] = round_sigfigs(
                self.metrics['f1'] / self.metrics['cnt'], 4)
            m['hits@k'] = {}
            for k in self.eval_pr:
                m['hits@k'][k] = round_sigfigs(
                    self.metrics['hits@' + str(k)] / self.metrics['cnt'], 4)
        return m

    def clear(self):
        with self._lock():
            self.metrics['cnt'] = 0
            self.metrics['correct'] = 0
            self.metrics['f1'] = 0.0
            for k in self.eval_pr:
                self.metrics['hits@' + str(k)] = 0
Example #5
0
class Metrics(object):
    """Class that maintains evaluation metrics over dialog."""
    def __init__(self, opt):
        self.metrics = {}
        self.metrics['cnt'] = 0
        self.metrics['correct'] = 0
        self.metrics['f1'] = 0.0
        self.eval_pr = [1, 5, 10, 100]
        for k in self.eval_pr:
            self.metrics['hits@' + str(k)] = 0
        if opt.get('numthreads', 1) > 1:
            self.metrics = SharedTable(self.metrics)
        self.datatype = opt.get('datatype', 'train')
        self.custom_keys = []

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, exc_traceback):
        pass

    def __str__(self):
        return str(self.metrics)

    def __repr__(self):
        return repr(self.metrics)

    def _lock(self):
        if hasattr(self.metrics, 'get_lock'):
            # use the shared_table's lock
            return self.metrics.get_lock()
        else:
            # otherwise do nothing
            return self

    def update_ranking_metrics(self, observation, labels):
        text_cands = observation.get('text_candidates', None)
        if text_cands is None:
            text = observation.get('text', None)
            if text is None:
                return
            else:
                text_cands = [text]
        # Now loop through text candidates, assuming they are sorted.
        # If any of them is a label then score a point.
        # maintain hits@1, 5, 10, 50, 100,  etc.
        label_set = set(_normalize_answer(l) for l in labels)
        cnts = {k: 0 for k in self.eval_pr}
        cnt = 0
        for c in text_cands:
            cnt += 1
            if _normalize_answer(c) in label_set:
                for k in self.eval_pr:
                    if cnt <= k:
                        cnts[k] += 1
        # hits metric is 1 if cnts[k] > 0.
        # (other metrics such as p@k and r@k take
        # the value of cnt into account.)
        with self._lock():
            for k in self.eval_pr:
                if cnts[k] > 0:
                    self.metrics['hits@' + str(k)] += 1

    def update(self, observation, labels):
        with self._lock():
            self.metrics['cnt'] += 1

        # Exact match metric.
        correct = 0
        prediction = observation.get('text', None)
        if _exact_match(prediction, labels):
            correct = 1
        with self._lock():
            self.metrics['correct'] += correct

        # F1 metric.
        f1 = _f1_score(prediction, labels)
        with self._lock():
            self.metrics['f1'] += f1

        # Ranking metrics.
        self.update_ranking_metrics(observation, labels)

        # User-reported metrics
        if 'metrics' in observation:
            for k, v in observation['metrics'].items():
                if k not in ['correct', 'f1', 'hits@k']:
                    with self._lock():
                        if k not in self.metrics:
                            self.custom_keys.append(k)
                            self.metrics[k] = v
                        else:
                            self.metrics[k] += v

        # Return a dict containing the metrics for this specific example.
        # Metrics across all data is stored internally in the class, and
        # can be accessed with the report method.
        loss = {}
        loss['correct'] = correct
        return loss

    def report(self):
        # Report the metrics over all data seen so far.
        m = {}
        total = self.metrics['cnt']
        m['total'] = total
        if total > 0:
            m['accuracy'] = round_sigfigs(self.metrics['correct'] / total, 4)
            m['f1'] = round_sigfigs(self.metrics['f1'] / total, 4)
            m['hits@k'] = {}
            for k in self.eval_pr:
                m['hits@k'][k] = round_sigfigs(
                    self.metrics['hits@' + str(k)] / total, 3)
            for k in self.custom_keys:
                if k in self.metrics:
                    m[k] = round_sigfigs(self.metrics[k] / total, 3)
        return m

    def clear(self):
        with self._lock():
            self.metrics['cnt'] = 0
            self.metrics['correct'] = 0
            self.metrics['f1'] = 0.0
            for k in self.eval_pr:
                self.metrics['hits@' + str(k)] = 0
            for k in self.custom_keys:
                self.metrics.pop(k, None)  # safer then casting to zero
Example #6
0
class Metrics(object):
    """Class that maintains evaluation metrics over dialog."""

    def __init__(self, opt):
        self.metrics = {}
        self.metrics['cnt'] = 0
        self.metrics['correct'] = 0
        self.metrics['f1'] = 0.0
        self.custom_metrics = ['mean_rank', 'loss', 'lmloss', 'ppl']
        for k in self.custom_metrics:
            self.metrics[k] = 0.0
            self.metrics[k + '_cnt'] = 0
        self.eval_pr = [1, 5, 10, 100]
        for k in self.eval_pr:
            self.metrics['hits@' + str(k)] = 0
        if opt.get('numthreads', 1) > 1:
            self.metrics = SharedTable(self.metrics)

        self.datatype = opt.get('datatype', 'train')
        self.print_prediction_metrics = False

    def __str__(self):
        return str(self.metrics)

    def __repr__(self):
        return repr(self.metrics)

    def _lock(self):
        if hasattr(self.metrics, 'get_lock'):
            # use the shared_table's lock
            return self.metrics.get_lock()
        else:
            # otherwise do nothing
            return no_lock()

    def update_ranking_metrics(self, observation, labels):
        text_cands = observation.get('text_candidates', None)
        if text_cands is None:
            text = observation.get('text', None)
            if text is None:
                return
            else:
                text_cands = [text]
        # Now loop through text candidates, assuming they are sorted.
        # If any of them is a label then score a point.
        # maintain hits@1, 5, 10, 50, 100,  etc.
        label_set = set(_normalize_answer(l) for l in labels)
        cnts = {k: 0 for k in self.eval_pr}
        cnt = 0
        for c in text_cands:
            cnt += 1
            if _normalize_answer(c) in label_set:
                for k in self.eval_pr:
                    if cnt <= k:
                        cnts[k] += 1
        # hits metric is 1 if cnts[k] > 0.
        # (other metrics such as p@k and r@k take
        # the value of cnt into account.)
        with self._lock():
            for k in self.eval_pr:
                if cnts[k] > 0:
                    self.metrics['hits@' + str(k)] += 1

    def update(self, observation, labels):
        with self._lock():
            self.metrics['cnt'] += 1

        # Exact match metric.
        correct = 0
        prediction = observation.get('text', None)
        if prediction is not None:
            self.print_prediction_metrics = True
            if _exact_match(prediction, labels):
                correct = 1
            with self._lock():
                self.metrics['correct'] += correct

            # F1 metric.
            f1 = _f1_score(prediction, labels)
            with self._lock():
                self.metrics['f1'] += f1

        # Ranking metrics.
        self.update_ranking_metrics(observation, labels)

        # User-reported metrics
        if 'metrics' in observation:
            for k, v in observation['metrics'].items():
                if k not in ['correct', 'f1', 'hits@k']:
                    if k in self.custom_metrics:
                        with self._lock():
                            self.metrics[k] += v
                            self.metrics[k + '_cnt'] += 1
                    else:
                        if type(self.metrics) is SharedTable:
                            # can't share custom metrics during hogwild
                            pass
                        else:
                            if k not in self.metrics:
                                self.metrics[k] = v
                                self.custom_metrics.append(k)
                                self.metrics[k + '_cnt'] = 1.0
                            else:
                                self.metrics[k] += v

        # Return a dict containing the metrics for this specific example.
        # Metrics across all data is stored internally in the class, and
        # can be accessed with the report method.
        loss = {}
        loss['correct'] = correct
        return loss

    def report(self):
        # Report the metrics over all data seen so far.
        m = {}
        total = self.metrics['cnt']
        m['total'] = total
        if total > 0:
            if self.print_prediction_metrics:
                m['accuracy'] = round_sigfigs(self.metrics['correct'] / total, 4)
                m['f1'] = round_sigfigs(self.metrics['f1'] / total, 4)
                m['hits@k'] = {}
                for k in self.eval_pr:
                    m['hits@k'][k] = round_sigfigs(
                        self.metrics['hits@' + str(k)] / total, 3)
            for k in self.custom_metrics:
                if self.metrics[k + '_cnt'] > 0:
                    m[k] = round_sigfigs(self.metrics[k] / self.metrics[k + '_cnt'], 4)
        return m

    def clear(self):
        with self._lock():
            self.metrics['cnt'] = 0
            self.metrics['correct'] = 0
            self.metrics['f1'] = 0.0
            for k in self.custom_metrics:
                v = self.metrics[k]
                v_typ = type(v)
                if 'Tensor' in str(v_typ):
                    self.metrics[k].zero_()
                else:
                    self.metrics[k] = 0.0
                self.metrics[k + '_cnt'] = 0
            for k in self.eval_pr:
                self.metrics['hits@' + str(k)] = 0