Esempio n. 1
0
File: cmd.py Progetto: shtechair/ACE
class CMD(object):
    def __call__(self, args):
        self.args = args
        if not hasattr(self.args, 'interpolation'):
            self.args.interpolation = 0.5
        if not os.path.exists(args.file):
            os.mkdir(args.file)
        if not os.path.exists(args.fields) or args.preprocess:
            print("Preprocess the data")
            self.WORD = Field('words', pad=pad, unk=unk, bos=bos, lower=True)
            # if args.feat == 'char':
            #     self.FEAT = CharField('chars', pad=pad, unk=unk, bos=bos,
            #                           fix_len=args.fix_len, tokenize=list)
            # elif args.feat == 'bert':
            #     tokenizer = BertTokenizer.from_pretrained(args.bert_model)
            #     self.FEAT = BertField('bert', pad='[PAD]', bos='[CLS]',
            #                           tokenize=tokenizer.encode)
            # else:
            #     self.FEAT = Field('tags', bos=bos)

            self.CHAR_FEAT = None
            self.POS_FEAT = None
            self.BERT_FEAT = None
            self.FEAT = [self.WORD]
            if args.use_char:
                self.CHAR_FEAT = CharField('chars',
                                           pad=pad,
                                           unk=unk,
                                           bos=bos,
                                           fix_len=args.fix_len,
                                           tokenize=list)
                self.FEAT.append(self.CHAR_FEAT)
            if args.use_pos:
                self.POS_FEAT = Field('tags', bos=bos)
            if args.use_bert:
                tokenizer = BertTokenizer.from_pretrained(args.bert_model)
                self.BERT_FEAT = BertField('bert',
                                           pad='[PAD]',
                                           bos='[CLS]',
                                           tokenize=tokenizer.encode)
                self.FEAT.append(self.BERT_FEAT)

            self.HEAD = Field('heads', bos=bos, use_vocab=False, fn=int)
            self.REL = Field('rels', bos=bos)

            self.fields = CoNLL(FORM=self.FEAT,
                                CPOS=self.POS_FEAT,
                                HEAD=self.HEAD,
                                DEPREL=self.REL)
            # if args.feat in ('char', 'bert'):
            #     self.fields = CoNLL(FORM=(self.WORD, self.FEAT),
            #                         HEAD=self.HEAD, DEPREL=self.REL)
            # else:
            #     self.fields = CoNLL(FORM=self.WORD, CPOS=self.FEAT,
            #                         HEAD=self.HEAD, DEPREL=self.REL)

            train = Corpus.load(args.ftrain, self.fields)
            if args.fembed:
                embed = Embedding.load(args.fembed, args.unk)
            else:
                embed = None
            self.WORD.build(train, args.min_freq, embed)
            if args.use_char:
                self.CHAR_FEAT.build(train)
            if args.use_pos:
                self.POS_FEAT.build(train)
            if args.use_bert:
                self.BERT_FEAT.build(train)
            # self.FEAT.build(train)
            self.REL.build(train)
            torch.save(self.fields, args.fields)
        else:
            self.fields = torch.load(args.fields)
            if args.feat in ('char', 'bert'):
                self.WORD, self.FEAT = self.fields.FORM
            else:
                self.WORD, self.FEAT = self.fields.FORM, self.fields.CPOS
            self.HEAD, self.REL = self.fields.HEAD, self.fields.DEPREL
        self.puncts = torch.tensor([
            i for s, i in self.WORD.vocab.stoi.items() if ispunct(s)
        ]).to(args.device)
        self.rel_criterion = nn.CrossEntropyLoss()
        self.arc_criterion = nn.CrossEntropyLoss()
        if args.binary:
            self.arc_criterion = nn.BCEWithLogitsLoss(reduction='none')

        # print(f"{self.WORD}\n{self.FEAT}\n{self.HEAD}\n{self.REL}")
        print(f"{self.WORD}\n{self.HEAD}\n{self.REL}")
        update_info = {}
        # pdb.set_trace()
        if args.use_char:
            update_info['n_char_feats'] = len(self.CHAR_FEAT.vocab)
        if args.use_pos:
            update_info['n_pos_feats'] = len(self.POS_FEAT.vocab)
        args.update({
            'n_words': self.WORD.vocab.n_init,
            # 'n_feats': len(self.FEAT.vocab),
            'n_rels': len(self.REL.vocab),
            'pad_index': self.WORD.pad_index,
            'unk_index': self.WORD.unk_index,
            'bos_index': self.WORD.bos_index
        })
        args.update(update_info)

    def train(self, loader):
        self.model.train()
        for vals in loader:
            words = vals[0]
            feats = vals[1:-2]
            arcs, rels = vals[-2:]
            self.optimizer.zero_grad()

            mask = words.ne(self.args.pad_index)
            # ignore the first token of each sentence
            mask[:, 0] = 0
            arc_scores, rel_scores = self.model(words, feats)
            loss = self.get_loss(arc_scores,
                                 rel_scores,
                                 arcs,
                                 rels,
                                 mask,
                                 words=words)
            loss.backward()
            nn.utils.clip_grad_norm_(self.model.parameters(), self.args.clip)
            self.optimizer.step()
            self.scheduler.step()

    @torch.no_grad()
    def evaluate(self, loader):
        self.model.eval()

        loss, metric = 0, Metric()

        for vals in loader:
            words = vals[0]
            feats = vals[1:-2]
            arcs, rels = vals[-2:]
            mask = words.ne(self.args.pad_index)
            # ignore the first token of each sentence
            mask[:, 0] = 0
            arc_scores, rel_scores = self.model(words, feats)
            loss += self.get_loss(arc_scores,
                                  rel_scores,
                                  arcs,
                                  rels,
                                  mask,
                                  words=words)
            arc_preds, rel_preds = self.decode(arc_scores, rel_scores, mask)
            # ignore all punctuation if not specified
            if not self.args.punct:
                mask &= words.unsqueeze(-1).ne(self.puncts).all(-1)
            metric(arc_preds, rel_preds, arcs, rels, mask)
        loss /= len(loader)

        return loss, metric

    @torch.no_grad()
    def predict(self, loader):
        self.model.eval()

        all_arcs, all_rels = [], []
        for vals in loader:
            words = vals[0]
            feats = vals[2:]

            mask = words.ne(self.args.pad_index)
            # ignore the first token of each sentence
            mask[:, 0] = 0
            lens = mask.sum(1).tolist()
            arc_scores, rel_scores = self.model(words, feats)
            arc_preds, rel_preds = self.decode(arc_scores, rel_scores, mask)
            all_arcs.extend(arc_preds[mask].split(lens))
            all_rels.extend(rel_preds[mask].split(lens))
        all_arcs = [seq.tolist() for seq in all_arcs]
        all_rels = [self.REL.vocab.id2token(seq.tolist()) for seq in all_rels]

        return all_arcs, all_rels

    def get_loss(self, arc_scores, rel_scores, arcs, rels, mask, words=None):
        if self.args.binary:
            full_mask = mask.clone()
            full_mask[:, 0] = 1
            binary_mask = mask.unsqueeze(-1) * full_mask.unsqueeze(-2)

            arc_target = torch.zeros_like(arc_scores)
            res = arc_target.scatter(-1, arcs.unsqueeze(-1), 1)
            arc_scores = arc_scores * binary_mask
            arc_loss = self.arc_criterion(arc_scores, res)
            '''
            # sampling the zero part
            zero_mask=1-res
            keep_prob=2*res.shape[1]/(res.shape[1]*res.shape[2])
            sample_val=zero_mask.new_empty(zero_mask.shape).bernoulli_(keep_prob)
            binary_mask=sample_val*zero_mask*binary_mask+res
            '''
            arc_loss = (arc_loss * binary_mask).sum() / binary_mask.sum()
            if torch.isnan(arc_loss).any():
                pdb.set_trace()
            arc_scores, arcs = arc_scores[mask], arcs[mask]
        else:
            arc_scores, arcs = arc_scores[mask], arcs[mask]
            arc_loss = self.arc_criterion(arc_scores, arcs)
        rel_scores, rels = rel_scores[mask], rels[mask]
        rel_scores = rel_scores[torch.arange(len(arcs)), arcs]

        rel_loss = self.rel_criterion(rel_scores, rels)
        # if self.args.binary:
        loss = 2 * ((1 - self.args.interpolation) * arc_loss +
                    self.args.interpolation * rel_loss)
        # else:
        #     loss = arc_loss + rel_loss

        return loss

    def decode(self, arc_scores, rel_scores, mask):
        if self.args.tree:
            arc_preds = eisner(arc_scores, mask)
        else:
            arc_preds = arc_scores.argmax(-1)
        rel_preds = rel_scores.argmax(-1)
        rel_preds = rel_preds.gather(-1, arc_preds.unsqueeze(-1)).squeeze(-1)

        return arc_preds, rel_preds
