コード例 #1
0
ファイル: test_dict.py プロジェクト: bonbert81/ParlAI
 def test_byte_level_bpe_tokenize(self):
     """
     Tests a bytelevel bpe tokenizer inside ParlAI.
     """
     parser = ParlaiParser()
     parser.set_params(
         dict_tokenizer='bytelevelbpe',
         bpe_vocab=DEFAULT_BYTELEVEL_BPE_VOCAB,
         bpe_merge=DEFAULT_BYTELEVEL_BPE_MERGE,
         bpe_add_prefix_space=False,
     )
     opt = parser.parse_args([], print_args=False)
     agent = DictionaryAgent(opt)
     self.assertEqual(
         # grinning face emoji
         agent.bytelevelbpe_tokenize(u'Hello, ParlAI! \U0001f600'),
         BYTELEVEL_BPE_RESULT,
     )
     self.assertEqual(
         agent.vec2txt([agent.tok2ind[w] for w in BYTELEVEL_BPE_RESULT]),
         # grinning face emoji
         u'Hello, ParlAI! \U0001f600',
     )
     self.assertEqual(
         agent.txt2vec(u'Hello, ParlAI! \U0001f600'),
         [agent.tok2ind[w] for w in BYTELEVEL_BPE_RESULT],
     )
     vocab_size = agent.byte_level_bpe.tokenizer.get_vocab_size()
     with testing_utils.tempdir() as tmpdir:
         path = os.path.join(tmpdir, 'dict-checkpoint')
         agent.save(filename=path)
         agent.load(filename=path)
     # Test loading / saving
     self.assertEqual(vocab_size,
                      agent.byte_level_bpe.tokenizer.get_vocab_size())
     self.assertEqual(
         # grinning face emoji
         agent.bytelevelbpe_tokenize(u'Hello, ParlAI! \U0001f600'),
         BYTELEVEL_BPE_RESULT,
     )
     self.assertEqual(
         agent.vec2txt([agent.tok2ind[w] for w in BYTELEVEL_BPE_RESULT]),
         # grinning face emoji
         u'Hello, ParlAI! \U0001f600',
     )
     self.assertEqual(
         agent.txt2vec(u'Hello, ParlAI! \U0001f600'),
         [agent.tok2ind[w] for w in BYTELEVEL_BPE_RESULT],
     )
     # Test special token ids are mapped correctly:
     # 4 special tokens are added in ParlAI dict in the begining and at the
     # end for Hugging Face null token would be 0 in ParlAI dict and
     # original_vocab in Hugging Face
     assert agent.txt2vec("__null__") == [0]
     assert agent.txt2vec("__start__") == [1]
     assert agent.txt2vec("__end__") == [2]
     assert agent.txt2vec("__unk__") == [3]
コード例 #2
0
def get_dictionary(PATH: str) -> DictionaryAgent:
    """
                    读取字典
                    :param PATH: 字典工具目录
                    :return 读取的字典
                    """
    opt = Opt()
    dictionary = DictionaryAgent(opt=opt)
    dictionary.load(PATH)
    return dictionary
