示例#1
0
 def __init__(self, opt, agents, shared=None):
     super().__init__(opt)
     if shared:
         # Create agents based on shared data.
         self.task, self.agent, self.dict = create_agents_from_shared(
             shared['agents'])
         self.metrics = shared['metrics']
     else:
         if len(agents) != 3:
             raise RuntimeError('There must be exactly three agents.')
         if opt.get('batchsize', 1) > 1:
             raise RuntimeError('This world only works with bs=1. Try '
                                'using multiple threads instead, nt>1.')
         self.task, self.agent, self.dict = agents
         if not hasattr(self.agent, 'next_word_probability'):
             raise RuntimeError('Agent must implement function '
                                '`next_word_probability`.')
         self.metrics = {
             'exs': 0,
             'loss': 0.0,
             'num_tokens': 0,
             'num_unk': 0
         }
         if opt.get('numthreads', 1) > 1:
             self.metrics = SharedTable(self.metrics)
     self.agents = [self.task, self.agent, self.dict]
     self.acts = [None, None]
    def test_get_set_del(self):
        st = SharedTable()
        try:
            st['key']
            assert False, 'did not fail on nonexistent key'
        except KeyError:
            pass

        st['key'] = 1
        assert st['key'] == 1

        st['key'] += 1
        assert st['key'] == 2

        try:
            st['key'] = 2.1
            assert False, 'cannot change type of value for set keys'
        except TypeError:
            pass

        del st['key']
        assert 'key' not in st, 'key should have been removed from table'

        st['key'] = 'hello'
        assert st['key'] == 'hello'

        st['key'] += ' world'
        assert st['key'] == 'hello world'

        st['ctr'] = 0
        keyset1 = set(iter(st))
        keyset2 = set(st.keys())
        assert keyset1 == keyset2, 'iterating should return keys'
    def test_torch(self):
        try:
            import torch
        except ImportError:
            # pass by default if no torch available
            return

        st = SharedTable({'a': torch.FloatTensor([1]), 'b': torch.LongTensor(2)})
        assert st['a'][0] == 1.0
        assert len(st) == 2
        assert 'b' in st
        del st['b']
        assert 'b' not in st
        assert len(st) == 1

        if torch.cuda.is_available():
            st = SharedTable(
                {'a': torch.cuda.FloatTensor([1]), 'b': torch.cuda.LongTensor(2)}
            )
            assert st['a'][0] == 1.0
            assert len(st) == 2
            assert 'b' in st
            del st['b']
            assert 'b' not in st
            assert len(st) == 1
示例#4
0
文件: metrics.py 项目: kvthr/ParlAI
 def __init__(self, opt):
     self.metrics = {}
     self.metrics['cnt'] = 0
     self.metrics_list = ['mean_rank', 'loss', 'correct', 'f1', 'ppl']
     if nltkbleu is not None:
         # only compute bleu if we can
         self.metrics_list.append('bleu')
     if rouge is not None:
         # only compute rouge if we can
         self.metrics_list.append('rouge-1')
         self.metrics_list.append('rouge-2')
         self.metrics_list.append('rouge-L')
     for k in self.metrics_list:
         self.metrics[k] = 0.0
         self.metrics[k + '_cnt'] = 0
     self.eval_pr = [1, 5, 10, 100]
     for k in self.eval_pr:
         self.metrics['hits@' + str(k)] = 0
     self.metrics['hits@_cnt'] = 0
     self.flags = {
         'has_text_cands': False,
         'print_prediction_metrics': False
     }
     if opt.get('numthreads', 1) > 1:
         self.metrics = SharedTable(self.metrics)
         self.flags = SharedTable(self.flags)
示例#5
0
    def test_get_set_del(self):
        st = SharedTable()
        try:
            st['key']
            assert False, 'did not fail on nonexistent key'
        except KeyError:
            pass

        st['key'] = 1
        assert st['key'] == 1

        st['key'] += 1
        assert st['key'] == 2

        try:
            st['key'] = 2.1
            assert False, 'cannot change type of value for set keys'
        except TypeError:
            pass

        del st['key']
        assert 'key' not in st, 'key should have been removed from table'

        st['key'] = 'hello'
        assert st['key'] == 'hello'

        st['key'] += ' world'
        assert st['key'] == 'hello world'

        st['ctr'] = 0
        keyset1 = set(iter(st))
        keyset2 = set(st.keys())
        assert keyset1 == keyset2, 'iterating should return keys'
 def test_iter_keys(self):
     st = SharedTable({'key': 0, 'ctr': 0.0, 'val': False, 'other': 1})
     assert len(st) == 4
     del st['key']
     assert len(st) == 3, 'length should decrease after deleting key'
     keyset1 = set(iter(st))
     keyset2 = set(st.keys())
     assert keyset1 == keyset2, 'iterating should return keys'
     assert len(keyset1) == 3, ''
示例#7
0
 def test_iter_keys(self):
     st = SharedTable({'key': 0, 'ctr': 0.0, 'val': False, 'other': 1})
     assert len(st) == 4
     del st['key']
     assert len(st) == 3, 'length should decrease after deleting key'
     keyset1 = set(iter(st))
     keyset2 = set(st.keys())
     assert keyset1 == keyset2, 'iterating should return keys'
     assert len(keyset1) == 3, ''
示例#8
0
文件: metrics.py 项目: wodole/ParlAI
 def __init__(self, opt):
     self.metrics = {}
     self.metrics['cnt'] = 0
     self.metrics['correct'] = 0
     self.metrics['f1'] = 0.0
     self.eval_pr = [1, 5, 10, 50, 100]
     for k in self.eval_pr:
         self.metrics['hits@' + str(k)] = 0
     if opt.get('numthreads', 1) > 1:
         self.metrics = SharedTable(self.metrics)
     self.datatype = opt.get('datatype', 'train')
