Beispiel #1
0
    def build_vocabs(self, dataset, logger=None, transformer=False):
        rel_vocab = self.vocabs.get('rel', None)
        if rel_vocab is None:
            rel_vocab = Vocab(unk_token=None,
                              pad_token=self.config.get('pad_rel', None))
            self.vocabs.put(rel=rel_vocab)

        timer = CountdownTimer(len(dataset))
        if transformer:
            token_vocab = None
        else:
            self.vocabs.token = token_vocab = VocabCounter(
                unk_token=self.config.get('unk', UNK))
        for i, sample in enumerate(dataset):
            timer.log('Building vocab [blink][yellow]...[/yellow][/blink]',
                      ratio_percentage=True)
        min_freq = self.config.get('min_freq', None)
        if min_freq:
            token_vocab.trim(min_freq)
        rel_vocab.set_unk_as_safe_unk()  # Some relation in dev set is OOV
        self.vocabs.lock()
        self.vocabs.summary(logger=logger)
        if token_vocab:
            self.config.n_words = len(self.vocabs['token'])
        self.config.n_rels = len(self.vocabs['rel'])
        if token_vocab:
            self.config.pad_index = self.vocabs['token'].pad_idx
            self.config.unk_index = self.vocabs['token'].unk_idx
Beispiel #2
0
    def build_vocabs(self, dataset, logger=None, transformer=None):
        rel_vocab = self.vocabs.get('rel', None)
        if rel_vocab is None:
            rel_vocab = Vocab(unk_token=None,
                              pad_token=self.config.get('pad_rel', None))
            self.vocabs.put(rel=rel_vocab)
        if self.config.get('feat', None) == 'pos' or self.config.get(
                'use_pos', False):
            self.vocabs['pos'] = Vocab(unk_token=None, pad_token=None)

        timer = CountdownTimer(len(dataset))
        if transformer:
            token_vocab = None
        else:
            token_vocab = Vocab()
            self.vocabs.token = token_vocab
            unk = self.config.get('unk', None)
            if unk is not None:
                token_vocab.unk_token = unk
        if token_vocab and self.config.get('min_freq', None):
            counter = Counter()
            for sample in dataset:
                for form in sample['token']:
                    counter[form] += 1
            reserved_token = [token_vocab.pad_token, token_vocab.unk_token]
            if ROOT in token_vocab:
                reserved_token.append(ROOT)
            freq_words = reserved_token + [
                token for token, freq in counter.items()
                if freq >= self.config.min_freq
            ]
            token_vocab.token_to_idx.clear()
            for word in freq_words:
                token_vocab(word)
        else:
            for i, sample in enumerate(dataset):
                timer.log('vocab building [blink][yellow]...[/yellow][/blink]',
                          ratio_percentage=True)
        rel_vocab.set_unk_as_safe_unk()  # Some relation in dev set is OOV
        self.vocabs.lock()
        self.vocabs.summary(logger=logger)
        if token_vocab:
            self.config.n_words = len(self.vocabs['token'])
        if 'pos' in self.vocabs:
            self.config.n_feats = len(self.vocabs['pos'])
            self.vocabs['pos'].set_unk_as_safe_unk()
        self.config.n_rels = len(self.vocabs['rel'])
        if token_vocab:
            self.config.pad_index = self.vocabs['token'].pad_idx
            self.config.unk_index = self.vocabs['token'].unk_idx