Esempio n. 2
0
File: cmd.py Progetto: shtechair/ACE
    def __call__(self, args):
        self.args = args
        if not hasattr(self.args, 'interpolation'):
            self.args.interpolation = 0.5
        if not os.path.exists(args.file):
            os.mkdir(args.file)
        if not os.path.exists(args.fields) or args.preprocess:
            print("Preprocess the data")
            self.WORD = Field('words', pad=pad, unk=unk, bos=bos, lower=True)
            # if args.feat == 'char':
            #     self.FEAT = CharField('chars', pad=pad, unk=unk, bos=bos,
            #                           fix_len=args.fix_len, tokenize=list)
            # elif args.feat == 'bert':
            #     tokenizer = BertTokenizer.from_pretrained(args.bert_model)
            #     self.FEAT = BertField('bert', pad='[PAD]', bos='[CLS]',
            #                           tokenize=tokenizer.encode)
            # else:
            #     self.FEAT = Field('tags', bos=bos)

            self.CHAR_FEAT = None
            self.POS_FEAT = None
            self.BERT_FEAT = None
            self.FEAT = [self.WORD]
            if args.use_char:
                self.CHAR_FEAT = CharField('chars',
                                           pad=pad,
                                           unk=unk,
                                           bos=bos,
                                           fix_len=args.fix_len,
                                           tokenize=list)
                self.FEAT.append(self.CHAR_FEAT)
            if args.use_pos:
                self.POS_FEAT = Field('tags', bos=bos)
            if args.use_bert:
                tokenizer = BertTokenizer.from_pretrained(args.bert_model)
                self.BERT_FEAT = BertField('bert',
                                           pad='[PAD]',
                                           bos='[CLS]',
                                           tokenize=tokenizer.encode)
                self.FEAT.append(self.BERT_FEAT)

            self.HEAD = Field('heads', bos=bos, use_vocab=False, fn=int)
            self.REL = Field('rels', bos=bos)

            self.fields = CoNLL(FORM=self.FEAT,
                                CPOS=self.POS_FEAT,
                                HEAD=self.HEAD,
                                DEPREL=self.REL)
            # if args.feat in ('char', 'bert'):
            #     self.fields = CoNLL(FORM=(self.WORD, self.FEAT),
            #                         HEAD=self.HEAD, DEPREL=self.REL)
            # else:
            #     self.fields = CoNLL(FORM=self.WORD, CPOS=self.FEAT,
            #                         HEAD=self.HEAD, DEPREL=self.REL)

            train = Corpus.load(args.ftrain, self.fields)
            if args.fembed:
                embed = Embedding.load(args.fembed, args.unk)
            else:
                embed = None
            self.WORD.build(train, args.min_freq, embed)
            if args.use_char:
                self.CHAR_FEAT.build(train)
            if args.use_pos:
                self.POS_FEAT.build(train)
            if args.use_bert:
                self.BERT_FEAT.build(train)
            # self.FEAT.build(train)
            self.REL.build(train)
            torch.save(self.fields, args.fields)
        else:
            self.fields = torch.load(args.fields)
            if args.feat in ('char', 'bert'):
                self.WORD, self.FEAT = self.fields.FORM
            else:
                self.WORD, self.FEAT = self.fields.FORM, self.fields.CPOS
            self.HEAD, self.REL = self.fields.HEAD, self.fields.DEPREL
        self.puncts = torch.tensor([
            i for s, i in self.WORD.vocab.stoi.items() if ispunct(s)
        ]).to(args.device)
        self.rel_criterion = nn.CrossEntropyLoss()
        self.arc_criterion = nn.CrossEntropyLoss()
        if args.binary:
            self.arc_criterion = nn.BCEWithLogitsLoss(reduction='none')

        # print(f"{self.WORD}\n{self.FEAT}\n{self.HEAD}\n{self.REL}")
        print(f"{self.WORD}\n{self.HEAD}\n{self.REL}")
        update_info = {}
        # pdb.set_trace()
        if args.use_char:
            update_info['n_char_feats'] = len(self.CHAR_FEAT.vocab)
        if args.use_pos:
            update_info['n_pos_feats'] = len(self.POS_FEAT.vocab)
        args.update({
            'n_words': self.WORD.vocab.n_init,
            # 'n_feats': len(self.FEAT.vocab),
            'n_rels': len(self.REL.vocab),
            'pad_index': self.WORD.pad_index,
            'unk_index': self.WORD.unk_index,
            'bos_index': self.WORD.bos_index
        })
        args.update(update_info)
Esempio n. 3
0
    def __call__(self, args):
        self.args = args
        if not os.path.exists(args.file):
            os.mkdir(args.file)
        if not os.path.exists(args.fields) or args.preprocess:
            print("Preprocess the data")
            self.WORD = Field('words', pad=pad, unk=unk, bos=bos, lower=True)
            if args.feat == 'char':
                self.FEAT = CharField('chars',
                                      pad=pad,
                                      unk=unk,
                                      bos=bos,
                                      fix_len=args.fix_len,
                                      tokenize=list)
            elif args.feat == 'bert':
                tokenizer = BertTokenizer.from_pretrained(args.bert_model)
                self.FEAT = BertField('bert',
                                      pad='[PAD]',
                                      bos='[CLS]',
                                      tokenize=tokenizer.encode)
            else:
                self.FEAT = Field('tags', bos=bos)
            self.HEAD = Field('heads', bos=bos, use_vocab=False, fn=int)
            self.REL = Field('rels', bos=bos)
            if args.feat in ('char', 'bert'):
                self.fields = CoNLL(FORM=(self.WORD, self.FEAT),
                                    HEAD=self.HEAD,
                                    DEPREL=self.REL)
            else:
                self.fields = CoNLL(FORM=self.WORD,
                                    CPOS=self.FEAT,
                                    HEAD=self.HEAD,
                                    DEPREL=self.REL)

            train = Corpus.load(args.ftrain, self.fields)
            if args.fembed:
                embed = Embedding.load(args.fembed, args.unk)
            else:
                embed = None
            self.WORD.build(train, args.min_freq, embed)
            self.FEAT.build(train)
            self.REL.build(train)
            torch.save(self.fields, args.fields)
        else:
            self.fields = torch.load(args.fields)
            if args.feat in ('char', 'bert'):
                self.WORD, self.FEAT = self.fields.FORM
            else:
                self.WORD, self.FEAT = self.fields.FORM, self.fields.CPOS
            self.HEAD, self.REL = self.fields.HEAD, self.fields.DEPREL
        self.puncts = torch.tensor([
            i for s, i in self.WORD.vocab.stoi.items() if ispunct(s)
        ]).to(args.device)
        self.criterion = nn.CrossEntropyLoss()

        print(f"{self.WORD}\n{self.FEAT}\n{self.HEAD}\n{self.REL}")
        args.update({
            'n_words': self.WORD.vocab.n_init,
            'n_feats': len(self.FEAT.vocab),
            'n_rels': len(self.REL.vocab),
            'pad_index': self.WORD.pad_index,
            'unk_index': self.WORD.unk_index,
            'bos_index': self.WORD.bos_index
        })