示例#9
0
文件: metrics.py 项目: zhaimq/ParlAI
 def __init__(self, opt):
     self.metrics = {}
     self.metrics['cnt'] = 0
     self.metrics_list = ['mean_rank', 'loss', 'correct', 'f1', 'ppl']
     for k in self.metrics_list:
         self.metrics[k] = 0.0
         self.metrics[k + '_cnt'] = 0
     self.eval_pr = [1, 5, 10, 100]
     for k in self.eval_pr:
         self.metrics['hits@' + str(k)] = 0
     self.metrics['hits@_cnt'] = 0
     self.flags = {'has_text_cands': False, 'print_prediction_metrics': False}
     if opt.get('numthreads', 1) > 1:
         self.metrics = SharedTable(self.metrics)
         self.flags = SharedTable(self.flags)
    def test_get_set_del(self):
        st = SharedTable({'key': 0})
        try:
            st['none']
            self.fail('did not fail on nonexistent key')
        except KeyError:
            pass

        st['key'] = 1
        assert st['key'] == 1

        st['key'] += 1
        assert st['key'] == 2

        try:
            st['key'] = 2.1
            self.fail('cannot change type of value for set keys')
        except TypeError:
            pass

        del st['key']
        assert 'key' not in st, 'key should have been removed from table'

        try:
            st['key'] = True
            self.fail('cannot change removed key')
        except KeyError:
            pass
示例#11
0
    def __init__(self, opt):
        self.metrics = {}
        self.metrics['cnt'] = 0
        self.metrics['correct'] = 0
        self.metrics['f1'] = 0.0
        self.custom_metrics = ['mean_rank', 'loss']
        for k in self.custom_metrics:
            self.metrics[k] = 0.0
            self.metrics[k + '_cnt'] = 0
        self.eval_pr = [1, 5, 10, 100]
        for k in self.eval_pr:
            self.metrics['hits@' + str(k)] = 0
        if opt.get('numthreads', 1) > 1:
            self.metrics = SharedTable(self.metrics)

        self.custom_keys = []
        self.datatype = opt.get('datatype', 'train')
示例#12
0
文件: metrics.py 项目: jojonki/ParlAI
 def __init__(self, opt):
     self.metrics = {}
     self.metrics['cnt'] = 0
     self.metrics['correct'] = 0
     self.metrics['f1'] = 0.0
     self.eval_pr = [1, 5, 10, 50, 100]
     for k in self.eval_pr:
         self.metrics['hits@' + str(k)] = 0
     if opt.get('numthreads', 1) > 1:
         self.metrics = SharedTable(self.metrics)
     self.datatype = opt.get('datatype', 'train')
示例#13
0
 def share(self):
     """Share model parameters."""
     shared = super().share()
     shared['model'] = self.model
     if self.opt.get('numthreads', 1) > 1 and isinstance(self.metrics, dict):
         torch.set_num_threads(1)
         # move metrics and model to shared memory
         self.metrics = SharedTable(self.metrics)
         self.model.share_memory()
     shared['metrics'] = self.metrics
     return shared
示例#14
0
 def test_init_from_dict(self):
     d = {
         'a': 0,
         'b': 1,
         'c': 1.0,
         'd': True,
         1: False,
         2: 2.0
     }
     st = SharedTable(d)
     for k, v in d.items():
         assert(st[k] == v)
示例#15
0
 def __init__(self, opt):
     self.metrics = {}
     self.metrics['cnt'] = 0
     self.metrics_list = set()
     optional_metrics_list = []
     metrics_arg = opt.get('metrics', 'default')
     if metrics_arg == 'default':
         optional_metrics_list = DEFAULT_METRICS
     elif metrics_arg == 'all':
         optional_metrics_list = ALL_METRICS
     else:
         optional_metrics_list = set(metrics_arg.split(','))
         optional_metrics_list.add('correct')
     for each_m in optional_metrics_list:
         if each_m.startswith('rouge'):
             if rouge is not None:
                 # only compute rouge if rouge is available
                 self.metrics_list.add('rouge')
         elif each_m == 'bleu' and nltkbleu is None:
             # only compute bleu if bleu is available
             pass
         else:
             self.metrics_list.add(each_m)
     metrics_list = (self.metrics_list if 'rouge' not in self.metrics_list
                     else self.metrics_list | ROUGE_METRICS)
     for k in metrics_list:
         self.metrics[k] = 0.0
         self.metrics[k + '_cnt'] = 0
     self.eval_pr = [1, 5, 10, 100]
     for k in self.eval_pr:
         self.metrics['hits@' + str(k)] = 0
     self.metrics['hits@_cnt'] = 0
     self.flags = {
         'has_text_cands': False,
         'print_prediction_metrics': False
     }
     if opt.get('numthreads', 1) > 1:
         self.metrics = SharedTable(self.metrics)
         self.flags = SharedTable(self.flags)
示例#16
0
 def share(self):
     """Share internal states between parent and child instances."""
     shared = super().share()
     shared['model'] = self.model
     if self.opt.get('numthreads', 1) > 1:
         # we're doing hogwild so share the model too
         if type(self.metrics) == dict:
             # move metrics and model to shared memory
             self.metrics = SharedTable(self.metrics)
             self.model.share_memory()
         shared['states'] = {  # don't share optimizer states
             'optimizer_type': self.opt['optimizer'],
         }
     shared['metrics'] = self.metrics  # do after numthreads check
     return shared
