def init_vocab(self, data): assert self.eval == False # for eval vocab must exist charvocab = CharVocab(data, self.args['shorthand']) wordvocab = WordVocab(data, self.args['shorthand'], cutoff=self.cutoff, lower=True) uposvocab = WordVocab(data, self.args['shorthand'], idx=1) xposvocab = xpos_vocab_factory(data, self.args['shorthand']) featsvocab = FeatureVocab(data, self.args['shorthand'], idx=3) lemmavocab = WordVocab(data, self.args['shorthand'], cutoff=self.cutoff, idx=4, lower=True) deprelvocab = WordVocab(data, self.args['shorthand'], idx=6) vocab = MultiVocab({ 'char': charvocab, 'word': wordvocab, 'upos': uposvocab, 'xpos': xposvocab, 'feats': featsvocab, 'lemma': lemmavocab, 'deprel': deprelvocab }) return vocab
def load(self, pretrain, filename): try: checkpoint = torch.load(filename, lambda storage, loc: storage) except BaseException: print("Cannot load model from {}".format(filename)) exit() self.args = checkpoint['config'] self.vocab = MultiVocab.load_state_dict(checkpoint['vocab']) self.model = Parser(self.args, self.vocab, emb_matrix=pretrain.emb) self.model.load_state_dict(checkpoint['model'], strict=False)
def init_vocab(self, data_list): assert self.eval == False # for eval vocab must exist data_all = sum(data_list, []) charvocab = CharVocab(data_all, self.args['shorthand']) # construct wordvocab from multiple files wordvocabs = [WordVocab(data, self.args['shorthand'], cutoff=0, lower=True) for data in data_list] wordset = list(set(sum([v._id2unit[len(VOCAB_PREFIX):len(VOCAB_PREFIX) + self.args['vocab_cutoff']] for v in wordvocabs], []))) wordvocab = wordvocabs[0] wordvocab._id2unit = VOCAB_PREFIX + wordset wordvocab._unit2id = {w: i for i, w in enumerate(wordvocab._id2unit)} print('Constructing a joint word vocabulary of size {} ...'.format(len(wordvocab))) uposvocab = WordVocab(data_all, self.args['shorthand'], idx=1) xposvocab = xpos_vocab_factory(data_all, self.args['shorthand']) featsvocab = FeatureVocab(data_all, self.args['shorthand'], idx=3) lemmavocab = WordVocab(data_all, self.args['shorthand'], cutoff=self.cutoff, idx=4, lower=True) vocab = MultiVocab({'char': charvocab, 'word': wordvocab, 'upos': uposvocab, 'xpos': xposvocab, 'feats': featsvocab, 'lemma': lemmavocab, }) return vocab