Esempio n. 4
0
    def __call__(self, args):
        self.args = args
        logging.basicConfig(filename=args.output, filemode='w', format='%(asctime)s %(levelname)-8s %(message)s', level=logging.INFO, datefmt='%Y-%m-%d %H:%M:%S')
        
        args.ud_dataset = {
                'en': (
                    'data/ud/UD_English-EWT/en_ewt-ud-train.conllx',
                    'data/ud/UD_English-EWT/en_ewt-ud-dev.conllx',
                    'data/ud/UD_English-EWT/en_ewt-ud-test.conllx',
                    "data/fastText_data/wiki.en.ewt.vec.new",
                ),
                'en20': (
                    'data/ud/UD_English-EWT/en_ewt-ud-train20.conllx',
                    'data/ud/UD_English-EWT/en_ewt-ud-dev.conllx',
                    'data/ud/UD_English-EWT/en_ewt-ud-test.conllx',
                    "data/fastText_data/wiki.en.ewt.vec.new",
                ),
                'en40': (
                    'data/ud/UD_English-EWT/en_ewt-ud-train40.conllx',
                    'data/ud/UD_English-EWT/en_ewt-ud-dev.conllx',
                    'data/ud/UD_English-EWT/en_ewt-ud-test.conllx',
                    "data/fastText_data/wiki.en.ewt.vec.new",
                ),
                'en60': (
                    'data/ud/UD_English-EWT/en_ewt-ud-train60.conllx',
                    'data/ud/UD_English-EWT/en_ewt-ud-dev.conllx',
                    'data/ud/UD_English-EWT/en_ewt-ud-test.conllx',
                    "data/fastText_data/wiki.en.ewt.vec.new",
                ),
                'en80': (
                    'data/ud/UD_English-EWT/en_ewt-ud-train80.conllx',
                    'data/ud/UD_English-EWT/en_ewt-ud-dev.conllx',
                    'data/ud/UD_English-EWT/en_ewt-ud-test.conllx',
                    "data/fastText_data/wiki.en.ewt.vec.new",
                ),
                'ar': (
                    "data/ud/UD_Arabic-PADT/ar_padt-ud-train.conllx",
                    "data/ud/UD_Arabic-PADT/ar_padt-ud-dev.conllx",
                    "data/ud/UD_Arabic-PADT/ar_padt-ud-test.conllx",
                    "data/fastText_data/wiki.ar.padt.vec.new",
                ),
                'ar20': (
                    "data/ud/UD_Arabic-PADT/ar_padt-ud-train20.conllx",
                    "data/ud/UD_Arabic-PADT/ar_padt-ud-dev.conllx",
                    "data/ud/UD_Arabic-PADT/ar_padt-ud-test.conllx",
                    "data/fastText_data/wiki.ar.padt.vec.new",
                ),
                'ar40': (
                    "data/ud/UD_Arabic-PADT/ar_padt-ud-train40.conllx",
                    "data/ud/UD_Arabic-PADT/ar_padt-ud-dev.conllx",
                    "data/ud/UD_Arabic-PADT/ar_padt-ud-test.conllx",
                    "data/fastText_data/wiki.ar.padt.vec.new",
                ),
                'ar60': (
                    "data/ud/UD_Arabic-PADT/ar_padt-ud-train60.conllx",
                    "data/ud/UD_Arabic-PADT/ar_padt-ud-dev.conllx",
                    "data/ud/UD_Arabic-PADT/ar_padt-ud-test.conllx",
                    "data/fastText_data/wiki.ar.padt.vec.new",
                ),
                'ar80': (
                    "data/ud/UD_Arabic-PADT/ar_padt-ud-train80.conllx",
                    "data/ud/UD_Arabic-PADT/ar_padt-ud-dev.conllx",
                    "data/ud/UD_Arabic-PADT/ar_padt-ud-test.conllx",
                    "data/fastText_data/wiki.ar.padt.vec.new",
                ),
                'bg': (
                    "data/ud/UD_Bulgarian-BTB/bg_btb-ud-train.conllx",
                    "data/ud/UD_Bulgarian-BTB/bg_btb-ud-dev.conllx",
                    "data/ud/UD_Bulgarian-BTB/bg_btb-ud-test.conllx",
                    "data/fastText_data/wiki.bg.btb.vec.new",
                ),
                'da': (
                    "data/ud/UD_Danish-DDT/da_ddt-ud-train.conllx",
                    "data/ud/UD_Danish-DDT/da_ddt-ud-dev.conllx",
                    "data/ud/UD_Danish-DDT/da_ddt-ud-test.conllx",
                    "data/fastText_data/wiki.da.ddt.vec.new",
                ),
                'de': (
                    "data/ud/UD_German-GSD/de_gsd-ud-train.conllx",
                    "data/ud/UD_German-GSD/de_gsd-ud-dev.conllx",
                    "data/ud/UD_German-GSD/de_gsd-ud-test.conllx",
                    "data/fastText_data/wiki.de.gsd.vec.new",
                ),
                'es': (
                    "data/ud/UD_Spanish-GSDAnCora/es_gsdancora-ud-train.conllx",
                    "data/ud/UD_Spanish-GSDAnCora/es_gsdancora-ud-dev.conllx",
                    "data/ud/UD_Spanish-GSDAnCora/es_gsdancora-ud-test.conllx",
                    "data/fastText_data/wiki.es.gsdancora.vec.new",
                ),
                'es20': (
                    "data/ud/UD_Spanish-GSDAnCora/es_gsdancora-ud-train20.conllx",
                    "data/ud/UD_Spanish-GSDAnCora/es_gsdancora-ud-dev.conllx",
                    "data/ud/UD_Spanish-GSDAnCora/es_gsdancora-ud-test.conllx",
                    "data/fastText_data/wiki.es.gsdancora.vec.new",
                ),
                'es40': (
                    "data/ud/UD_Spanish-GSDAnCora/es_gsdancora-ud-train40.conllx",
                    "data/ud/UD_Spanish-GSDAnCora/es_gsdancora-ud-dev.conllx",
                    "data/ud/UD_Spanish-GSDAnCora/es_gsdancora-ud-test.conllx",
                    "data/fastText_data/wiki.es.gsdancora.vec.new",
                ),
                'es60': (
                    "data/ud/UD_Spanish-GSDAnCora/es_gsdancora-ud-train60.conllx",
                    "data/ud/UD_Spanish-GSDAnCora/es_gsdancora-ud-dev.conllx",
                    "data/ud/UD_Spanish-GSDAnCora/es_gsdancora-ud-test.conllx",
                    "data/fastText_data/wiki.es.gsdancora.vec.new",
                ),
                'es80': (
                    "data/ud/UD_Spanish-GSDAnCora/es_gsdancora-ud-train80.conllx",
                    "data/ud/UD_Spanish-GSDAnCora/es_gsdancora-ud-dev.conllx",
                    "data/ud/UD_Spanish-GSDAnCora/es_gsdancora-ud-test.conllx",
                    "data/fastText_data/wiki.es.gsdancora.vec.new",
                ),
                'fa': (
                    "data/ud/UD_Persian-Seraji/fa_seraji-ud-train.conllx",
                    "data/ud/UD_Persian-Seraji/fa_seraji-ud-dev.conllx",
                    "data/ud/UD_Persian-Seraji/fa_seraji-ud-test.conllx",
                    "data/fastText_data/wiki.fa.seraji.vec.new",
                ),
                'fr': (
                    "data/ud/UD_French-GSD/fr_gsd-ud-train.conllx",
                    "data/ud/UD_French-GSD/fr_gsd-ud-dev.conllx",
                    "data/ud/UD_French-GSD/fr_gsd-ud-test.conllx",
                    "data/fastText_data/wiki.fr.gsd.vec.new",
                ),
                'he': (
                    "data/ud/UD_Hebrew-HTB/he_htb-ud-train.conllx",
                    "data/ud/UD_Hebrew-HTB/he_htb-ud-dev.conllx",
                    "data/ud/UD_Hebrew-HTB/he_htb-ud-test.conllx",
                    "data/fastText_data/wiki.he.htb.vec.new",
                ),
                'hi': (
                    "data/ud/UD_Hindi-HDTB/hi_hdtb-ud-train.conllx",
                    "data/ud/UD_Hindi-HDTB/hi_hdtb-ud-dev.conllx",
                    "data/ud/UD_Hindi-HDTB/hi_hdtb-ud-test.conllx",
                    "data/fastText_data/wiki.hi.hdtb.vec.new",
                ),
                'hr': (
                    "data/ud/UD_Croatian-SET/hr_set-ud-train.conllx",
                    "data/ud/UD_Croatian-SET/hr_set-ud-dev.conllx",
                    "data/ud/UD_Croatian-SET/hr_set-ud-test.conllx",
                    "data/fastText_data/wiki.hr.set.vec.new",
                ),
                'id': (
                    "data/ud/UD_Indonesian-GSD/id_gsd-ud-train.conllx",
                    "data/ud/UD_Indonesian-GSD/id_gsd-ud-dev.conllx",
                    "data/ud/UD_Indonesian-GSD/id_gsd-ud-test.conllx",
                    "data/fastText_data/wiki.id.gsd.vec.new",
                ),
                'it': (
                    "data/ud/UD_Italian-ISDT/it_isdt-ud-train.conllx",
                    "data/ud/UD_Italian-ISDT/it_isdt-ud-dev.conllx",
                    "data/ud/UD_Italian-ISDT/it_isdt-ud-test.conllx",
                    "data/fastText_data/wiki.it.isdt.vec.new",
                ),
                'ja': (
                    "data/ud/UD_Japanese-GSD/ja_gsd-ud-train.conllx",
                    "data/ud/UD_Japanese-GSD/ja_gsd-ud-dev.conllx",
                    "data/ud/UD_Japanese-GSD/ja_gsd-ud-test.conllx",
                    "data/fastText_data/wiki.ja.gsd.vec.new",
                ),
                'ko': (
                    "data/ud/UD_Korean-GSDKaist/ko_gsdkaist-ud-train.conllx",
                    "data/ud/UD_Korean-GSDKaist/ko_gsdkaist-ud-dev.conllx",
                    "data/ud/UD_Korean-GSDKaist/ko_gsdkaist-ud-test.conllx",
                    "data/fastText_data/wiki.ko.gsdkaist.vec.new",
                ),
                'nl': (
                    "data/ud/UD_Dutch-AlpinoLassySmall/nl_alpinolassysmall-ud-train.conllx",
                    "data/ud/UD_Dutch-AlpinoLassySmall/nl_alpinolassysmall-ud-dev.conllx",
                    "data/ud/UD_Dutch-AlpinoLassySmall/nl_alpinolassysmall-ud-test.conllx",
                    "data/fastText_data/wiki.nl.alpinolassysmall.vec.new",
                ),
                'no': (
                    "data/ud/UD_Norwegian-BokmaalNynorsk/no_bokmaalnynorsk-ud-train.conllx",
                    "data/ud/UD_Norwegian-BokmaalNynorsk/no_bokmaalnynorsk-ud-dev.conllx",
                    "data/ud/UD_Norwegian-BokmaalNynorsk/no_bokmaalnynorsk-ud-test.conllx",
                    "data/fastText_data/wiki.no.bokmaalnynorsk.vec.new",
                ),
                'pt': (
                    "data/ud/UD_Portuguese-BosqueGSD/pt_bosquegsd-ud-train.conllx",
                    "data/ud/UD_Portuguese-BosqueGSD/pt_bosquegsd-ud-dev.conllx",
                    "data/ud/UD_Portuguese-BosqueGSD/pt_bosquegsd-ud-test.conllx",
                    "data/fastText_data/wiki.pt.bosquegsd.vec.new",
                ),
                'sv': (
                    "data/ud/UD_Swedish-Talbanken/sv_talbanken-ud-train.conllx",
                    "data/ud/UD_Swedish-Talbanken/sv_talbanken-ud-dev.conllx",
                    "data/ud/UD_Swedish-Talbanken/sv_talbanken-ud-test.conllx",
                    "data/fastText_data/wiki.sv.talbanken.vec.new",
                ),
                'tr': (
                    "data/ud/UD_Turkish-IMST/tr_imst-ud-train.conllx",
                    "data/ud/UD_Turkish-IMST/tr_imst-ud-dev.conllx",
                    "data/ud/UD_Turkish-IMST/tr_imst-ud-test.conllx",
                    "data/fastText_data/wiki.tr.imst.vec.new",
                ),
                'zh': (
                    "data/ud/UD_Chinese-GSD/zh_gsd-ud-train.conllx",
                    "data/ud/UD_Chinese-GSD/zh_gsd-ud-dev.conllx",
                    "data/ud/UD_Chinese-GSD/zh_gsd-ud-test.conllx",
                    "data/fastText_data/wiki.zh.gsd.vec.new",
                )}

        self.args.ftrain = args.ud_dataset[args.lang][0]
        self.args.fdev = args.ud_dataset[args.lang][1]
        self.args.ftest = args.ud_dataset[args.lang][2]
        self.args.fembed = args.ud_dataset[args.lang][3]

        if not os.path.exists(args.file):
            os.mkdir(args.file)
        if not os.path.exists(args.fields) or args.preprocess:
            logging.info("Preprocess the data")
            
            self.WORD = Field('words', pad=pad, unk=unk, bos=bos, lower=True)

            tokenizer = BertTokenizer.from_pretrained(args.bert_model)
            self.BERT = BertField('bert', pad='[PAD]', bos='[CLS]',
                                    tokenize=tokenizer.encode)

            if args.feat == 'char':
                self.FEAT = CharField('chars', pad=pad, unk=unk, bos=bos,
                                      fix_len=args.fix_len, tokenize=list)
            elif args.feat == 'bert':
                tokenizer = BertTokenizer.from_pretrained(args.bert_model)
                self.FEAT = BertField('bert', pad='[PAD]', bos='[CLS]',
                                      tokenize=tokenizer.encode)
            else:
                self.FEAT = Field('tags', bos=bos)
            self.HEAD = Field('heads', bos=bos, use_vocab=False, fn=int)
            self.REL = Field('rels', bos=bos)
            if args.feat in ('char', 'bert'):
                self.fields = CoNLL(FORM=(self.WORD, self.BERT, self.FEAT),
                                    HEAD=self.HEAD, DEPREL=self.REL)
            else:
                self.fields = CoNLL(FORM=(self.WORD, self.BERT), CPOS=self.FEAT,
                                    HEAD=self.HEAD, DEPREL=self.REL)

            train = Corpus.load(args.ftrain, self.fields, args.max_len)
            if args.fembed:
                if args.bert is False:
                    # fasttext
                    embed = Embedding.load(args.fembed, args.lang, unk=args.unk)
                else:
                    embed = None
            else:
                embed = None
            
            self.WORD.build(train, args.min_freq, embed)
            self.FEAT.build(train)
            self.BERT.build(train)
            self.REL.build(train)
            torch.save(self.fields, args.fields)
        else:
            self.fields = torch.load(args.fields)
            if args.feat in ('char', 'bert'):
                self.WORD, self.BERT, self.FEAT = self.fields.FORM
            else:
                self.WORD, self.BERT, self.FEAT = self.fields.FORM, self.fields.CPOS
            self.HEAD, self.REL = self.fields.HEAD, self.fields.DEPREL


        self.puncts = torch.tensor([i for s, i in self.WORD.vocab.stoi.items()
                                    if ispunct(s)]).to(args.device)
        self.criterion = nn.CrossEntropyLoss()

        logging.info(f"{self.WORD}\n{self.FEAT}\n{self.BERT}\n{self.HEAD}\n{self.REL}")
        args.update({
            'n_words': self.WORD.vocab.n_init,
            'n_feats': len(self.FEAT.vocab),
            'n_bert': len(self.BERT.vocab),
            'n_rels': len(self.REL.vocab),
            'pad_index': self.WORD.pad_index,
            'unk_index': self.WORD.unk_index,
            'bos_index': self.WORD.bos_index
        })
        logging.info(f"n_words {args.n_words} n_feats {args.n_feats} n_bert {args.n_bert} pad_index {args.pad_index} bos_index {args.bos_index}")