示例#17
0
 def share(self):
     """Share model parameters."""
     shared = super().share()
     shared['model'] = self.model
     if self.opt.get('numthreads', 1) > 1 and isinstance(self.metrics, dict):
         torch.set_num_threads(1)
         # move metrics and model to shared memory
         self.metrics = SharedTable(self.metrics)
         self.model.share_memory()
     shared['metrics'] = self.metrics
     shared['fixed_candidates'] = self.fixed_candidates
     shared['fixed_candidate_vecs'] = self.fixed_candidate_vecs
     shared['vocab_candidates'] = self.vocab_candidates
     shared['vocab_candidate_vecs'] = self.vocab_candidate_vecs
     shared['optimizer'] = self.optimizer
     return shared
示例#18
0
文件: metrics.py 项目: ahiroto/ParlAI
    def __init__(self, opt):
        self.metrics = {}
        self.metrics['cnt'] = 0
        self.metrics['correct'] = 0
        self.metrics['f1'] = 0.0
        self.custom_metrics = ['mean_rank', 'loss', 'lmloss', 'ppl']
        for k in self.custom_metrics:
            self.metrics[k] = 0.0
            self.metrics[k + '_cnt'] = 0
        self.eval_pr = [1, 5, 10, 100]
        for k in self.eval_pr:
            self.metrics['hits@' + str(k)] = 0
        if opt.get('numthreads', 1) > 1:
            self.metrics = SharedTable(self.metrics)

        self.datatype = opt.get('datatype', 'train')
        self.print_prediction_metrics = False
 def share(self):
     """Share internal states between parent and child instances."""
     shared = super().share()
     shared['criterion'] = self.criterion
     if self.opt.get('numthreads', 1) > 1:
         # we're doing hogwild so share the model too
         if isinstance(self.metrics, dict):
             # move metrics and model to shared memory
             self.metrics = SharedTable(self.metrics)
             self.model.share_memory()
         shared['states'] = {  # don't share optimizer states
             'optimizer_type': self.opt['optimizer']
         }
     shared['metrics'] = self.metrics  # do after numthreads check
     if self.beam_dot_log is True:
         shared['beam_dot_dir'] = self.beam_dot_dir
     return shared
    def test_concurrent_access(self):
        st = SharedTable({'cnt': 0})

        def inc():
            for _ in range(50):
                with st.get_lock():
                    st['cnt'] += 1
                time.sleep(random.randint(1, 5) / 10000)

        threads = []
        for _ in range(5):  # numthreads
            threads.append(Process(target=inc))
        for t in threads:
            t.start()
        for t in threads:
            t.join()
        assert st['cnt'] == 250
示例#21
0
 def share(self):
     """Share internal states between parent and child instances."""
     shared = super().share()
     shared['opt'] = self.opt
     shared['dict'] = self.dict
     shared['NULL_IDX'] = self.NULL_IDX
     shared['END_IDX'] = self.END_IDX
     shared['model'] = self.model
     if self.opt.get('numthreads', 1) > 1:
         if type(self.metrics) == dict:
             # move metrics and model to shared memory
             self.metrics = SharedTable(self.metrics)
             self.model.share_memory()
         shared['states'] = {  # only need to pass optimizer states
             'optimizer': self.optimizer.state_dict(),
         }
     shared['metrics'] = self.metrics
     return shared
示例#22
0
 def share(self):
     """Share internal states between parent and child instances."""
     shared = super().share()
     shared['opt'] = self.opt
     shared['answers'] = self.answers
     shared['dict'] = self.dict
     shared['START_IDX'] = self.START_IDX
     shared['END_IDX'] = self.END_IDX
     shared['NULL_IDX'] = self.NULL_IDX
     if self.opt.get('numthreads', 1) > 1:
         if type(self.metrics) == dict:
             self.metrics = SharedTable(self.metrics)
             self.model.share_memory()
         shared['metrics'] = self.metrics
         shared['model'] = self.model
         shared['states'] = { # only need to pass optimizer states
             'optimizer': self.optimizer.state_dict(),
             'optimizer_type': self.opt['optimizer'],
         }
     return shared
示例#23
0
 def share(self):
     """Share internal states between parent and child instances."""
     shared = super().share()
     shared['opt'] = self.opt
     shared['answers'] = self.answers
     shared['dict'] = self.dict
     shared['START_IDX'] = self.START_IDX
     shared['END_IDX'] = self.END_IDX
     shared['NULL_IDX'] = self.NULL_IDX
     if self.opt.get('numthreads', 1) > 1:
         # we're doing hogwild so share the model too
         if type(self.metrics) == dict:
             # move metrics and model to shared memory
             self.metrics = SharedTable(self.metrics)
             self.model.share_memory()
         shared['model'] = self.model
         shared['metrics'] = self.metrics
         shared['states'] = {  # don't share optimizer states
             'optimizer_type': self.opt['optimizer'],
         }
     return shared