コード例 #3
0
class IrBaselineAgent(Agent):
    """Information Retrieval baseline."""
    @staticmethod
    def add_cmdline_args(parser):
        """Add command line args specific to this agent."""
        parser = parser.add_argument_group('IrBaseline Arguments')
        parser.add_argument('-lp',
                            '--length_penalty',
                            type=float,
                            default=0.5,
                            help='length penalty for responses')
        parser.add_argument(
            '-hsz',
            '--history_size',
            type=int,
            default=1,
            help='number of utterances from the dialogue history to take use '
            'as the query')
        parser.add_argument('--label_candidates_file',
                            type=str,
                            default=None,
                            help='file of candidate responses to choose from')

    def __init__(self, opt, shared=None):
        """Initialize agent."""
        super().__init__(opt)
        self.id = 'IRBaselineAgent'
        self.length_penalty = float(opt['length_penalty'])
        self.dictionary = DictionaryAgent(opt)
        self.opt = opt
        self.history = []
        self.episodeDone = True
        if opt.get('label_candidates_file'):
            f = open(opt.get('label_candidates_file'))
            self.label_candidates = f.read().split('\n')

    def reset(self):
        """Reset agent properties."""
        self.observation = None
        self.history = []
        self.episodeDone = True

    def observe(self, obs):
        """Store and remember incoming observation message dict."""
        self.observation = obs
        self.dictionary.observe(obs)
        if self.episodeDone:
            self.history = []
        if 'text' in obs:
            self.history.append(obs.get('text', ''))
        self.episodeDone = obs.get('episode_done', False)
        return obs

    def act(self):
        """Generate a response to the previously seen observation(s)."""
        if self.opt.get('datatype', '').startswith('train'):
            self.dictionary.act()

        obs = self.observation
        reply = {}
        reply['id'] = self.getID()

        # Rank candidates
        cands = None
        if 'label_candidates' in obs and len(obs['label_candidates']) > 0:
            cands = obs['label_candidates']
        if hasattr(self, 'label_candidates'):
            # override label candidates with candidate file if set
            cands = self.label_candidates
        if cands:
            hist_sz = self.opt.get('history_size', 1)
            left_idx = max(0, len(self.history) - hist_sz)
            text = ' '.join(self.history[left_idx:len(self.history)])
            rep = self.build_query_representation(text)
            reply['text_candidates'] = (rank_candidates(
                rep, cands, self.length_penalty, self.dictionary))
            reply['text'] = reply['text_candidates'][0]
        else:
            reply['text'] = "I don't know."
        return reply

    def save(self, fname=None):
        """Save dictionary tokenizer if available."""
        fname = self.opt.get('model_file', None) if fname is None else fname
        if fname:
            self.dictionary.save(fname + '.dict')

    def load(self, fname):
        """Load internal dictionary."""
        self.dictionary.load(fname + '.dict')

    def build_query_representation(self, query):
        """Build representation of query, e.g. words or n-grams.

        :param query: string to represent.

        :returns: dictionary containing 'words' dictionary (token => frequency)
                  and 'norm' float (square root of the number of tokens)
        """
        rep = {}
        rep['words'] = {}
        words = [w for w in self.dictionary.tokenize(query.lower())]
        rw = rep['words']
        used = {}
        for w in words:
            if len(self.dictionary.freqs()) > 0:
                rw[w] = 1.0 / (1.0 +
                               math.log(1.0 + self.dictionary.freqs()[w]))
            else:
                if w not in stopwords:
                    rw[w] = 1
            used[w] = True
        rep['norm'] = math.sqrt(len(words))
        return rep
コード例 #4
0
class IrBaselineAgent(Agent):

    @staticmethod
    def add_cmdline_args(parser):
        DictionaryAgent.add_cmdline_args(parser)
        parser.add_argument(
            '-lp', '--length_penalty', type=float, default=0.5,
            help='length penalty for responses')
        parser.add_argument(
            '-hsz', '--history_size', type=int, default=1,
            help='number of utterances from the dialogue history to take use as the query')

    def __init__(self, opt, shared=None):
        super().__init__(opt)
        self.id = 'IRBaselineAgent'
        self.length_penalty = float(opt['length_penalty'])
        self.dictionary = DictionaryAgent(opt)
        self.opt = opt
        self.history = []
        self.episodeDone = True

    def reset(self):
        self.observation = None
        self.history = []
        self.episodeDone = True

    def observe(self, obs):
        self.observation = obs
        self.dictionary.observe(obs)
        if self.episodeDone:
            self.history = []
        if 'text' in obs:
            self.history.append(obs.get('text', ''))
        self.episodeDone = obs.get('episode_done', False)
        return obs

    def act(self):
        if self.opt.get('datatype', '').startswith('train'):
            self.dictionary.act()

        obs = self.observation
        reply = {}
        reply['id'] = self.getID()

        # Rank candidates
        if 'label_candidates' in obs and len(obs['label_candidates']) > 0:
            # text = obs['text']
            text = ' '.join(
                self.history[max(0, len(self.history) -
                                 self.opt.get('history_size', 1)):len(self.history)])
            rep = self.build_query_representation(text)
            reply['text_candidates'] = (
                rank_candidates(rep, obs['label_candidates'],
                                self.length_penalty, self.dictionary))
            reply['text'] = reply['text_candidates'][0]
        else:
            reply['text'] = "I don't know."
        return reply

    def save(self, fname=None):
        fname = self.opt.get('model_file', None) if fname is None else fname
        if fname:
            self.dictionary.save(fname + '.dict')

    def load(self, fname):
        self.dictionary.load(fname + '.dict')

    def build_query_representation(self, query):
        """ Build representation of query, e.g. words or n-grams """
        rep = {}
        rep['words'] = {}
        words = [w for w in self.dictionary.tokenize(query.lower())]
        rw = rep['words']
        used = {}
        for w in words:
            if len(self.dictionary.freqs()) > 0:
                rw[w] = 1.0 / (1.0 + math.log(1.0 + self.dictionary.freqs()[w]))
            else:
                if w not in stopwords:
                    rw[w] = 1
            used[w] = True
        rep['norm'] = math.sqrt(len(words))
        return rep