Esempio n. 5
0
class CMD(object):

    def __call__(self, args):
        self.args = args
        logging.basicConfig(filename=args.output, filemode='w', format='%(asctime)s %(levelname)-8s %(message)s', level=logging.INFO, datefmt='%Y-%m-%d %H:%M:%S')
        
        args.ud_dataset = {
                'en': (
                    'data/ud/UD_English-EWT/en_ewt-ud-train.conllx',
                    'data/ud/UD_English-EWT/en_ewt-ud-dev.conllx',
                    'data/ud/UD_English-EWT/en_ewt-ud-test.conllx',
                    "data/fastText_data/wiki.en.ewt.vec.new",
                ),
                'en20': (
                    'data/ud/UD_English-EWT/en_ewt-ud-train20.conllx',
                    'data/ud/UD_English-EWT/en_ewt-ud-dev.conllx',
                    'data/ud/UD_English-EWT/en_ewt-ud-test.conllx',
                    "data/fastText_data/wiki.en.ewt.vec.new",
                ),
                'en40': (
                    'data/ud/UD_English-EWT/en_ewt-ud-train40.conllx',
                    'data/ud/UD_English-EWT/en_ewt-ud-dev.conllx',
                    'data/ud/UD_English-EWT/en_ewt-ud-test.conllx',
                    "data/fastText_data/wiki.en.ewt.vec.new",
                ),
                'en60': (
                    'data/ud/UD_English-EWT/en_ewt-ud-train60.conllx',
                    'data/ud/UD_English-EWT/en_ewt-ud-dev.conllx',
                    'data/ud/UD_English-EWT/en_ewt-ud-test.conllx',
                    "data/fastText_data/wiki.en.ewt.vec.new",
                ),
                'en80': (
                    'data/ud/UD_English-EWT/en_ewt-ud-train80.conllx',
                    'data/ud/UD_English-EWT/en_ewt-ud-dev.conllx',
                    'data/ud/UD_English-EWT/en_ewt-ud-test.conllx',
                    "data/fastText_data/wiki.en.ewt.vec.new",
                ),
                'ar': (
                    "data/ud/UD_Arabic-PADT/ar_padt-ud-train.conllx",
                    "data/ud/UD_Arabic-PADT/ar_padt-ud-dev.conllx",
                    "data/ud/UD_Arabic-PADT/ar_padt-ud-test.conllx",
                    "data/fastText_data/wiki.ar.padt.vec.new",
                ),
                'ar20': (
                    "data/ud/UD_Arabic-PADT/ar_padt-ud-train20.conllx",
                    "data/ud/UD_Arabic-PADT/ar_padt-ud-dev.conllx",
                    "data/ud/UD_Arabic-PADT/ar_padt-ud-test.conllx",
                    "data/fastText_data/wiki.ar.padt.vec.new",
                ),
                'ar40': (
                    "data/ud/UD_Arabic-PADT/ar_padt-ud-train40.conllx",
                    "data/ud/UD_Arabic-PADT/ar_padt-ud-dev.conllx",
                    "data/ud/UD_Arabic-PADT/ar_padt-ud-test.conllx",
                    "data/fastText_data/wiki.ar.padt.vec.new",
                ),
                'ar60': (
                    "data/ud/UD_Arabic-PADT/ar_padt-ud-train60.conllx",
                    "data/ud/UD_Arabic-PADT/ar_padt-ud-dev.conllx",
                    "data/ud/UD_Arabic-PADT/ar_padt-ud-test.conllx",
                    "data/fastText_data/wiki.ar.padt.vec.new",
                ),
                'ar80': (
                    "data/ud/UD_Arabic-PADT/ar_padt-ud-train80.conllx",
                    "data/ud/UD_Arabic-PADT/ar_padt-ud-dev.conllx",
                    "data/ud/UD_Arabic-PADT/ar_padt-ud-test.conllx",
                    "data/fastText_data/wiki.ar.padt.vec.new",
                ),
                'bg': (
                    "data/ud/UD_Bulgarian-BTB/bg_btb-ud-train.conllx",
                    "data/ud/UD_Bulgarian-BTB/bg_btb-ud-dev.conllx",
                    "data/ud/UD_Bulgarian-BTB/bg_btb-ud-test.conllx",
                    "data/fastText_data/wiki.bg.btb.vec.new",
                ),
                'da': (
                    "data/ud/UD_Danish-DDT/da_ddt-ud-train.conllx",
                    "data/ud/UD_Danish-DDT/da_ddt-ud-dev.conllx",
                    "data/ud/UD_Danish-DDT/da_ddt-ud-test.conllx",
                    "data/fastText_data/wiki.da.ddt.vec.new",
                ),
                'de': (
                    "data/ud/UD_German-GSD/de_gsd-ud-train.conllx",
                    "data/ud/UD_German-GSD/de_gsd-ud-dev.conllx",
                    "data/ud/UD_German-GSD/de_gsd-ud-test.conllx",
                    "data/fastText_data/wiki.de.gsd.vec.new",
                ),
                'es': (
                    "data/ud/UD_Spanish-GSDAnCora/es_gsdancora-ud-train.conllx",
                    "data/ud/UD_Spanish-GSDAnCora/es_gsdancora-ud-dev.conllx",
                    "data/ud/UD_Spanish-GSDAnCora/es_gsdancora-ud-test.conllx",
                    "data/fastText_data/wiki.es.gsdancora.vec.new",
                ),
                'es20': (
                    "data/ud/UD_Spanish-GSDAnCora/es_gsdancora-ud-train20.conllx",
                    "data/ud/UD_Spanish-GSDAnCora/es_gsdancora-ud-dev.conllx",
                    "data/ud/UD_Spanish-GSDAnCora/es_gsdancora-ud-test.conllx",
                    "data/fastText_data/wiki.es.gsdancora.vec.new",
                ),
                'es40': (
                    "data/ud/UD_Spanish-GSDAnCora/es_gsdancora-ud-train40.conllx",
                    "data/ud/UD_Spanish-GSDAnCora/es_gsdancora-ud-dev.conllx",
                    "data/ud/UD_Spanish-GSDAnCora/es_gsdancora-ud-test.conllx",
                    "data/fastText_data/wiki.es.gsdancora.vec.new",
                ),
                'es60': (
                    "data/ud/UD_Spanish-GSDAnCora/es_gsdancora-ud-train60.conllx",
                    "data/ud/UD_Spanish-GSDAnCora/es_gsdancora-ud-dev.conllx",
                    "data/ud/UD_Spanish-GSDAnCora/es_gsdancora-ud-test.conllx",
                    "data/fastText_data/wiki.es.gsdancora.vec.new",
                ),
                'es80': (
                    "data/ud/UD_Spanish-GSDAnCora/es_gsdancora-ud-train80.conllx",
                    "data/ud/UD_Spanish-GSDAnCora/es_gsdancora-ud-dev.conllx",
                    "data/ud/UD_Spanish-GSDAnCora/es_gsdancora-ud-test.conllx",
                    "data/fastText_data/wiki.es.gsdancora.vec.new",
                ),
                'fa': (
                    "data/ud/UD_Persian-Seraji/fa_seraji-ud-train.conllx",
                    "data/ud/UD_Persian-Seraji/fa_seraji-ud-dev.conllx",
                    "data/ud/UD_Persian-Seraji/fa_seraji-ud-test.conllx",
                    "data/fastText_data/wiki.fa.seraji.vec.new",
                ),
                'fr': (
                    "data/ud/UD_French-GSD/fr_gsd-ud-train.conllx",
                    "data/ud/UD_French-GSD/fr_gsd-ud-dev.conllx",
                    "data/ud/UD_French-GSD/fr_gsd-ud-test.conllx",
                    "data/fastText_data/wiki.fr.gsd.vec.new",
                ),
                'he': (
                    "data/ud/UD_Hebrew-HTB/he_htb-ud-train.conllx",
                    "data/ud/UD_Hebrew-HTB/he_htb-ud-dev.conllx",
                    "data/ud/UD_Hebrew-HTB/he_htb-ud-test.conllx",
                    "data/fastText_data/wiki.he.htb.vec.new",
                ),
                'hi': (
                    "data/ud/UD_Hindi-HDTB/hi_hdtb-ud-train.conllx",
                    "data/ud/UD_Hindi-HDTB/hi_hdtb-ud-dev.conllx",
                    "data/ud/UD_Hindi-HDTB/hi_hdtb-ud-test.conllx",
                    "data/fastText_data/wiki.hi.hdtb.vec.new",
                ),
                'hr': (
                    "data/ud/UD_Croatian-SET/hr_set-ud-train.conllx",
                    "data/ud/UD_Croatian-SET/hr_set-ud-dev.conllx",
                    "data/ud/UD_Croatian-SET/hr_set-ud-test.conllx",
                    "data/fastText_data/wiki.hr.set.vec.new",
                ),
                'id': (
                    "data/ud/UD_Indonesian-GSD/id_gsd-ud-train.conllx",
                    "data/ud/UD_Indonesian-GSD/id_gsd-ud-dev.conllx",
                    "data/ud/UD_Indonesian-GSD/id_gsd-ud-test.conllx",
                    "data/fastText_data/wiki.id.gsd.vec.new",
                ),
                'it': (
                    "data/ud/UD_Italian-ISDT/it_isdt-ud-train.conllx",
                    "data/ud/UD_Italian-ISDT/it_isdt-ud-dev.conllx",
                    "data/ud/UD_Italian-ISDT/it_isdt-ud-test.conllx",
                    "data/fastText_data/wiki.it.isdt.vec.new",
                ),
                'ja': (
                    "data/ud/UD_Japanese-GSD/ja_gsd-ud-train.conllx",
                    "data/ud/UD_Japanese-GSD/ja_gsd-ud-dev.conllx",
                    "data/ud/UD_Japanese-GSD/ja_gsd-ud-test.conllx",
                    "data/fastText_data/wiki.ja.gsd.vec.new",
                ),
                'ko': (
                    "data/ud/UD_Korean-GSDKaist/ko_gsdkaist-ud-train.conllx",
                    "data/ud/UD_Korean-GSDKaist/ko_gsdkaist-ud-dev.conllx",
                    "data/ud/UD_Korean-GSDKaist/ko_gsdkaist-ud-test.conllx",
                    "data/fastText_data/wiki.ko.gsdkaist.vec.new",
                ),
                'nl': (
                    "data/ud/UD_Dutch-AlpinoLassySmall/nl_alpinolassysmall-ud-train.conllx",
                    "data/ud/UD_Dutch-AlpinoLassySmall/nl_alpinolassysmall-ud-dev.conllx",
                    "data/ud/UD_Dutch-AlpinoLassySmall/nl_alpinolassysmall-ud-test.conllx",
                    "data/fastText_data/wiki.nl.alpinolassysmall.vec.new",
                ),
                'no': (
                    "data/ud/UD_Norwegian-BokmaalNynorsk/no_bokmaalnynorsk-ud-train.conllx",
                    "data/ud/UD_Norwegian-BokmaalNynorsk/no_bokmaalnynorsk-ud-dev.conllx",
                    "data/ud/UD_Norwegian-BokmaalNynorsk/no_bokmaalnynorsk-ud-test.conllx",
                    "data/fastText_data/wiki.no.bokmaalnynorsk.vec.new",
                ),
                'pt': (
                    "data/ud/UD_Portuguese-BosqueGSD/pt_bosquegsd-ud-train.conllx",
                    "data/ud/UD_Portuguese-BosqueGSD/pt_bosquegsd-ud-dev.conllx",
                    "data/ud/UD_Portuguese-BosqueGSD/pt_bosquegsd-ud-test.conllx",
                    "data/fastText_data/wiki.pt.bosquegsd.vec.new",
                ),
                'sv': (
                    "data/ud/UD_Swedish-Talbanken/sv_talbanken-ud-train.conllx",
                    "data/ud/UD_Swedish-Talbanken/sv_talbanken-ud-dev.conllx",
                    "data/ud/UD_Swedish-Talbanken/sv_talbanken-ud-test.conllx",
                    "data/fastText_data/wiki.sv.talbanken.vec.new",
                ),
                'tr': (
                    "data/ud/UD_Turkish-IMST/tr_imst-ud-train.conllx",
                    "data/ud/UD_Turkish-IMST/tr_imst-ud-dev.conllx",
                    "data/ud/UD_Turkish-IMST/tr_imst-ud-test.conllx",
                    "data/fastText_data/wiki.tr.imst.vec.new",
                ),
                'zh': (
                    "data/ud/UD_Chinese-GSD/zh_gsd-ud-train.conllx",
                    "data/ud/UD_Chinese-GSD/zh_gsd-ud-dev.conllx",
                    "data/ud/UD_Chinese-GSD/zh_gsd-ud-test.conllx",
                    "data/fastText_data/wiki.zh.gsd.vec.new",
                )}

        self.args.ftrain = args.ud_dataset[args.lang][0]
        self.args.fdev = args.ud_dataset[args.lang][1]
        self.args.ftest = args.ud_dataset[args.lang][2]
        self.args.fembed = args.ud_dataset[args.lang][3]

        if not os.path.exists(args.file):
            os.mkdir(args.file)
        if not os.path.exists(args.fields) or args.preprocess:
            logging.info("Preprocess the data")
            
            self.WORD = Field('words', pad=pad, unk=unk, bos=bos, lower=True)

            tokenizer = BertTokenizer.from_pretrained(args.bert_model)
            self.BERT = BertField('bert', pad='[PAD]', bos='[CLS]',
                                    tokenize=tokenizer.encode)

            if args.feat == 'char':
                self.FEAT = CharField('chars', pad=pad, unk=unk, bos=bos,
                                      fix_len=args.fix_len, tokenize=list)
            elif args.feat == 'bert':
                tokenizer = BertTokenizer.from_pretrained(args.bert_model)
                self.FEAT = BertField('bert', pad='[PAD]', bos='[CLS]',
                                      tokenize=tokenizer.encode)
            else:
                self.FEAT = Field('tags', bos=bos)
            self.HEAD = Field('heads', bos=bos, use_vocab=False, fn=int)
            self.REL = Field('rels', bos=bos)
            if args.feat in ('char', 'bert'):
                self.fields = CoNLL(FORM=(self.WORD, self.BERT, self.FEAT),
                                    HEAD=self.HEAD, DEPREL=self.REL)
            else:
                self.fields = CoNLL(FORM=(self.WORD, self.BERT), CPOS=self.FEAT,
                                    HEAD=self.HEAD, DEPREL=self.REL)

            train = Corpus.load(args.ftrain, self.fields, args.max_len)
            if args.fembed:
                if args.bert is False:
                    # fasttext
                    embed = Embedding.load(args.fembed, args.lang, unk=args.unk)
                else:
                    embed = None
            else:
                embed = None
            
            self.WORD.build(train, args.min_freq, embed)
            self.FEAT.build(train)
            self.BERT.build(train)
            self.REL.build(train)
            torch.save(self.fields, args.fields)
        else:
            self.fields = torch.load(args.fields)
            if args.feat in ('char', 'bert'):
                self.WORD, self.BERT, self.FEAT = self.fields.FORM
            else:
                self.WORD, self.BERT, self.FEAT = self.fields.FORM, self.fields.CPOS
            self.HEAD, self.REL = self.fields.HEAD, self.fields.DEPREL


        self.puncts = torch.tensor([i for s, i in self.WORD.vocab.stoi.items()
                                    if ispunct(s)]).to(args.device)
        self.criterion = nn.CrossEntropyLoss()

        logging.info(f"{self.WORD}\n{self.FEAT}\n{self.BERT}\n{self.HEAD}\n{self.REL}")
        args.update({
            'n_words': self.WORD.vocab.n_init,
            'n_feats': len(self.FEAT.vocab),
            'n_bert': len(self.BERT.vocab),
            'n_rels': len(self.REL.vocab),
            'pad_index': self.WORD.pad_index,
            'unk_index': self.WORD.unk_index,
            'bos_index': self.WORD.bos_index
        })
        logging.info(f"n_words {args.n_words} n_feats {args.n_feats} n_bert {args.n_bert} pad_index {args.pad_index} bos_index {args.bos_index}")

    def train(self, loader, self_train=None):
        self.model.train()

        cnt = 0
        for words, bert, feats, arcs, rels in loader:
            if self_train is not None:
                arcs = self_train[cnt]

            self.optimizer.zero_grad()
            mask = words.ne(self.args.pad_index)
            # ignore the first token of each sentence
            mask[:, 0] = 0
            arc_scores = self.model(words, bert, feats)
            crf_weight = arc_scores
            arc_scores = self.model.decoder(arc_scores, feats)  # joint_weights
            if self.args.crf:
                if self.args.T_Reg:
                    source_score = self.model.T_Reg(words, bert, feats, self.args.source_model)
                    loss = self.model.crf(crf_weight, arc_scores + self.args.T_beta*source_score, arcs, words, feats)  # crf_weights, joint_weights, heads
                else:
                    loss = self.model.crf(crf_weight, arc_scores, arcs, words, feats)  # crf_weights, joint_weights, heads
            else:
                loss = self.get_loss(arc_scores, arcs, mask)
            if self.args.W_Reg:
                mseloss = self.model.W_Reg()
                loss += mseloss
            if self.args.E_Reg:
                eloss = self.model.E_Reg(words, bert, feats, self.args.source_model, arc_scores)
                loss += eloss

            loss.backward()
            nn.utils.clip_grad_norm_(self.model.parameters(), self.args.clip)
            self.optimizer.step()
            self.scheduler.step()
            cnt += 1

    @torch.no_grad()
    def evaluate(self, loader, self_train=None):
        self.model.eval()

        loss, metric = 0, Metric()
        
        cnt = 0
        for words, bert, feats, arcs, rels in loader:
            if self_train is not None:
                arcs = self_train[cnt]

            mask = words.ne(self.args.pad_index)
            # ignore the first token of each sentence
            mask[:, 0] = 0
            arc_scores = self.model(words, bert, feats)
            crf_weight = arc_scores
            arc_scores = self.model.decoder(arc_scores, feats)  # joint_weights
            if self.args.crf:
                cur_loss = self.model.crf(crf_weight, arc_scores, arcs, words, feats)  # crf_weights, joint_weights, heads, words, pos
                if self.args.unsupervised:
                    arc_preds = self.model.decode_paskin(arc_scores)
                else:
                    arc_preds = self.model.decode_crf(arc_scores, mask)
                loss += cur_loss
            else:
                loss += self.get_loss(arc_scores, arcs, mask)
                arc_preds = self.model.decode(arc_scores, mask)

            # ignore all punctuation if not specified
            if not self.args.punct:
                mask &= words.unsqueeze(-1).ne(self.puncts).all(-1)
            metric(arc_preds, arcs, mask)
            cnt += 1

        loss /= len(loader)

        return loss, metric


    @torch.no_grad()
    def get_preds(self, loader):
        self.model.eval()
        loss, metric = 0, Metric()
        arcs_preds = []
            
        for words, bert, feats, arcs, rels in loader:
            mask = words.ne(self.args.pad_index)
            # ignore the first token of each sentence
            mask[:, 0] = 0
            arc_scores = self.model(words, bert, feats)
            crf_weight = arc_scores
            arc_scores = self.model.decoder(arc_scores, feats)  # joint_weights
            if self.args.crf:
                cur_loss = self.model.crf(crf_weight, arc_scores, arcs, words, feats)  # crf_weights, joint_weights, heads, words, pos
                if self.args.unsupervised:
                    arc_preds = self.model.decode_paskin(arc_scores)
                else:
                    arc_preds = self.model.decode_crf(arc_scores, mask)
                loss += cur_loss
            else:
                loss += self.get_loss(arc_scores, arcs, mask)
                arc_preds = self.model.decode(arc_scores, mask)
                arcs_preds.append(arc_preds)

        return arcs_preds

    def get_loss(self, arc_scores, arcs, mask):
        arc_scores, arcs = arc_scores[mask], arcs[mask]
        arc_loss = self.criterion(arc_scores, arcs)

        return arc_loss