示例#24
0
class Metrics(object):
    """Class that maintains evaluation metrics over dialog."""
    def __init__(self, opt):
        self.metrics = {}
        self.metrics['cnt'] = 0
        self.metrics_list = set()
        optional_metrics_list = []
        metrics_arg = opt.get('metrics', 'default')
        if metrics_arg == 'default':
            optional_metrics_list = DEFAULT_METRICS
        elif metrics_arg == 'all':
            optional_metrics_list = ALL_METRICS
        else:
            optional_metrics_list = set(metrics_arg.split(','))
            optional_metrics_list.add('correct')
        for each_m in optional_metrics_list:
            if each_m.startswith('rouge') and rouge is not None:
                self.metrics_list.add('rouge')
            elif each_m == 'bleu' and nltkbleu is None:
                # only compute bleu if we can
                pass
            else:
                self.metrics_list.add(each_m)
        metrics_list = (self.metrics_list if 'rouge' not in self.metrics_list
                        else self.metrics_list | ROUGE_METRICS)
        for k in metrics_list:
            self.metrics[k] = 0.0
            self.metrics[k + '_cnt'] = 0
        self.eval_pr = [1, 5, 10, 100]
        for k in self.eval_pr:
            self.metrics['hits@' + str(k)] = 0
        self.metrics['hits@_cnt'] = 0
        self.flags = {
            'has_text_cands': False,
            'print_prediction_metrics': False
        }
        if opt.get('numthreads', 1) > 1:
            self.metrics = SharedTable(self.metrics)
            self.flags = SharedTable(self.flags)

    def __str__(self):
        return str(self.metrics)

    def __repr__(self):
        representation = super().__repr__()
        return representation.replace('>', ': {}>'.format(repr(self.metrics)))

    def _lock(self):
        if hasattr(self.metrics, 'get_lock'):
            # use the shared_table's lock
            return self.metrics.get_lock()
        else:
            # otherwise do nothing
            return no_lock()

    def _update_ranking_metrics(self, observation, labels):
        text_cands = observation.get('text_candidates', None)
        if text_cands is None:
            return
        else:
            # Now loop through text candidates, assuming they are sorted.
            # If any of them is a label then score a point.
            # maintain hits@1, 5, 10, 50, 100,  etc.
            label_set = set(normalize_answer(l) for l in labels)
            cnts = {k: 0 for k in self.eval_pr}
            cnt = 0
            for c in text_cands:
                cnt += 1
                if normalize_answer(c) in label_set:
                    for k in self.eval_pr:
                        if cnt <= k:
                            cnts[k] += 1
            # hits metric is 1 if cnts[k] > 0.
            # (other metrics such as p@k and r@k take
            # the value of cnt into account.)
            with self._lock():
                self.flags['has_text_cands'] = True
                for k in self.eval_pr:
                    if cnts[k] > 0:
                        self.metrics['hits@' + str(k)] += 1
                self.metrics['hits@_cnt'] += 1

    def update(self, observation, labels):
        """Update metrics based on an observation and true labels."""
        with self._lock():
            self.metrics['cnt'] += 1

        # Exact match metric.
        correct = 0
        prediction = observation.get('text', None)
        if prediction is not None:
            if _exact_match(prediction, labels):
                correct = 1
            with self._lock():
                self.flags['print_prediction_metrics'] = True
                self.metrics['correct'] += correct
                self.metrics['correct_cnt'] += 1

            # F1 and BLEU metrics.
            if 'f1' in self.metrics_list:
                f1 = _f1_score(prediction, labels)
            if 'bleu' in self.metrics_list:
                bleu = _bleu(prediction, labels)
            if 'rouge' in self.metrics_list:
                rouge1, rouge2, rougeL = _rouge(prediction, labels)

            with self._lock():
                if 'f1' in self.metrics:
                    self.metrics['f1'] += f1
                    self.metrics['f1_cnt'] += 1
                if 'bleu' in self.metrics:
                    self.metrics['bleu'] += bleu
                    self.metrics['bleu_cnt'] += 1
                if 'rouge-L' in self.metrics:
                    self.metrics['rouge-1'] += rouge1
                    self.metrics['rouge-1_cnt'] += 1
                    self.metrics['rouge-2'] += rouge2
                    self.metrics['rouge-2_cnt'] += 1
                    self.metrics['rouge-L'] += rougeL
                    self.metrics['rouge-L_cnt'] += 1

        # Ranking metrics.
        self._update_ranking_metrics(observation, labels)

        # User-reported metrics
        if 'metrics' in observation:
            for k, v in observation['metrics'].items():
                if k not in ALL_METRICS and k != 'rouge':
                    if k in self.metrics_list:
                        with self._lock():
                            self.metrics[k] += v
                            self.metrics[k + '_cnt'] += 1
                    else:
                        if type(self.metrics) is SharedTable:
                            # can't share custom metrics during hogwild
                            pass
                        else:
                            # no need to lock because not SharedTable
                            if k not in self.metrics:
                                self.metrics[k] = v
                                self.metrics_list.add(k)
                                self.metrics[k + '_cnt'] = 1.0
                            else:
                                self.metrics[k] += v

        # Return a dict containing the metrics for this specific example.
        # Metrics across all data is stored internally in the class, and
        # can be accessed with the report method.
        loss = {}
        loss['correct'] = correct
        return loss

    def report(self):
        """Report the metrics over all data seen so far."""
        m = {}
        total = self.metrics['cnt']
        m['exs'] = total
        if total > 0:
            if self.flags['print_prediction_metrics']:
                if 'accuracy' in self.metrics_list:
                    m['accuracy'] = round_sigfigs(
                        self.metrics['correct'] /
                        max(1, self.metrics['correct_cnt']), 4)
                if 'f1' in self.metrics_list:
                    m['f1'] = round_sigfigs(
                        self.metrics['f1'] / max(1, self.metrics['f1_cnt']), 4)
            if self.flags['has_text_cands']:
                for k in self.eval_pr:
                    m['hits@' + str(k)] = round_sigfigs(
                        self.metrics['hits@' + str(k)] /
                        max(1, self.metrics['hits@_cnt']),
                        3,
                    )
            for k in self.metrics_list:
                if self.metrics[k +
                                '_cnt'] > 0 and k != 'correct' and k != 'f1':
                    m[k] = round_sigfigs(
                        self.metrics[k] / max(1, self.metrics[k + '_cnt']), 4)
        return m

    def clear(self):
        """Clear all the metrics."""
        # TODO: rename to reset for consistency with rest of ParlAI
        with self._lock():
            self.metrics['cnt'] = 0
            metrics_list = (self.metrics_list if 'rouge'
                            not in self.metrics_list else self.metrics_list
                            | ROUGE_METRICS)
            for k in metrics_list:
                v = self.metrics[k]
                v_typ = type(v)
                if 'Tensor' in str(v_typ):
                    self.metrics[k].zero_()
                if isinstance(v, int):
                    self.metrics[k] = 0
                else:
                    self.metrics[k] = 0.0
                self.metrics[k + '_cnt'] = 0
            for k in self.eval_pr:
                self.metrics['hits@' + str(k)] = 0
            self.metrics['hits@_cnt'] = 0