コード例 #5
0
ファイル: ir_baseline.py プロジェクト: ahiroto/ParlAI
class IrBaselineAgent(Agent):

    @staticmethod
    def add_cmdline_args(parser):
        DictionaryAgent.add_cmdline_args(parser)
        parser.add_argument(
            '-lp', '--length_penalty', default=0.5,
            help='length penalty for responses')

    def __init__(self, opt, shared=None):
        super().__init__(opt)
        self.id = 'IRBaselineAgent'
        self.length_penalty = float(opt['length_penalty'])
        self.dictionary = DictionaryAgent(opt)
        self.opt = opt

    def observe(self, obs):
        self.observation = obs
        self.dictionary.observe(obs)
        return obs

    def act(self):
        if self.opt.get('datatype', '').startswith('train'):
            self.dictionary.act()

        obs = self.observation
        reply = {}
        reply['id'] = self.getID()

        # Rank candidates
        if 'label_candidates' in obs and len(obs['label_candidates']) > 0:
            rep = self.build_query_representation(obs['text'])
            reply['text_candidates'] = (
                rank_candidates(rep, obs['label_candidates'],
                                self.length_penalty, self.dictionary))
            reply['text'] = reply['text_candidates'][0]
        else:
            reply['text'] = "I don't know."
        return reply

    def save(self, fname=None):
        fname = self.opt.get('model_file', None) if fname is None else fname
        if fname:
            self.dictionary.save(fname + '.dict')

    def load(self, fname):
        self.dictionary.load(fname + '.dict')

    def build_query_representation(self, query):
        """ Build representation of query, e.g. words or n-grams """
        rep = {}
        rep['words'] = {}
        words = [w for w in self.dictionary.tokenize(query.lower())]
        rw = rep['words']
        used = {}
        for w in words:
            if len(self.dictionary.freqs()) > 0:
                rw[w] = 1.0 / (1.0 + math.log(1.0 + self.dictionary.freqs()[w]))
            else:
                if w not in stopwords:
                    rw[w] = 1
            used[w] = True
        norm = len(used)
        rep['norm'] = math.sqrt(len(words))
        return rep