Esempio n. 6
0
    def __call__(self, args):
        self.args = args
        if not os.path.exists(args.file):
            os.mkdir(args.file)
        if not os.path.exists(args.fields) or args.preprocess:
            print("Preprocess the data")

            self.CHAR = Field('chars', pad=pad, unk=unk,
                              bos=bos, eos=eos, lower=True)
                              
            # TODO span as label, modify chartfield to spanfield
            self.SEG = SegmentField('segs')

            if args.feat == 'bert':
                tokenizer = BertTokenizer.from_pretrained(args.bert_model)
                self.FEAT = BertField('bert',
                                      pad='[PAD]',
                                      bos='[CLS]',
                                      eos='[SEP]',
                                      tokenize=tokenizer.encode)
                self.fields = CoNLL(CHAR=(self.CHAR, self.FEAT),
                                    SEG=self.SEG)
            elif args.feat == 'bigram':
                self.BIGRAM = NGramField(
                    'bichar', n=2, pad=pad, unk=unk, bos=bos, eos=eos, lower=True)
                self.fields = CoNLL(CHAR=(self.CHAR, self.BIGRAM),
                                    SEG=self.SEG)
            elif args.feat == 'trigram':
                self.BIGRAM = NGramField(
                    'bichar', n=2, pad=pad, unk=unk, bos=bos, eos=eos, lower=True)
                self.TRIGRAM = NGramField(
                    'trichar', n=3, pad=pad, unk=unk, bos=bos, eos=eos, lower=True)
                self.fields = CoNLL(CHAR=(self.CHAR,
                                          self.BIGRAM,
                                          self.TRIGRAM),
                                    SEG=self.SEG)
            else:
                self.fields = CoNLL(CHAR=self.CHAR,
                                    SEG=self.SEG)

            train = Corpus.load(args.ftrain, self.fields)
            embed = Embedding.load(
                'data/tencent.char.200.txt',
                args.unk) if args.embed else None
            self.CHAR.build(train, args.min_freq, embed)
            if hasattr(self, 'FEAT'):
                self.FEAT.build(train)
            if hasattr(self, 'BIGRAM'):
                embed = Embedding.load(
                    'data/tencent.bi.200.txt',
                    args.unk) if args.embed else None
                self.BIGRAM.build(train, args.min_freq,
                                  embed=embed,
                                  dict_file=args.dict_file)
            if hasattr(self, 'TRIGRAM'):
                embed = Embedding.load(
                    'data/tencent.tri.200.txt',
                    args.unk) if args.embed else None
                self.TRIGRAM.build(train, args.min_freq,
                                   embed=embed,
                                   dict_file=args.dict_file)
            # TODO
            self.SEG.build(train)
            torch.save(self.fields, args.fields)
        else:
            self.fields = torch.load(args.fields)
            if args.feat == 'bert':
                self.CHAR, self.FEAT = self.fields.CHAR
            elif args.feat == 'bigram':
                self.CHAR, self.BIGRAM = self.fields.CHAR
            elif args.feat == 'trigram':
                self.CHAR, self.BIGRAM, self.TRIGRAM = self.fields.CHAR
            else:
                self.CHAR = self.fields.CHAR
            # TODO
            self.SEG = self.fields.SEG
        # TODO loss funciton 
        # self.criterion = nn.CrossEntropyLoss()
        # # [B, E, M, S]
        # self.trans = (torch.tensor([1., 0., 0., 1.]).log().to(args.device),
        #               torch.tensor([0., 1., 0., 1.]).log().to(args.device),
        #               torch.tensor([[0., 1., 1., 0.],
        #                             [1., 0., 0., 1.],
        #                             [0., 1., 1., 0.],
        #                             [1., 0., 0., 1.]]).log().to(args.device))

        args.update({
            'n_chars': self.CHAR.vocab.n_init,
            'pad_index': self.CHAR.pad_index,
            'unk_index': self.CHAR.unk_index
        })

        # TODO
        vocab = f"{self.CHAR}\n"
        if hasattr(self, 'FEAT'):
            args.update({
                'n_feats': self.FEAT.vocab.n_init,
            })
            vocab += f"{self.FEAT}\n"
        if hasattr(self, 'BIGRAM'):
            args.update({
                'n_bigrams': self.BIGRAM.vocab.n_init,
            })
            vocab += f"{self.BIGRAM}\n"
        if hasattr(self, 'TRIGRAM'):
            args.update({
                'n_trigrams': self.TRIGRAM.vocab.n_init,
            })
            vocab += f"{self.TRIGRAM}\n"

        print(f"Override the default configs\n{args}")
        print(vocab[:-1])