示例#25
0
文件: metrics.py 项目: ahiroto/ParlAI
class Metrics(object):
    """Class that maintains evaluation metrics over dialog."""

    def __init__(self, opt):
        self.metrics = {}
        self.metrics['cnt'] = 0
        self.metrics['correct'] = 0
        self.metrics['f1'] = 0.0
        self.custom_metrics = ['mean_rank', 'loss', 'lmloss', 'ppl']
        for k in self.custom_metrics:
            self.metrics[k] = 0.0
            self.metrics[k + '_cnt'] = 0
        self.eval_pr = [1, 5, 10, 100]
        for k in self.eval_pr:
            self.metrics['hits@' + str(k)] = 0
        if opt.get('numthreads', 1) > 1:
            self.metrics = SharedTable(self.metrics)

        self.datatype = opt.get('datatype', 'train')
        self.print_prediction_metrics = False

    def __str__(self):
        return str(self.metrics)

    def __repr__(self):
        return repr(self.metrics)

    def _lock(self):
        if hasattr(self.metrics, 'get_lock'):
            # use the shared_table's lock
            return self.metrics.get_lock()
        else:
            # otherwise do nothing
            return no_lock()

    def update_ranking_metrics(self, observation, labels):
        text_cands = observation.get('text_candidates', None)
        if text_cands is None:
            text = observation.get('text', None)
            if text is None:
                return
            else:
                text_cands = [text]
        # Now loop through text candidates, assuming they are sorted.
        # If any of them is a label then score a point.
        # maintain hits@1, 5, 10, 50, 100,  etc.
        label_set = set(_normalize_answer(l) for l in labels)
        cnts = {k: 0 for k in self.eval_pr}
        cnt = 0
        for c in text_cands:
            cnt += 1
            if _normalize_answer(c) in label_set:
                for k in self.eval_pr:
                    if cnt <= k:
                        cnts[k] += 1
        # hits metric is 1 if cnts[k] > 0.
        # (other metrics such as p@k and r@k take
        # the value of cnt into account.)
        with self._lock():
            for k in self.eval_pr:
                if cnts[k] > 0:
                    self.metrics['hits@' + str(k)] += 1

    def update(self, observation, labels):
        with self._lock():
            self.metrics['cnt'] += 1

        # Exact match metric.
        correct = 0
        prediction = observation.get('text', None)
        if prediction is not None:
            self.print_prediction_metrics = True
            if _exact_match(prediction, labels):
                correct = 1
            with self._lock():
                self.metrics['correct'] += correct

            # F1 metric.
            f1 = _f1_score(prediction, labels)
            with self._lock():
                self.metrics['f1'] += f1

        # Ranking metrics.
        self.update_ranking_metrics(observation, labels)

        # User-reported metrics
        if 'metrics' in observation:
            for k, v in observation['metrics'].items():
                if k not in ['correct', 'f1', 'hits@k']:
                    if k in self.custom_metrics:
                        with self._lock():
                            self.metrics[k] += v
                            self.metrics[k + '_cnt'] += 1
                    else:
                        if type(self.metrics) is SharedTable:
                            # can't share custom metrics during hogwild
                            pass
                        else:
                            if k not in self.metrics:
                                self.metrics[k] = v
                                self.custom_metrics.append(k)
                                self.metrics[k + '_cnt'] = 1.0
                            else:
                                self.metrics[k] += v

        # Return a dict containing the metrics for this specific example.
        # Metrics across all data is stored internally in the class, and
        # can be accessed with the report method.
        loss = {}
        loss['correct'] = correct
        return loss

    def report(self):
        # Report the metrics over all data seen so far.
        m = {}
        total = self.metrics['cnt']
        m['total'] = total
        if total > 0:
            if self.print_prediction_metrics:
                m['accuracy'] = round_sigfigs(self.metrics['correct'] / total, 4)
                m['f1'] = round_sigfigs(self.metrics['f1'] / total, 4)
                m['hits@k'] = {}
                for k in self.eval_pr:
                    m['hits@k'][k] = round_sigfigs(
                        self.metrics['hits@' + str(k)] / total, 3)
            for k in self.custom_metrics:
                if self.metrics[k + '_cnt'] > 0:
                    m[k] = round_sigfigs(self.metrics[k] / self.metrics[k + '_cnt'], 4)
        return m

    def clear(self):
        with self._lock():
            self.metrics['cnt'] = 0
            self.metrics['correct'] = 0
            self.metrics['f1'] = 0.0
            for k in self.custom_metrics:
                v = self.metrics[k]
                v_typ = type(v)
                if 'Tensor' in str(v_typ):
                    self.metrics[k].zero_()
                else:
                    self.metrics[k] = 0.0
                self.metrics[k + '_cnt'] = 0
            for k in self.eval_pr:
                self.metrics['hits@' + str(k)] = 0
 def test_init_from_dict(self):
     d = {'a': 0, 'b': 1, 'c': 1.0, 'd': 'hello', 1: 'world', 2: 2.0}
     st = SharedTable(d)
     for k, v in d.items():
         assert (st[k] == v)