コード例 #6
0
ファイル: ir_baseline.py プロジェクト: rikima/ParlAI
class IrBaselineAgent(Agent):
    @staticmethod
    def add_cmdline_args(parser):
        DictionaryAgent.add_cmdline_args(parser)
        parser.add_argument('-lp',
                            '--length_penalty',
                            default=0.5,
                            help='length penalty for responses')

    def __init__(self, opt, shared=None):
        super().__init__(opt)
        self.id = 'IRBaselineAgent'
        self.length_penalty = float(opt['length_penalty'])
        self.dictionary = DictionaryAgent(opt)
        self.opt = opt

    def observe(self, obs):
        self.observation = obs
        self.dictionary.observe(obs)
        return obs

    def act(self):
        if self.opt.get('datatype', '').startswith('train'):
            self.dictionary.act()

        obs = self.observation
        reply = {}
        reply['id'] = self.getID()

        # Rank candidates
        if 'label_candidates' in obs and len(obs['label_candidates']) > 0:
            rep = self.build_query_representation(obs['text'])
            reply['text_candidates'] = (rank_candidates(
                rep, obs['label_candidates'], self.length_penalty,
                self.dictionary))
            reply['text'] = reply['text_candidates'][0]
        else:
            reply['text'] = "I don't know."
        return reply

    def save(self, fname=None):
        fname = self.opt.get('model_file', None) if fname is None else fname
        if fname:
            self.dictionary.save(fname + '.dict')

    def load(self, fname):
        self.dictionary.load(fname + '.dict')

    def build_query_representation(self, query):
        """ Build representation of query, e.g. words or n-grams """
        rep = {}
        rep['words'] = {}
        words = [w for w in self.dictionary.tokenize(query.lower())]
        rw = rep['words']
        used = {}
        for w in words:
            if len(self.dictionary.freqs()) > 0:
                rw[w] = 1.0 / (1.0 +
                               math.log(1.0 + self.dictionary.freqs()[w]))
            else:
                if w not in stopwords:
                    rw[w] = 1
            used[w] = True
        norm = len(used)
        rep['norm'] = math.sqrt(len(words))
        return rep
コード例 #7
0
def eval_wordstat(opt, print_parser=None):
    """Evaluates a model.

    Arguments:
    opt -- tells the evaluation function how to run
    print_parser -- if provided, prints the options that are set within the
        model after loading the model
    """
    random.seed(42)

    # Create model and assign it to the specified task
    agent = create_agent(opt, requireModelExists=True)
    world = create_task(opt, agent)

    if opt['external_dict'] is not None:
        print('[ Using external dictionary from: {} ]'.format(
            opt['external_dict']))
        dictionary = DictionaryAgent(opt)
        dictionary.load(opt['external_dict'])
    else:
        print('[ Using model bundled dictionary ]')
        dictionary = agent.dict

    if print_parser:
        # Show arguments after loading model
        print_parser.opt = agent.opt
        print_parser.print_args()
    log_every_n_secs = opt.get('log_every_n_secs', -1)
    if log_every_n_secs <= 0:
        log_every_n_secs = float('inf')
    log_time = TimeLogger()

    cnt = 0
    mean_wlength = []
    mean_clength = []
    freqs_cnt = Counter()
    word_cnt = 0
    bins = [int(i) for i in opt['freq_bins'].split(',')]

    while not world.epoch_done():
        cnt += 1
        world.parley()
        prediction = world.acts[-1]['text']
        freqs, _cnt, wlength, clength = get_word_stats(prediction,
                                                       dictionary,
                                                       bins=bins)
        word_cnt += _cnt

        mean_wlength.append(wlength)
        mean_clength.append(clength)

        freqs_cnt += Counter(freqs)

        if log_time.time() > log_every_n_secs:
            report = world.report()
            text, report = log_time.log(report['exs'], world.num_examples(),
                                        report)
            print(text)
            stat_str = 'total_words: {}, '.format(word_cnt) + ', '.join([
                '<{}:{} ({:.{prec}f}%)'.format(
                    b,
                    freqs_cnt.get(b, 0),
                    (freqs_cnt.get(b, 0) / word_cnt) * 100,
                    prec=2) for b in bins
            ])
            print(
                "Word statistics: {}, avg_word_length: {:.{prec}f}, avg_char_length: {:.{prec}f}"
                .format(stat_str,
                        numpy.array(mean_wlength).mean(),
                        numpy.array(mean_clength).mean(),
                        prec=2))
        if opt['num_examples'] > 0 and cnt >= opt['num_examples']:
            break
    if world.epoch_done():
        print("EPOCH DONE")
    report = world.report()
    print(report)
    return report