Esempio n. 7
0
class CMD(object):

    def __call__(self, args):
        self.args = args
        if not os.path.exists(args.file):
            os.mkdir(args.file)
        if not os.path.exists(args.fields) or args.preprocess:
            print("Preprocess the data")

            self.CHAR = Field('chars', pad=pad, unk=unk,
                              bos=bos, eos=eos, lower=True)
                              
            # TODO span as label, modify chartfield to spanfield
            self.SEG = SegmentField('segs')

            if args.feat == 'bert':
                tokenizer = BertTokenizer.from_pretrained(args.bert_model)
                self.FEAT = BertField('bert',
                                      pad='[PAD]',
                                      bos='[CLS]',
                                      eos='[SEP]',
                                      tokenize=tokenizer.encode)
                self.fields = CoNLL(CHAR=(self.CHAR, self.FEAT),
                                    SEG=self.SEG)
            elif args.feat == 'bigram':
                self.BIGRAM = NGramField(
                    'bichar', n=2, pad=pad, unk=unk, bos=bos, eos=eos, lower=True)
                self.fields = CoNLL(CHAR=(self.CHAR, self.BIGRAM),
                                    SEG=self.SEG)
            elif args.feat == 'trigram':
                self.BIGRAM = NGramField(
                    'bichar', n=2, pad=pad, unk=unk, bos=bos, eos=eos, lower=True)
                self.TRIGRAM = NGramField(
                    'trichar', n=3, pad=pad, unk=unk, bos=bos, eos=eos, lower=True)
                self.fields = CoNLL(CHAR=(self.CHAR,
                                          self.BIGRAM,
                                          self.TRIGRAM),
                                    SEG=self.SEG)
            else:
                self.fields = CoNLL(CHAR=self.CHAR,
                                    SEG=self.SEG)

            train = Corpus.load(args.ftrain, self.fields)
            embed = Embedding.load(
                'data/tencent.char.200.txt',
                args.unk) if args.embed else None
            self.CHAR.build(train, args.min_freq, embed)
            if hasattr(self, 'FEAT'):
                self.FEAT.build(train)
            if hasattr(self, 'BIGRAM'):
                embed = Embedding.load(
                    'data/tencent.bi.200.txt',
                    args.unk) if args.embed else None
                self.BIGRAM.build(train, args.min_freq,
                                  embed=embed,
                                  dict_file=args.dict_file)
            if hasattr(self, 'TRIGRAM'):
                embed = Embedding.load(
                    'data/tencent.tri.200.txt',
                    args.unk) if args.embed else None
                self.TRIGRAM.build(train, args.min_freq,
                                   embed=embed,
                                   dict_file=args.dict_file)
            # TODO
            self.SEG.build(train)
            torch.save(self.fields, args.fields)
        else:
            self.fields = torch.load(args.fields)
            if args.feat == 'bert':
                self.CHAR, self.FEAT = self.fields.CHAR
            elif args.feat == 'bigram':
                self.CHAR, self.BIGRAM = self.fields.CHAR
            elif args.feat == 'trigram':
                self.CHAR, self.BIGRAM, self.TRIGRAM = self.fields.CHAR
            else:
                self.CHAR = self.fields.CHAR
            # TODO
            self.SEG = self.fields.SEG
        # TODO loss funciton 
        # self.criterion = nn.CrossEntropyLoss()
        # # [B, E, M, S]
        # self.trans = (torch.tensor([1., 0., 0., 1.]).log().to(args.device),
        #               torch.tensor([0., 1., 0., 1.]).log().to(args.device),
        #               torch.tensor([[0., 1., 1., 0.],
        #                             [1., 0., 0., 1.],
        #                             [0., 1., 1., 0.],
        #                             [1., 0., 0., 1.]]).log().to(args.device))

        args.update({
            'n_chars': self.CHAR.vocab.n_init,
            'pad_index': self.CHAR.pad_index,
            'unk_index': self.CHAR.unk_index
        })

        # TODO
        vocab = f"{self.CHAR}\n"
        if hasattr(self, 'FEAT'):
            args.update({
                'n_feats': self.FEAT.vocab.n_init,
            })
            vocab += f"{self.FEAT}\n"
        if hasattr(self, 'BIGRAM'):
            args.update({
                'n_bigrams': self.BIGRAM.vocab.n_init,
            })
            vocab += f"{self.BIGRAM}\n"
        if hasattr(self, 'TRIGRAM'):
            args.update({
                'n_trigrams': self.TRIGRAM.vocab.n_init,
            })
            vocab += f"{self.TRIGRAM}\n"

        print(f"Override the default configs\n{args}")
        print(vocab[:-1])

    def train(self, loader):
        self.model.train()

        for data in loader:
            # TODO label
            if self.args.feat == 'bert':
                chars, feats, segs = data
                feed_dict = {"chars": chars, "feats": feats}
            elif self.args.feat == 'bigram':
                chars, bigram, segs = data
                feed_dict = {"chars": chars, "bigram": bigram}
            elif self.args.feat == 'trigram':
                chars, bigram, trigram, segs = data
                feed_dict = {"chars": chars,
                             "bigram": bigram, "trigram": trigram}
            else:
                chars, segs = data
                feed_dict = {"chars": chars}

            self.optimizer.zero_grad()

            batch_size, seq_len = chars.shape
            # fenceposts length: (B)
            lens = chars.ne(self.args.pad_index).sum(1) - 1
            # TODO purpose
            # (B, 1, L-1)
            mask = lens.new_tensor(range(seq_len - 1)) < lens.view(-1, 1, 1)
            # TODO purpose
            # for example, seq_len=10, fenceposts=7, pad=2
            # for each sentence, get a L-1*L-1 matrix
            # span (i, i) and pad are masked 
            # [[False,  True,  True,  True,  True,  True,  True, False, False],
            #  [False, False,  True,  True,  True,  True,  True, False, False],
            #  [False, False, False,  True,  True,  True,  True, False, False],
            #  [False, False, False, False,  True,  True,  True, False, False],
            #  [False, False, False, False, False,  True,  True, False, False],
            #  [False, False, False, False, False, False,  True, False, False],
            #  [False, False, False, False, False, False, False, False, False],
            #  [False, False, False, False, False, False, False, False, False],
            #  [False, False, False, False, False, False, False, False, False]]
            # (B, L-1, L-1)
            mask = mask & mask.new_ones(seq_len-1, seq_len-1).triu_(1)
            # (B, L-1, L-1), (B, L-1, 1)
            s_span, s_link = self.model(feed_dict, self.args.link)

            # with torch.autograd.set_detect_anomaly(True):
            loss = self.get_loss(s_span, segs, mask, s_link)
                
            # with torch.autograd.set_detect_anomaly(True):
            loss.backward()

            nn.utils.clip_grad_norm_(self.model.parameters(),
                                     self.args.clip)
            self.optimizer.step()
            self.scheduler.step()

    @torch.no_grad()
    def evaluate(self, loader):
        self.model.eval()

        total_loss, metric = 0, SegF1Metric()

        for data in loader:
            if self.args.feat == 'bert':
                chars, feats, segs = data
                feed_dict = {"chars": chars, "feats": feats}
            elif self.args.feat == 'bigram':
                chars, bigram, segs = data
                feed_dict = {"chars": chars, "bigram": bigram}
            elif self.args.feat == 'trigram':
                chars, bigram, trigram, segs = data
                feed_dict = {"chars": chars,
                             "bigram": bigram, "trigram": trigram}
            else:
                chars, segs = data
                feed_dict = {"chars": chars}

            batch_size, seq_len = chars.shape
            lens = chars.ne(self.args.pad_index).sum(1) - 1
            mask = lens.new_tensor(range(seq_len - 1)) < lens.view(-1, 1, 1)
            mask = mask & mask.new_ones(seq_len-1, seq_len-1).triu_(1)

            s_span, s_link = self.model(feed_dict, self.args.link)

            loss = self.get_loss(s_span, segs, mask, s_link)

            pred_segs = self.decode(s_span, mask, s_link)
            # list
            # gold_segs = [torch.nonzero(gold).tolist() for gold in segs]
            gold_segs = [list(zip(*tensor2scalar(torch.nonzero(gold, as_tuple=True))))
                         for gold in segs]

            total_loss += loss.item()
            metric(pred_segs, gold_segs)

        total_loss /= len(loader)

        # TODO metric
        return total_loss, metric

    @torch.no_grad()
    def predict(self, loader):
        self.model.eval()

        all_segs = []
        for data in loader:
            if self.args.feat == 'bert':
                chars, feats = data
                feed_dict = {"chars": chars, "feats": feats}
            elif self.args.feat == 'bigram':
                chars, bigram = data
                feed_dict = {"chars": chars, "bigram": bigram}
            elif self.args.feat == 'trigram':
                chars, bigram, trigram = data
                feed_dict = {"chars": chars,
                             "bigram": bigram, "trigram": trigram}
            else:
                chars = data
                feed_dict = {"chars": chars}
 
            mask = chars.ne(self.args.pad_index)
            
            s_span, s_link = self.model(feed_dict, self.args.link)

            pred_segs = directed_acyclic_graph(s_span, mask, s_link)

            all_segs.extend(pred_segs)

        return all_segs

    def get_loss(self, s_span, segs, mask, s_link):
        """crf loss

        Args:
            scores (Tensor(B, N, N)): scores for candidate words (i, j)
            segs (Tensor(B, N, N)): groud truth words
            mask (Tensor(B, N, N)): actual 

        Returns:
            loss [type]: 
            span_probs (Tensor(B, N, N)): marginal probability for candidate words
        """

        # span_mask = spans & mask
        # span_loss, span_probs = crf(s_span, mask, spans, self.args.marg)

        loss = neg_log_likelihood(s_span, segs, mask, s_link)

        return loss

    def decode(self, s_span, mask, s_link):

        pred_spans = directed_acyclic_graph(s_span, mask, s_link)
        
        preds = pred_spans

        return preds