示例#27
0
class PerplexityWorld(World):
    """Instead of just calling act/observe on each agent, this world just calls
    act on the teacher and then calls `next_word_probability` on the agent.

    The label for each example is parsed by the provided tokenizer, and then
    for each word in the parsed label the model is given the input and all of
    the tokens up to the current word and asked to predict the current word.

    The model must return a probability of any words it thinks are likely in
    the form of a dict mapping words to scores. If the scores do not sum to 1,
    they are normalized to do so. If the correct word is not present or has a
    probablity of zero, it will be assigned a probability of 1e-8.

    The API of the next_word_probability function which agents must implement
    is mentioned in the documentation for this file.
    """
    def __init__(self, opt, agents, shared=None):
        super().__init__(opt)
        if shared:
            # Create agents based on shared data.
            self.task, self.agent, self.dict = create_agents_from_shared(
                shared['agents']
            )
            self.metrics = shared['metrics']
        else:
            if len(agents) != 3:
                raise RuntimeError('There must be exactly three agents.')
            if opt.get('batchsize', 1) > 1:
                raise RuntimeError('This world only works with bs=1. Try '
                                   'using multiple threads instead, nt>1.')
            self.task, self.agent, self.dict = agents
            if not hasattr(self.agent, 'next_word_probability'):
                raise RuntimeError('Agent must implement function '
                                   '`next_word_probability`.')
            self.metrics = {'exs': 0, 'loss': 0.0, 'num_tokens': 0, 'num_unk': 0}
            if opt.get('numthreads', 1) > 1:
                self.metrics = SharedTable(self.metrics)
        self.agents = [self.task, self.agent, self.dict]
        self.acts = [None, None]

    def _lock(self):
        if hasattr(self.metrics, 'get_lock'):
            # use the shared_table's lock
            return self.metrics.get_lock()
        else:
            # otherwise do nothing
            return no_lock()

    def parley(self):
        action = self.task.act()
        self.acts[0] = action.copy()

        # hide labels from model
        labels = action.get('eval_labels', action.get('labels', None))
        if 'label_candidates' in action:
            action.pop('label_candidates')
        if labels is None:
            # empty example, move on
            return

        parsed = self.dict.tokenize(labels[0])
        loss = 0
        num_tokens = 0
        num_unk = 0
        self.agent.observe(action)
        for i in range(len(parsed)):
            if parsed[i] in self.dict:
                # only score words which are in the dictionary
                probs = self.agent.next_word_probability(parsed[:i])
                # get probability of correct answer, divide by total prob mass
                prob_true = probs.get(parsed[i], 0)
                if prob_true > 0:
                    prob_true /= sum((probs.get(k, 0) for k in self.dict.keys()))
                    loss -= math.log(prob_true)
                else:
                    loss = float('inf')
                num_tokens += 1
            else:
                num_unk += 1
        with self._lock():
            self.metrics['exs'] += 1
            self.metrics['loss'] += loss
            self.metrics['num_tokens'] += num_tokens
            self.metrics['num_unk'] += num_unk

    def epoch_done(self):
        return self.task.epoch_done()

    def num_examples(self):
        return self.task.num_examples()

    def num_episodes(self):
        return self.task.num_episodes()

    def share(self):
        shared = super().share()
        shared['metrics'] = self.metrics
        return shared

    def reset_metrics(self):
        with self._lock():
            self.metrics['exs'] = 0
            self.metrics['loss'] = 0
            self.metrics['num_tokens'] = 0
            self.metrics['num_unk'] = 0

    def report(self):
        m = {}
        with self._lock():
            m['exs'] = self.metrics['exs']
            if m['exs'] > 0:
                # m['num_unk'] = self.metrics['num_unk']
                # m['num_tokens'] = self.metrics['num_tokens']
                m['loss'] = round_sigfigs(
                    self.metrics['loss'] / self.metrics['num_tokens'],
                    3
                )
                m['ppl'] = round_sigfigs(
                    math.exp(self.metrics['loss'] / self.metrics['num_tokens']),
                    4
                )
        return m
示例#28
0
文件: metrics.py 项目: jojonki/ParlAI
class Metrics(object):
    """Class that maintains evaluation metrics over dialog."""

    def __init__(self, opt):
        self.metrics = {}
        self.metrics['cnt'] = 0
        self.metrics['correct'] = 0
        self.metrics['f1'] = 0.0
        self.eval_pr = [1, 5, 10, 50, 100]
        for k in self.eval_pr:
            self.metrics['hits@' + str(k)] = 0
        if opt.get('numthreads', 1) > 1:
            self.metrics = SharedTable(self.metrics)
        self.datatype = opt.get('datatype', 'train')

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, exc_traceback):
        pass

    def __str__(self):
        return str(self.metrics)

    def __repr__(self):
        return repr(self.metrics)

    def _lock(self):
        if hasattr(self.metrics, 'get_lock'):
            # use the shared_table's lock
            return self.metrics.get_lock()
        else:
            # otherwise do nothing
            return self

    def update_ranking_metrics(self, observation, labels):
        text_cands = observation.get('text_candidates', None)
        if text_cands is None:
            text = observation.get('text', None)
            if text is None:
                return
            else:
                text_cands = [text]
        # Now loop through text candidates, assuming they are sorted.
        # If any of them is a label then score a point.
        # maintain hits@1, 5, 10, 50, 100,  etc.
        label_set = set(_normalize_answer(l) for l in labels)
        cnts = {k: 0 for k in self.eval_pr}
        cnt = 0
        for c in text_cands:
            cnt += 1
            if _normalize_answer(c) in label_set:
                for k in self.eval_pr:
                    if cnt <= k:
                        cnts[k] += 1
        # hits metric is 1 if cnts[k] > 0.
        # (other metrics such as p@k and r@k take
        # the value of cnt into account.)
        with self._lock():
            for k in self.eval_pr:
                if cnts[k] > 0:
                    self.metrics['hits@' + str(k)] += 1

    def update(self, observation, labels):
        with self._lock():
            self.metrics['cnt'] += 1

        # Exact match metric.
        correct = 0
        prediction = observation.get('text', None)
        if _exact_match(prediction, labels):
            correct = 1
        with self._lock():
            self.metrics['correct'] += correct

        # F1 metric.
        f1 = _f1_score(prediction, labels)
        with self._lock():
            self.metrics['f1'] += f1

        # Ranking metrics.
        self.update_ranking_metrics(observation, labels)

        # Return a dict containing the metrics for this specific example.
        # Metrics across all data is stored internally in the class, and
        # can be accessed with the report method.
        loss = {}
        loss['correct'] = correct
        return loss

    def report(self):
        # Report the metrics over all data seen so far.
        m = {}
        m['total'] = self.metrics['cnt']
        if self.metrics['cnt'] > 0:
            m['accuracy'] = round_sigfigs(
                self.metrics['correct'] / self.metrics['cnt'], 4)
            m['f1'] = round_sigfigs(
                self.metrics['f1'] / self.metrics['cnt'], 4)
            m['hits@k'] = {}
            for k in self.eval_pr:
                m['hits@k'][k] = round_sigfigs(
                    self.metrics['hits@' + str(k)] / self.metrics['cnt'], 4)
        return m

    def clear(self):
        with self._lock():
            self.metrics['cnt'] = 0
            self.metrics['correct'] = 0
            self.metrics['f1'] = 0.0
            for k in self.eval_pr:
                self.metrics['hits@' + str(k)] = 0
示例#29
0
文件: metrics.py 项目: zwcdp/KBRD
class Metrics(object):
    """Class that maintains evaluation metrics over dialog."""
    def __init__(self, opt):
        self.metrics = {}
        self.metrics['cnt'] = 0
        self.metrics_list = ['mean_rank', 'loss', 'correct', 'f1', 'ppl']
        if nltkbleu is not None:
            # only compute bleu if we can
            self.metrics_list.append('bleu')
        for k in self.metrics_list:
            self.metrics[k] = 0.0
            self.metrics[k + '_cnt'] = 0
        self.eval_pr = [1, 5, 10, 100]
        for k in self.eval_pr:
            self.metrics['hits@' + str(k)] = 0
        self.metrics['hits@_cnt'] = 0
        self.flags = {
            'has_text_cands': False,
            'print_prediction_metrics': False
        }
        if opt.get('numthreads', 1) > 1:
            self.metrics = SharedTable(self.metrics)
            self.flags = SharedTable(self.flags)

    def __str__(self):
        return str(self.metrics)

    def __repr__(self):
        representation = super().__repr__()
        return representation.replace('>', ': {}>'.format(repr(self.metrics)))

    def _lock(self):
        if hasattr(self.metrics, 'get_lock'):
            # use the shared_table's lock
            return self.metrics.get_lock()
        else:
            # otherwise do nothing
            return no_lock()

    def update_ranking_metrics(self, observation, labels):
        text_cands = observation.get('text_candidates', None)
        if text_cands is None:
            return
        else:
            # Now loop through text candidates, assuming they are sorted.
            # If any of them is a label then score a point.
            # maintain hits@1, 5, 10, 50, 100,  etc.
            label_set = set(normalize_answer(l) for l in labels)
            cnts = {k: 0 for k in self.eval_pr}
            cnt = 0
            for c in text_cands:
                cnt += 1
                if normalize_answer(c) in label_set:
                    for k in self.eval_pr:
                        if cnt <= k:
                            cnts[k] += 1
            # hits metric is 1 if cnts[k] > 0.
            # (other metrics such as p@k and r@k take
            # the value of cnt into account.)
            with self._lock():
                self.flags['has_text_cands'] = True
                for k in self.eval_pr:
                    if cnts[k] > 0:
                        self.metrics['hits@' + str(k)] += 1
                self.metrics['hits@_cnt'] += 1

    def update(self, observation, labels):
        with self._lock():
            self.metrics['cnt'] += 1

        # Exact match metric.
        correct = 0
        prediction = observation.get('text', None)
        if prediction is not None:
            if _exact_match(prediction, labels):
                correct = 1
            with self._lock():
                self.flags['print_prediction_metrics'] = True
                self.metrics['correct'] += correct
                self.metrics['correct_cnt'] += 1

            # F1 and BLEU metrics.
            f1 = _f1_score(prediction, labels)
            bleu = _bleu(prediction, labels)
            with self._lock():
                self.metrics['f1'] += f1
                self.metrics['f1_cnt'] += 1
                if bleu is not None:
                    self.metrics['bleu'] += bleu
                    self.metrics['bleu_cnt'] += 1

        # Ranking metrics.
        self.update_ranking_metrics(observation, labels)

        # User-reported metrics
        if 'metrics' in observation:
            for k, v in observation['metrics'].items():
                if k not in ['correct', 'f1', 'hits@k', 'bleu']:
                    if k in self.metrics_list:
                        with self._lock():
                            self.metrics[k] += v
                            self.metrics[k + '_cnt'] += 1
                    else:
                        if type(self.metrics) is SharedTable:
                            # can't share custom metrics during hogwild
                            pass
                        else:
                            # no need to lock because not SharedTable
                            if k not in self.metrics:
                                self.metrics[k] = v
                                self.metrics_list.append(k)
                                self.metrics[k + '_cnt'] = 1.0
                            else:
                                self.metrics[k] += v

        # Return a dict containing the metrics for this specific example.
        # Metrics across all data is stored internally in the class, and
        # can be accessed with the report method.
        loss = {}
        loss['correct'] = correct
        return loss

    def report(self):
        # Report the metrics over all data seen so far.
        m = {}
        total = self.metrics['cnt']
        m['exs'] = total
        if total > 0:
            if self.flags['print_prediction_metrics']:
                m['accuracy'] = round_sigfigs(
                    self.metrics['correct'] /
                    max(1, self.metrics['correct_cnt']), 4)
                m['f1'] = round_sigfigs(
                    self.metrics['f1'] / max(1, self.metrics['f1_cnt']), 4)
            if self.flags['has_text_cands']:
                for k in self.eval_pr:
                    m['hits@' + str(k)] = round_sigfigs(
                        self.metrics['hits@' + str(k)] /
                        max(1, self.metrics['hits@_cnt']), 3)
            for k in self.metrics_list:
                if self.metrics[k +
                                '_cnt'] > 0 and k != 'correct' and k != 'f1':
                    m[k] = round_sigfigs(
                        self.metrics[k] / max(1, self.metrics[k + '_cnt']), 4)
        return m

    def clear(self):
        with self._lock():
            self.metrics['cnt'] = 0
            for k in self.metrics_list:
                v = self.metrics[k]
                v_typ = type(v)
                if 'Tensor' in str(v_typ):
                    self.metrics[k].zero_()
                else:
                    self.metrics[k] = 0.0
                self.metrics[k + '_cnt'] = 0
            for k in self.eval_pr:
                self.metrics['hits@' + str(k)] = 0
            self.metrics['hits@_cnt'] = 0
示例#30
0
class Metrics(object):
    """Class that maintains evaluation metrics over dialog."""
    def __init__(self, opt):
        self.metrics = {}
        self.metrics['cnt'] = 0
        self.metrics['correct'] = 0
        self.metrics['f1'] = 0.0
        self.eval_pr = [1, 5, 10, 100]
        for k in self.eval_pr:
            self.metrics['hits@' + str(k)] = 0
        if opt.get('numthreads', 1) > 1:
            self.metrics = SharedTable(self.metrics)
        self.datatype = opt.get('datatype', 'train')
        self.custom_keys = []

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, exc_traceback):
        pass

    def __str__(self):
        return str(self.metrics)

    def __repr__(self):
        return repr(self.metrics)

    def _lock(self):
        if hasattr(self.metrics, 'get_lock'):
            # use the shared_table's lock
            return self.metrics.get_lock()
        else:
            # otherwise do nothing
            return self

    def update_ranking_metrics(self, observation, labels):
        text_cands = observation.get('text_candidates', None)
        if text_cands is None:
            text = observation.get('text', None)
            if text is None:
                return
            else:
                text_cands = [text]
        # Now loop through text candidates, assuming they are sorted.
        # If any of them is a label then score a point.
        # maintain hits@1, 5, 10, 50, 100,  etc.
        label_set = set(_normalize_answer(l) for l in labels)
        cnts = {k: 0 for k in self.eval_pr}
        cnt = 0
        for c in text_cands:
            cnt += 1
            if _normalize_answer(c) in label_set:
                for k in self.eval_pr:
                    if cnt <= k:
                        cnts[k] += 1
        # hits metric is 1 if cnts[k] > 0.
        # (other metrics such as p@k and r@k take
        # the value of cnt into account.)
        with self._lock():
            for k in self.eval_pr:
                if cnts[k] > 0:
                    self.metrics['hits@' + str(k)] += 1

    def update(self, observation, labels):
        with self._lock():
            self.metrics['cnt'] += 1

        # Exact match metric.
        correct = 0
        prediction = observation.get('text', None)
        if _exact_match(prediction, labels):
            correct = 1
        with self._lock():
            self.metrics['correct'] += correct

        # F1 metric.
        f1 = _f1_score(prediction, labels)
        with self._lock():
            self.metrics['f1'] += f1

        # Ranking metrics.
        self.update_ranking_metrics(observation, labels)

        # User-reported metrics
        if 'metrics' in observation:
            for k, v in observation['metrics'].items():
                if k not in ['correct', 'f1', 'hits@k']:
                    with self._lock():
                        if k not in self.metrics:
                            self.custom_keys.append(k)
                            self.metrics[k] = v
                        else:
                            self.metrics[k] += v

        # Return a dict containing the metrics for this specific example.
        # Metrics across all data is stored internally in the class, and
        # can be accessed with the report method.
        loss = {}
        loss['correct'] = correct
        return loss

    def report(self):
        # Report the metrics over all data seen so far.
        m = {}
        total = self.metrics['cnt']
        m['total'] = total
        if total > 0:
            m['accuracy'] = round_sigfigs(self.metrics['correct'] / total, 4)
            m['f1'] = round_sigfigs(self.metrics['f1'] / total, 4)
            m['hits@k'] = {}
            for k in self.eval_pr:
                m['hits@k'][k] = round_sigfigs(
                    self.metrics['hits@' + str(k)] / total, 3)
            for k in self.custom_keys:
                if k in self.metrics:
                    m[k] = round_sigfigs(self.metrics[k] / total, 3)
        return m

    def clear(self):
        with self._lock():
            self.metrics['cnt'] = 0
            self.metrics['correct'] = 0
            self.metrics['f1'] = 0.0
            for k in self.eval_pr:
                self.metrics['hits@' + str(k)] = 0
            for k in self.custom_keys:
                self.metrics.pop(k, None)  # safer then casting to zero