예제 #1
0
    def __call__(self, config):
        print("Preprocess the data")
        train = Corpus.load(config.ftrain)
        dev = Corpus.load(config.fdev)
        test = Corpus.load(config.ftest)
        if os.path.exists(config.vocab):
            vocab = torch.load(config.vocab)
        else:
            vocab = Vocab.from_corpus(corpus=train, min_freq=2)
            vocab.read_embeddings(Embedding.load(config.fembed, config.unk))
            torch.save(vocab, config.vocab)
        config.update({
            'n_words': vocab.n_train_words,
            'n_tags': vocab.n_tags,
            'n_rels': vocab.n_rels,
            'pad_index': vocab.pad_index,
            'unk_index': vocab.unk_index
        })
        print(vocab)

        print("Load the dataset")
        trainset = TextDataset(vocab.numericalize(train))
        devset = TextDataset(vocab.numericalize(dev))
        testset = TextDataset(vocab.numericalize(test))
        # set the data loaders
        train_loader = batchify(dataset=trainset,
                                batch_size=config.batch_size,
                                n_buckets=config.buckets,
                                shuffle=True)
        dev_loader = batchify(dataset=devset,
                              batch_size=config.batch_size,
                              n_buckets=config.buckets)
        test_loader = batchify(dataset=testset,
                               batch_size=config.batch_size,
                               n_buckets=config.buckets)
        print(f"{'train:':6} {len(trainset):5} sentences in total, "
              f"{len(train_loader):3} batches provided")
        print(f"{'dev:':6} {len(devset):5} sentences in total, "
              f"{len(dev_loader):3} batches provided")
        print(f"{'test:':6} {len(testset):5} sentences in total, "
              f"{len(test_loader):3} batches provided")

        print("Create the model")
        parser = BiaffineParser(config, vocab.embeddings)
        if torch.cuda.is_available():
            parser = parser.cuda()
        print(f"{parser}\n")

        model = Model(vocab, parser)

        total_time = timedelta()
        best_e, best_metric = 1, Metric()
        model.optimizer = Adam(model.parser.parameters(),
                               config.lr,
                               (config.beta_1, config.beta_2),
                               config.epsilon)
        model.scheduler = ExponentialLR(model.optimizer,
                                        config.decay ** (1 / config.steps))

        for epoch in range(1, config.epochs + 1):
            start = datetime.now()
            # train one epoch and update the parameters
            model.train(train_loader)

            print(f"Epoch {epoch} / {config.epochs}:")
            loss, train_metric = model.evaluate(train_loader, config.punct)
            print(f"{'train:':6} Loss: {loss:.4f} {train_metric}")
            loss, dev_metric = model.evaluate(dev_loader, config.punct)
            print(f"{'dev:':6} Loss: {loss:.4f} {dev_metric}")
            loss, test_metric = model.evaluate(test_loader, config.punct)
            print(f"{'test:':6} Loss: {loss:.4f} {test_metric}")

            t = datetime.now() - start
            # save the model if it is the best so far
            if dev_metric > best_metric and epoch > config.patience:
                best_e, best_metric = epoch, dev_metric
                model.parser.save(config.model + f".{best_e}")
                print(f"{t}s elapsed (saved)\n")
            else:
                print(f"{t}s elapsed\n")
            total_time += t
            if epoch - best_e >= config.patience:
                break
        model.parser = BiaffineParser.load(config.model + f".{best_e}")
        loss, metric = model.evaluate(test_loader, config.punct)

        print(f"max score of dev is {best_metric.score:.2%} at epoch {best_e}")
        print(f"the score of test at epoch {best_e} is {metric.score:.2%}")
        print(f"average time of each epoch is {total_time / epoch}s")
        print(f"{total_time}s elapsed")
예제 #2
0
    def __call__(self, args):
        config = get_config(args.config_path)
        assert config.ucca.type in ["chart", "top-down", "global-chart"]

        with open(os.path.join(args.save_path, "config.json"), "w", encoding="utf-8") as f:
            json.dump(config, f, ensure_ascii=False, default=lambda o: o.__dict__, indent=4)

        print("save all files to %s" % (args.save_path))
        # read training , dev file
        print("loading datasets and transforming to trees...")
        train = Corpus(args.train_path)
        dev = Corpus(args.dev_path)
        print(train, "\n", dev)

        # init vocab
        print("collecting words and labels in training dataset...")
        vocab = Vocab(train)
        print(vocab)

        # prepare pre-trained embedding
        if args.emb_path:
            print("reading pre-trained embedding...")
            pre_emb = Embedding.load(args.emb_path)
            print(
                "pre-trained words:%d, dim=%d in %s"
                % (len(pre_emb), pre_emb.dim, args.emb_path)
            )
        else:
            pre_emb = None
        embedding = vocab.read_embedding(config.ucca.word_dim, pre_emb)
        vocab_path = os.path.join(args.save_path, "vocab.pt")
        torch.save(vocab, vocab_path)

        # init parser
        print("initializing model...")
        ucca_parser = UCCA_Parser(vocab, config.ucca, pre_emb=embedding)
        if torch.cuda.is_available():
            ucca_parser = ucca_parser.cuda()

        # prepare data
        print("preparing input data...")
        train_loader = Data.DataLoader(
            dataset=train.generate_inputs(vocab, True),
            batch_size=config.ucca.batch_size,
            shuffle=True,
            collate_fn=collate_fn,
        )
        dev_loader = Data.DataLoader(
            dataset=dev.generate_inputs(vocab, False),
            batch_size=10,
            shuffle=False,
            collate_fn=collate_fn,
        )

        optimizer = optim.Adam(ucca_parser.parameters(), lr=config.ucca.lr)
        ucca_evaluator = UCCA_Evaluator(
            parser=ucca_parser,
            gold_dic=args.dev_path,
        )

        trainer = Trainer(
            parser=ucca_parser,
            optimizer=optimizer,
            evaluator=ucca_evaluator,
            batch_size=config.ucca.batch_size,
            epoch=config.ucca.epoch,
            patience=config.ucca.patience,
            path=args.save_path,
        )
        trainer.train(train_loader, dev_loader)

        # reload parser
        del ucca_parser
        torch.cuda.empty_cache()
        print("reloading the best parser for testing...")
        vocab_path = os.path.join(args.save_path, "vocab.pt")
        state_path = os.path.join(args.save_path, "parser.pt")
        config_path = os.path.join(args.save_path, "config.json")
        ucca_parser = UCCA_Parser.load(vocab_path, config_path, state_path)

        if args.test_id_path:
            print("evaluating test data : %s" % (args.test_id_path))
            test = Corpus(args.test_id_path)
            print(test)
            test_loader = Data.DataLoader(
                dataset=test.generate_inputs(vocab, False),
                batch_size=10,
                shuffle=False,
                collate_fn=collate_fn,
            )
            ucca_evaluator = UCCA_Evaluator(
                parser=ucca_parser,
                gold_dic=args.test_id_path,
            )
            ucca_evaluator.compute_accuracy(test_loader)
            ucca_evaluator.remove_temp()

        if args.test_ood_path:
            print("evaluating test data : %s" % (args.test_ood_path))
            test = Corpus(args.test_ood_path)
            print(test)
            test_loader = Data.DataLoader(
                dataset=test.generate_inputs(vocab, False),
                batch_size=10,
                shuffle=False,
                collate_fn=collate_fn,
            )
            ucca_evaluator = UCCA_Evaluator(
                parser=ucca_parser,
                gold_dic=args.test_ood_path,
            )
            ucca_evaluator.compute_accuracy(test_loader)
            ucca_evaluator.remove_temp()
예제 #3
0
파일: cmd.py 프로젝트: shtechair/ACE
    def __call__(self, args):
        self.args = args
        if not hasattr(self.args, 'interpolation'):
            self.args.interpolation = 0.5
        if not os.path.exists(args.file):
            os.mkdir(args.file)
        if not os.path.exists(args.fields) or args.preprocess:
            print("Preprocess the data")
            self.WORD = Field('words', pad=pad, unk=unk, bos=bos, lower=True)
            # if args.feat == 'char':
            #     self.FEAT = CharField('chars', pad=pad, unk=unk, bos=bos,
            #                           fix_len=args.fix_len, tokenize=list)
            # elif args.feat == 'bert':
            #     tokenizer = BertTokenizer.from_pretrained(args.bert_model)
            #     self.FEAT = BertField('bert', pad='[PAD]', bos='[CLS]',
            #                           tokenize=tokenizer.encode)
            # else:
            #     self.FEAT = Field('tags', bos=bos)

            self.CHAR_FEAT = None
            self.POS_FEAT = None
            self.BERT_FEAT = None
            self.FEAT = [self.WORD]
            if args.use_char:
                self.CHAR_FEAT = CharField('chars',
                                           pad=pad,
                                           unk=unk,
                                           bos=bos,
                                           fix_len=args.fix_len,
                                           tokenize=list)
                self.FEAT.append(self.CHAR_FEAT)
            if args.use_pos:
                self.POS_FEAT = Field('tags', bos=bos)
            if args.use_bert:
                tokenizer = BertTokenizer.from_pretrained(args.bert_model)
                self.BERT_FEAT = BertField('bert',
                                           pad='[PAD]',
                                           bos='[CLS]',
                                           tokenize=tokenizer.encode)
                self.FEAT.append(self.BERT_FEAT)

            self.HEAD = Field('heads', bos=bos, use_vocab=False, fn=int)
            self.REL = Field('rels', bos=bos)

            self.fields = CoNLL(FORM=self.FEAT,
                                CPOS=self.POS_FEAT,
                                HEAD=self.HEAD,
                                DEPREL=self.REL)
            # if args.feat in ('char', 'bert'):
            #     self.fields = CoNLL(FORM=(self.WORD, self.FEAT),
            #                         HEAD=self.HEAD, DEPREL=self.REL)
            # else:
            #     self.fields = CoNLL(FORM=self.WORD, CPOS=self.FEAT,
            #                         HEAD=self.HEAD, DEPREL=self.REL)

            train = Corpus.load(args.ftrain, self.fields)
            if args.fembed:
                embed = Embedding.load(args.fembed, args.unk)
            else:
                embed = None
            self.WORD.build(train, args.min_freq, embed)
            if args.use_char:
                self.CHAR_FEAT.build(train)
            if args.use_pos:
                self.POS_FEAT.build(train)
            if args.use_bert:
                self.BERT_FEAT.build(train)
            # self.FEAT.build(train)
            self.REL.build(train)
            torch.save(self.fields, args.fields)
        else:
            self.fields = torch.load(args.fields)
            if args.feat in ('char', 'bert'):
                self.WORD, self.FEAT = self.fields.FORM
            else:
                self.WORD, self.FEAT = self.fields.FORM, self.fields.CPOS
            self.HEAD, self.REL = self.fields.HEAD, self.fields.DEPREL
        self.puncts = torch.tensor([
            i for s, i in self.WORD.vocab.stoi.items() if ispunct(s)
        ]).to(args.device)
        self.rel_criterion = nn.CrossEntropyLoss()
        self.arc_criterion = nn.CrossEntropyLoss()
        if args.binary:
            self.arc_criterion = nn.BCEWithLogitsLoss(reduction='none')

        # print(f"{self.WORD}\n{self.FEAT}\n{self.HEAD}\n{self.REL}")
        print(f"{self.WORD}\n{self.HEAD}\n{self.REL}")
        update_info = {}
        # pdb.set_trace()
        if args.use_char:
            update_info['n_char_feats'] = len(self.CHAR_FEAT.vocab)
        if args.use_pos:
            update_info['n_pos_feats'] = len(self.POS_FEAT.vocab)
        args.update({
            'n_words': self.WORD.vocab.n_init,
            # 'n_feats': len(self.FEAT.vocab),
            'n_rels': len(self.REL.vocab),
            'pad_index': self.WORD.pad_index,
            'unk_index': self.WORD.unk_index,
            'bos_index': self.WORD.bos_index
        })
        args.update(update_info)
예제 #4
0
    def __call__(self, args):
        self.args = args
        if not os.path.exists(args.file):
            os.mkdir(args.file)
        if not os.path.exists(args.fields) or args.preprocess:
            logger.info("Preprocess the data")
            self.WORD = Field('words',
                              pad=pad,
                              unk=unk,
                              bos=bos,
                              lower=args.lower)
            if args.feat == 'char':
                self.FEAT = SubwordField('chars',
                                         pad=pad,
                                         unk=unk,
                                         bos=bos,
                                         fix_len=args.fix_len,
                                         tokenize=list)
            elif args.feat == 'bert':
                tokenizer = SubwordField.tokenizer(args.bert_model)
                self.FEAT = SubwordField('bert',
                                         tokenizer=tokenizer,
                                         fix_len=args.fix_len)
                self.bos = self.FEAT.bos or bos
                if hasattr(tokenizer, 'vocab'):
                    self.FEAT.vocab = tokenizer.vocab
                else:
                    self.FEAT.vocab = FieldVocab(
                        tokenizer.unk_token_id, {
                            tokenizer._convert_id_to_token(i): i
                            for i in range(len(tokenizer))
                        })
            else:
                self.FEAT = Field('tags', bos=self.bos)
            self.ARC = Field('arcs',
                             bos=self.bos,
                             use_vocab=False,
                             fn=numericalize)
            self.REL = Field('rels', bos=self.bos)
            if args.feat == 'bert':
                if args.n_embed:
                    self.fields = CoNLL(FORM=(self.WORD, self.FEAT),
                                        HEAD=self.ARC,
                                        DEPREL=self.REL)
                    self.WORD.bos = self.bos  # ensure representations of the same length
                else:
                    self.fields = CoNLL(FORM=self.FEAT,
                                        HEAD=self.ARC,
                                        DEPREL=self.REL)
                    self.WORD = None
            elif args.feat == 'char':
                self.fields = CoNLL(FORM=(self.WORD, self.FEAT),
                                    HEAD=self.ARC,
                                    DEPREL=self.REL)
            else:
                self.fields = CoNLL(FORM=self.WORD,
                                    CPOS=self.FEAT,
                                    HEAD=self.ARC,
                                    DEPREL=self.REL)

            train = Corpus.load(args.ftrain, self.fields, args.max_sent_length)
            if args.fembed:
                embed = Embedding.load(args.fembed, args.unk)
            else:
                embed = None
            if self.WORD:
                self.WORD.build(train, args.min_freq, embed)
            self.FEAT.build(train)
            self.REL.build(train)
            if args.feat == 'bert':
                # do not save the tokenize funztion, or else it might be incompatible with new releases
                tokenize = self.FEAT.tokenize  # save it
                self.FEAT.tokenize = None
            torch.save(self.fields, args.fields)
            if args.feat == 'bert':
                self.FEAT.tokenize = tokenize  # restore
            self.trainset = train  # pass it on to subclasses
        else:
            self.trainset = None
            self.fields = torch.load(args.fields)
            if args.feat == 'bert':
                tokenizer = SubwordField.tokenizer(args.bert_model)
                if args.n_embed:
                    self.fields.FORM[1].tokenize = tokenizer.tokenize
                else:
                    self.fields.FORM.tokenize = tokenizer.tokenize
            if args.feat in ('char', 'bert'):
                if isinstance(self.fields.FORM, tuple):
                    self.WORD, self.FEAT = self.fields.FORM
                else:
                    self.WORD, self.FEAT = None, self.fields.FORM
            else:
                self.WORD, self.FEAT = self.fields.FORM, self.fields.CPOS
            self.ARC, self.REL = self.fields.HEAD, self.fields.DEPREL
        self.puncts = torch.tensor(
            [i for s, i in self.WORD.vocab.stoi.items()
             if ispunct(s)]).to(args.device) if self.WORD else []

        # override parameters from embeddings:
        if self.WORD:
            args.update({
                'n_words': self.WORD.vocab.n_init,
                'pad_index': self.WORD.pad_index,
                'unk_index': self.WORD.unk_index,
                'bos_index': self.WORD.bos_index,
            })
        args.update({
            'n_feats': len(self.FEAT.vocab),
            'n_rels': len(self.REL.vocab),
            'feat_pad_index': self.FEAT.pad_index,
        })

        logger.info("Features:")
        if self.WORD:
            logger.info(f"   {self.WORD}")
        logger.info(f"   {self.FEAT}\n   {self.ARC}\n   {self.REL}")
예제 #5
0
    def __call__(self, args):
        self.args = args
        if not os.path.exists(args.file):
            os.mkdir(args.file)
        if not os.path.exists(args.fields) or args.preprocess:
            print("Preprocess the data")
            self.WORD = Field('words', pad=pad, unk=unk, bos=bos, lower=True)
            if args.feat == 'char':
                self.FEAT = CharField('chars',
                                      pad=pad,
                                      unk=unk,
                                      bos=bos,
                                      fix_len=args.fix_len,
                                      tokenize=list)
            elif args.feat == 'bert':
                tokenizer = BertTokenizer.from_pretrained(args.bert_model)
                self.FEAT = BertField('bert',
                                      pad='[PAD]',
                                      bos='[CLS]',
                                      tokenize=tokenizer.encode)
            else:
                self.FEAT = Field('tags', bos=bos)
            self.HEAD = Field('heads', bos=bos, use_vocab=False, fn=int)
            self.REL = Field('rels', bos=bos)
            if args.feat in ('char', 'bert'):
                self.fields = CoNLL(FORM=(self.WORD, self.FEAT),
                                    HEAD=self.HEAD,
                                    DEPREL=self.REL)
            else:
                self.fields = CoNLL(FORM=self.WORD,
                                    CPOS=self.FEAT,
                                    HEAD=self.HEAD,
                                    DEPREL=self.REL)

            train = Corpus.load(args.ftrain, self.fields)
            if args.fembed:
                embed = Embedding.load(args.fembed, args.unk)
            else:
                embed = None
            self.WORD.build(train, args.min_freq, embed)
            self.FEAT.build(train)
            self.REL.build(train)
            torch.save(self.fields, args.fields)
        else:
            self.fields = torch.load(args.fields)
            if args.feat in ('char', 'bert'):
                self.WORD, self.FEAT = self.fields.FORM
            else:
                self.WORD, self.FEAT = self.fields.FORM, self.fields.CPOS
            self.HEAD, self.REL = self.fields.HEAD, self.fields.DEPREL
        self.puncts = torch.tensor([
            i for s, i in self.WORD.vocab.stoi.items() if ispunct(s)
        ]).to(args.device)
        self.criterion = nn.CrossEntropyLoss()

        print(f"{self.WORD}\n{self.FEAT}\n{self.HEAD}\n{self.REL}")
        args.update({
            'n_words': self.WORD.vocab.n_init,
            'n_feats': len(self.FEAT.vocab),
            'n_rels': len(self.REL.vocab),
            'pad_index': self.WORD.pad_index,
            'unk_index': self.WORD.unk_index,
            'bos_index': self.WORD.bos_index
        })
예제 #6
0
    def __call__(self, args):
        self.args = args
        logging.basicConfig(filename=args.output, filemode='w', format='%(asctime)s %(levelname)-8s %(message)s', level=logging.INFO, datefmt='%Y-%m-%d %H:%M:%S')
        
        args.ud_dataset = {
                'en': (
                    'data/ud/UD_English-EWT/en_ewt-ud-train.conllx',
                    'data/ud/UD_English-EWT/en_ewt-ud-dev.conllx',
                    'data/ud/UD_English-EWT/en_ewt-ud-test.conllx',
                    "data/fastText_data/wiki.en.ewt.vec.new",
                ),
                'en20': (
                    'data/ud/UD_English-EWT/en_ewt-ud-train20.conllx',
                    'data/ud/UD_English-EWT/en_ewt-ud-dev.conllx',
                    'data/ud/UD_English-EWT/en_ewt-ud-test.conllx',
                    "data/fastText_data/wiki.en.ewt.vec.new",
                ),
                'en40': (
                    'data/ud/UD_English-EWT/en_ewt-ud-train40.conllx',
                    'data/ud/UD_English-EWT/en_ewt-ud-dev.conllx',
                    'data/ud/UD_English-EWT/en_ewt-ud-test.conllx',
                    "data/fastText_data/wiki.en.ewt.vec.new",
                ),
                'en60': (
                    'data/ud/UD_English-EWT/en_ewt-ud-train60.conllx',
                    'data/ud/UD_English-EWT/en_ewt-ud-dev.conllx',
                    'data/ud/UD_English-EWT/en_ewt-ud-test.conllx',
                    "data/fastText_data/wiki.en.ewt.vec.new",
                ),
                'en80': (
                    'data/ud/UD_English-EWT/en_ewt-ud-train80.conllx',
                    'data/ud/UD_English-EWT/en_ewt-ud-dev.conllx',
                    'data/ud/UD_English-EWT/en_ewt-ud-test.conllx',
                    "data/fastText_data/wiki.en.ewt.vec.new",
                ),
                'ar': (
                    "data/ud/UD_Arabic-PADT/ar_padt-ud-train.conllx",
                    "data/ud/UD_Arabic-PADT/ar_padt-ud-dev.conllx",
                    "data/ud/UD_Arabic-PADT/ar_padt-ud-test.conllx",
                    "data/fastText_data/wiki.ar.padt.vec.new",
                ),
                'ar20': (
                    "data/ud/UD_Arabic-PADT/ar_padt-ud-train20.conllx",
                    "data/ud/UD_Arabic-PADT/ar_padt-ud-dev.conllx",
                    "data/ud/UD_Arabic-PADT/ar_padt-ud-test.conllx",
                    "data/fastText_data/wiki.ar.padt.vec.new",
                ),
                'ar40': (
                    "data/ud/UD_Arabic-PADT/ar_padt-ud-train40.conllx",
                    "data/ud/UD_Arabic-PADT/ar_padt-ud-dev.conllx",
                    "data/ud/UD_Arabic-PADT/ar_padt-ud-test.conllx",
                    "data/fastText_data/wiki.ar.padt.vec.new",
                ),
                'ar60': (
                    "data/ud/UD_Arabic-PADT/ar_padt-ud-train60.conllx",
                    "data/ud/UD_Arabic-PADT/ar_padt-ud-dev.conllx",
                    "data/ud/UD_Arabic-PADT/ar_padt-ud-test.conllx",
                    "data/fastText_data/wiki.ar.padt.vec.new",
                ),
                'ar80': (
                    "data/ud/UD_Arabic-PADT/ar_padt-ud-train80.conllx",
                    "data/ud/UD_Arabic-PADT/ar_padt-ud-dev.conllx",
                    "data/ud/UD_Arabic-PADT/ar_padt-ud-test.conllx",
                    "data/fastText_data/wiki.ar.padt.vec.new",
                ),
                'bg': (
                    "data/ud/UD_Bulgarian-BTB/bg_btb-ud-train.conllx",
                    "data/ud/UD_Bulgarian-BTB/bg_btb-ud-dev.conllx",
                    "data/ud/UD_Bulgarian-BTB/bg_btb-ud-test.conllx",
                    "data/fastText_data/wiki.bg.btb.vec.new",
                ),
                'da': (
                    "data/ud/UD_Danish-DDT/da_ddt-ud-train.conllx",
                    "data/ud/UD_Danish-DDT/da_ddt-ud-dev.conllx",
                    "data/ud/UD_Danish-DDT/da_ddt-ud-test.conllx",
                    "data/fastText_data/wiki.da.ddt.vec.new",
                ),
                'de': (
                    "data/ud/UD_German-GSD/de_gsd-ud-train.conllx",
                    "data/ud/UD_German-GSD/de_gsd-ud-dev.conllx",
                    "data/ud/UD_German-GSD/de_gsd-ud-test.conllx",
                    "data/fastText_data/wiki.de.gsd.vec.new",
                ),
                'es': (
                    "data/ud/UD_Spanish-GSDAnCora/es_gsdancora-ud-train.conllx",
                    "data/ud/UD_Spanish-GSDAnCora/es_gsdancora-ud-dev.conllx",
                    "data/ud/UD_Spanish-GSDAnCora/es_gsdancora-ud-test.conllx",
                    "data/fastText_data/wiki.es.gsdancora.vec.new",
                ),
                'es20': (
                    "data/ud/UD_Spanish-GSDAnCora/es_gsdancora-ud-train20.conllx",
                    "data/ud/UD_Spanish-GSDAnCora/es_gsdancora-ud-dev.conllx",
                    "data/ud/UD_Spanish-GSDAnCora/es_gsdancora-ud-test.conllx",
                    "data/fastText_data/wiki.es.gsdancora.vec.new",
                ),
                'es40': (
                    "data/ud/UD_Spanish-GSDAnCora/es_gsdancora-ud-train40.conllx",
                    "data/ud/UD_Spanish-GSDAnCora/es_gsdancora-ud-dev.conllx",
                    "data/ud/UD_Spanish-GSDAnCora/es_gsdancora-ud-test.conllx",
                    "data/fastText_data/wiki.es.gsdancora.vec.new",
                ),
                'es60': (
                    "data/ud/UD_Spanish-GSDAnCora/es_gsdancora-ud-train60.conllx",
                    "data/ud/UD_Spanish-GSDAnCora/es_gsdancora-ud-dev.conllx",
                    "data/ud/UD_Spanish-GSDAnCora/es_gsdancora-ud-test.conllx",
                    "data/fastText_data/wiki.es.gsdancora.vec.new",
                ),
                'es80': (
                    "data/ud/UD_Spanish-GSDAnCora/es_gsdancora-ud-train80.conllx",
                    "data/ud/UD_Spanish-GSDAnCora/es_gsdancora-ud-dev.conllx",
                    "data/ud/UD_Spanish-GSDAnCora/es_gsdancora-ud-test.conllx",
                    "data/fastText_data/wiki.es.gsdancora.vec.new",
                ),
                'fa': (
                    "data/ud/UD_Persian-Seraji/fa_seraji-ud-train.conllx",
                    "data/ud/UD_Persian-Seraji/fa_seraji-ud-dev.conllx",
                    "data/ud/UD_Persian-Seraji/fa_seraji-ud-test.conllx",
                    "data/fastText_data/wiki.fa.seraji.vec.new",
                ),
                'fr': (
                    "data/ud/UD_French-GSD/fr_gsd-ud-train.conllx",
                    "data/ud/UD_French-GSD/fr_gsd-ud-dev.conllx",
                    "data/ud/UD_French-GSD/fr_gsd-ud-test.conllx",
                    "data/fastText_data/wiki.fr.gsd.vec.new",
                ),
                'he': (
                    "data/ud/UD_Hebrew-HTB/he_htb-ud-train.conllx",
                    "data/ud/UD_Hebrew-HTB/he_htb-ud-dev.conllx",
                    "data/ud/UD_Hebrew-HTB/he_htb-ud-test.conllx",
                    "data/fastText_data/wiki.he.htb.vec.new",
                ),
                'hi': (
                    "data/ud/UD_Hindi-HDTB/hi_hdtb-ud-train.conllx",
                    "data/ud/UD_Hindi-HDTB/hi_hdtb-ud-dev.conllx",
                    "data/ud/UD_Hindi-HDTB/hi_hdtb-ud-test.conllx",
                    "data/fastText_data/wiki.hi.hdtb.vec.new",
                ),
                'hr': (
                    "data/ud/UD_Croatian-SET/hr_set-ud-train.conllx",
                    "data/ud/UD_Croatian-SET/hr_set-ud-dev.conllx",
                    "data/ud/UD_Croatian-SET/hr_set-ud-test.conllx",
                    "data/fastText_data/wiki.hr.set.vec.new",
                ),
                'id': (
                    "data/ud/UD_Indonesian-GSD/id_gsd-ud-train.conllx",
                    "data/ud/UD_Indonesian-GSD/id_gsd-ud-dev.conllx",
                    "data/ud/UD_Indonesian-GSD/id_gsd-ud-test.conllx",
                    "data/fastText_data/wiki.id.gsd.vec.new",
                ),
                'it': (
                    "data/ud/UD_Italian-ISDT/it_isdt-ud-train.conllx",
                    "data/ud/UD_Italian-ISDT/it_isdt-ud-dev.conllx",
                    "data/ud/UD_Italian-ISDT/it_isdt-ud-test.conllx",
                    "data/fastText_data/wiki.it.isdt.vec.new",
                ),
                'ja': (
                    "data/ud/UD_Japanese-GSD/ja_gsd-ud-train.conllx",
                    "data/ud/UD_Japanese-GSD/ja_gsd-ud-dev.conllx",
                    "data/ud/UD_Japanese-GSD/ja_gsd-ud-test.conllx",
                    "data/fastText_data/wiki.ja.gsd.vec.new",
                ),
                'ko': (
                    "data/ud/UD_Korean-GSDKaist/ko_gsdkaist-ud-train.conllx",
                    "data/ud/UD_Korean-GSDKaist/ko_gsdkaist-ud-dev.conllx",
                    "data/ud/UD_Korean-GSDKaist/ko_gsdkaist-ud-test.conllx",
                    "data/fastText_data/wiki.ko.gsdkaist.vec.new",
                ),
                'nl': (
                    "data/ud/UD_Dutch-AlpinoLassySmall/nl_alpinolassysmall-ud-train.conllx",
                    "data/ud/UD_Dutch-AlpinoLassySmall/nl_alpinolassysmall-ud-dev.conllx",
                    "data/ud/UD_Dutch-AlpinoLassySmall/nl_alpinolassysmall-ud-test.conllx",
                    "data/fastText_data/wiki.nl.alpinolassysmall.vec.new",
                ),
                'no': (
                    "data/ud/UD_Norwegian-BokmaalNynorsk/no_bokmaalnynorsk-ud-train.conllx",
                    "data/ud/UD_Norwegian-BokmaalNynorsk/no_bokmaalnynorsk-ud-dev.conllx",
                    "data/ud/UD_Norwegian-BokmaalNynorsk/no_bokmaalnynorsk-ud-test.conllx",
                    "data/fastText_data/wiki.no.bokmaalnynorsk.vec.new",
                ),
                'pt': (
                    "data/ud/UD_Portuguese-BosqueGSD/pt_bosquegsd-ud-train.conllx",
                    "data/ud/UD_Portuguese-BosqueGSD/pt_bosquegsd-ud-dev.conllx",
                    "data/ud/UD_Portuguese-BosqueGSD/pt_bosquegsd-ud-test.conllx",
                    "data/fastText_data/wiki.pt.bosquegsd.vec.new",
                ),
                'sv': (
                    "data/ud/UD_Swedish-Talbanken/sv_talbanken-ud-train.conllx",
                    "data/ud/UD_Swedish-Talbanken/sv_talbanken-ud-dev.conllx",
                    "data/ud/UD_Swedish-Talbanken/sv_talbanken-ud-test.conllx",
                    "data/fastText_data/wiki.sv.talbanken.vec.new",
                ),
                'tr': (
                    "data/ud/UD_Turkish-IMST/tr_imst-ud-train.conllx",
                    "data/ud/UD_Turkish-IMST/tr_imst-ud-dev.conllx",
                    "data/ud/UD_Turkish-IMST/tr_imst-ud-test.conllx",
                    "data/fastText_data/wiki.tr.imst.vec.new",
                ),
                'zh': (
                    "data/ud/UD_Chinese-GSD/zh_gsd-ud-train.conllx",
                    "data/ud/UD_Chinese-GSD/zh_gsd-ud-dev.conllx",
                    "data/ud/UD_Chinese-GSD/zh_gsd-ud-test.conllx",
                    "data/fastText_data/wiki.zh.gsd.vec.new",
                )}

        self.args.ftrain = args.ud_dataset[args.lang][0]
        self.args.fdev = args.ud_dataset[args.lang][1]
        self.args.ftest = args.ud_dataset[args.lang][2]
        self.args.fembed = args.ud_dataset[args.lang][3]

        if not os.path.exists(args.file):
            os.mkdir(args.file)
        if not os.path.exists(args.fields) or args.preprocess:
            logging.info("Preprocess the data")
            
            self.WORD = Field('words', pad=pad, unk=unk, bos=bos, lower=True)

            tokenizer = BertTokenizer.from_pretrained(args.bert_model)
            self.BERT = BertField('bert', pad='[PAD]', bos='[CLS]',
                                    tokenize=tokenizer.encode)

            if args.feat == 'char':
                self.FEAT = CharField('chars', pad=pad, unk=unk, bos=bos,
                                      fix_len=args.fix_len, tokenize=list)
            elif args.feat == 'bert':
                tokenizer = BertTokenizer.from_pretrained(args.bert_model)
                self.FEAT = BertField('bert', pad='[PAD]', bos='[CLS]',
                                      tokenize=tokenizer.encode)
            else:
                self.FEAT = Field('tags', bos=bos)
            self.HEAD = Field('heads', bos=bos, use_vocab=False, fn=int)
            self.REL = Field('rels', bos=bos)
            if args.feat in ('char', 'bert'):
                self.fields = CoNLL(FORM=(self.WORD, self.BERT, self.FEAT),
                                    HEAD=self.HEAD, DEPREL=self.REL)
            else:
                self.fields = CoNLL(FORM=(self.WORD, self.BERT), CPOS=self.FEAT,
                                    HEAD=self.HEAD, DEPREL=self.REL)

            train = Corpus.load(args.ftrain, self.fields, args.max_len)
            if args.fembed:
                if args.bert is False:
                    # fasttext
                    embed = Embedding.load(args.fembed, args.lang, unk=args.unk)
                else:
                    embed = None
            else:
                embed = None
            
            self.WORD.build(train, args.min_freq, embed)
            self.FEAT.build(train)
            self.BERT.build(train)
            self.REL.build(train)
            torch.save(self.fields, args.fields)
        else:
            self.fields = torch.load(args.fields)
            if args.feat in ('char', 'bert'):
                self.WORD, self.BERT, self.FEAT = self.fields.FORM
            else:
                self.WORD, self.BERT, self.FEAT = self.fields.FORM, self.fields.CPOS
            self.HEAD, self.REL = self.fields.HEAD, self.fields.DEPREL


        self.puncts = torch.tensor([i for s, i in self.WORD.vocab.stoi.items()
                                    if ispunct(s)]).to(args.device)
        self.criterion = nn.CrossEntropyLoss()

        logging.info(f"{self.WORD}\n{self.FEAT}\n{self.BERT}\n{self.HEAD}\n{self.REL}")
        args.update({
            'n_words': self.WORD.vocab.n_init,
            'n_feats': len(self.FEAT.vocab),
            'n_bert': len(self.BERT.vocab),
            'n_rels': len(self.REL.vocab),
            'pad_index': self.WORD.pad_index,
            'unk_index': self.WORD.unk_index,
            'bos_index': self.WORD.bos_index
        })
        logging.info(f"n_words {args.n_words} n_feats {args.n_feats} n_bert {args.n_bert} pad_index {args.pad_index} bos_index {args.bos_index}")
예제 #7
0
파일: cmd.py 프로젝트: ironsword666/CWS
    def __call__(self, args):
        self.args = args
        if not os.path.exists(args.file):
            os.mkdir(args.file)
        if not os.path.exists(args.fields) or args.preprocess:
            print("Preprocess the data")

            self.CHAR = Field('chars', pad=pad, unk=unk,
                              bos=bos, eos=eos, lower=True)
                              
            # TODO span as label, modify chartfield to spanfield
            self.SEG = SegmentField('segs')

            if args.feat == 'bert':
                tokenizer = BertTokenizer.from_pretrained(args.bert_model)
                self.FEAT = BertField('bert',
                                      pad='[PAD]',
                                      bos='[CLS]',
                                      eos='[SEP]',
                                      tokenize=tokenizer.encode)
                self.fields = CoNLL(CHAR=(self.CHAR, self.FEAT),
                                    SEG=self.SEG)
            elif args.feat == 'bigram':
                self.BIGRAM = NGramField(
                    'bichar', n=2, pad=pad, unk=unk, bos=bos, eos=eos, lower=True)
                self.fields = CoNLL(CHAR=(self.CHAR, self.BIGRAM),
                                    SEG=self.SEG)
            elif args.feat == 'trigram':
                self.BIGRAM = NGramField(
                    'bichar', n=2, pad=pad, unk=unk, bos=bos, eos=eos, lower=True)
                self.TRIGRAM = NGramField(
                    'trichar', n=3, pad=pad, unk=unk, bos=bos, eos=eos, lower=True)
                self.fields = CoNLL(CHAR=(self.CHAR,
                                          self.BIGRAM,
                                          self.TRIGRAM),
                                    SEG=self.SEG)
            else:
                self.fields = CoNLL(CHAR=self.CHAR,
                                    SEG=self.SEG)

            train = Corpus.load(args.ftrain, self.fields)
            embed = Embedding.load(
                'data/tencent.char.200.txt',
                args.unk) if args.embed else None
            self.CHAR.build(train, args.min_freq, embed)
            if hasattr(self, 'FEAT'):
                self.FEAT.build(train)
            if hasattr(self, 'BIGRAM'):
                embed = Embedding.load(
                    'data/tencent.bi.200.txt',
                    args.unk) if args.embed else None
                self.BIGRAM.build(train, args.min_freq,
                                  embed=embed,
                                  dict_file=args.dict_file)
            if hasattr(self, 'TRIGRAM'):
                embed = Embedding.load(
                    'data/tencent.tri.200.txt',
                    args.unk) if args.embed else None
                self.TRIGRAM.build(train, args.min_freq,
                                   embed=embed,
                                   dict_file=args.dict_file)
            # TODO
            self.SEG.build(train)
            torch.save(self.fields, args.fields)
        else:
            self.fields = torch.load(args.fields)
            if args.feat == 'bert':
                self.CHAR, self.FEAT = self.fields.CHAR
            elif args.feat == 'bigram':
                self.CHAR, self.BIGRAM = self.fields.CHAR
            elif args.feat == 'trigram':
                self.CHAR, self.BIGRAM, self.TRIGRAM = self.fields.CHAR
            else:
                self.CHAR = self.fields.CHAR
            # TODO
            self.SEG = self.fields.SEG
        # TODO loss funciton 
        # self.criterion = nn.CrossEntropyLoss()
        # # [B, E, M, S]
        # self.trans = (torch.tensor([1., 0., 0., 1.]).log().to(args.device),
        #               torch.tensor([0., 1., 0., 1.]).log().to(args.device),
        #               torch.tensor([[0., 1., 1., 0.],
        #                             [1., 0., 0., 1.],
        #                             [0., 1., 1., 0.],
        #                             [1., 0., 0., 1.]]).log().to(args.device))

        args.update({
            'n_chars': self.CHAR.vocab.n_init,
            'pad_index': self.CHAR.pad_index,
            'unk_index': self.CHAR.unk_index
        })

        # TODO
        vocab = f"{self.CHAR}\n"
        if hasattr(self, 'FEAT'):
            args.update({
                'n_feats': self.FEAT.vocab.n_init,
            })
            vocab += f"{self.FEAT}\n"
        if hasattr(self, 'BIGRAM'):
            args.update({
                'n_bigrams': self.BIGRAM.vocab.n_init,
            })
            vocab += f"{self.BIGRAM}\n"
        if hasattr(self, 'TRIGRAM'):
            args.update({
                'n_trigrams': self.TRIGRAM.vocab.n_init,
            })
            vocab += f"{self.TRIGRAM}\n"

        print(f"Override the default configs\n{args}")
        print(vocab[:-1])
예제 #8
0
    def __call__(self, config):
        if not os.path.exists(config.file):
            os.mkdir(config.file)
        if config.preprocess or not os.path.exists(config.vocab):
            print("Preprocess the corpus")
            pos_train = Corpus.load(config.fptrain, [1, 4], config.pos)
            dep_train = Corpus.load(config.ftrain)
            pos_dev = Corpus.load(config.fpdev, [1, 4])
            dep_dev = Corpus.load(config.fdev)
            pos_test = Corpus.load(config.fptest, [1, 4])
            dep_test = Corpus.load(config.ftest)
            print("Create the vocab")
            vocab = Vocab.from_corpora(pos_train, dep_train, 2)
            vocab.read_embeddings(Embedding.load(config.fembed))
            print("Load the dataset")
            pos_trainset = TextDataset(vocab.numericalize(pos_train, False),
                                       config.buckets)
            dep_trainset = TextDataset(vocab.numericalize(dep_train),
                                       config.buckets)
            pos_devset = TextDataset(vocab.numericalize(pos_dev, False),
                                     config.buckets)
            dep_devset = TextDataset(vocab.numericalize(dep_dev),
                                     config.buckets)
            pos_testset = TextDataset(vocab.numericalize(pos_test, False),
                                      config.buckets)
            dep_testset = TextDataset(vocab.numericalize(dep_test),
                                      config.buckets)
            torch.save(vocab, config.vocab)
            torch.save(pos_trainset, os.path.join(config.file, 'pos_trainset'))
            torch.save(dep_trainset, os.path.join(config.file, 'dep_trainset'))
            torch.save(pos_devset, os.path.join(config.file, 'pos_devset'))
            torch.save(dep_devset, os.path.join(config.file, 'dep_devset'))
            torch.save(pos_testset, os.path.join(config.file, 'pos_testset'))
            torch.save(dep_testset, os.path.join(config.file, 'dep_testset'))
        else:
            print("Load the vocab")
            vocab = torch.load(config.vocab)
            print("Load the datasets")
            pos_trainset = torch.load(os.path.join(config.file,
                                                   'pos_trainset'))
            dep_trainset = torch.load(os.path.join(config.file,
                                                   'dep_trainset'))
            pos_devset = torch.load(os.path.join(config.file, 'pos_devset'))
            dep_devset = torch.load(os.path.join(config.file, 'dep_devset'))
            pos_testset = torch.load(os.path.join(config.file, 'pos_testset'))
            dep_testset = torch.load(os.path.join(config.file, 'dep_testset'))
        config.update({
            'n_words': vocab.n_init,
            'n_chars': vocab.n_chars,
            'n_pos_tags': vocab.n_pos_tags,
            'n_dep_tags': vocab.n_dep_tags,
            'n_rels': vocab.n_rels,
            'pad_index': vocab.pad_index,
            'unk_index': vocab.unk_index
        })
        # set the data loaders
        pos_train_loader = batchify(
            pos_trainset, config.pos_batch_size // config.update_steps, True)
        dep_train_loader = batchify(dep_trainset,
                                    config.batch_size // config.update_steps,
                                    True)
        pos_dev_loader = batchify(pos_devset, config.pos_batch_size)
        dep_dev_loader = batchify(dep_devset, config.batch_size)
        pos_test_loader = batchify(pos_testset, config.pos_batch_size)
        dep_test_loader = batchify(dep_testset, config.batch_size)

        print(vocab)
        print(f"{'pos_train:':10} {len(pos_trainset):7} sentences in total, "
              f"{len(pos_train_loader):4} batches provided")
        print(f"{'dep_train:':10} {len(dep_trainset):7} sentences in total, "
              f"{len(dep_train_loader):4} batches provided")
        print(f"{'pos_dev:':10} {len(pos_devset):7} sentences in total, "
              f"{len(pos_dev_loader):4} batches provided")
        print(f"{'dep_dev:':10} {len(dep_devset):7} sentences in total, "
              f"{len(dep_dev_loader):4} batches provided")
        print(f"{'pos_test:':10} {len(pos_testset):7} sentences in total, "
              f"{len(pos_test_loader):4} batches provided")
        print(f"{'dep_test:':10} {len(dep_testset):7} sentences in total, "
              f"{len(dep_test_loader):4} batches provided")

        print("Create the model")
        parser = BiaffineParser(config, vocab.embed).to(config.device)
        print(f"{parser}\n")

        model = Model(config, vocab, parser)

        total_time = timedelta()
        best_e, best_metric = 1, AttachmentMethod()
        model.optimizer = Adam(model.parser.parameters(), config.lr,
                               (config.mu, config.nu), config.epsilon)
        model.scheduler = ExponentialLR(model.optimizer,
                                        config.decay**(1 / config.decay_steps))

        for epoch in range(1, config.epochs + 1):
            start = datetime.now()
            # train one epoch and update the parameters
            model.train(pos_train_loader, dep_train_loader)
            print(f"Epoch {epoch} / {config.epochs}:")
            lp, ld, mp, mdt, mdp = model.evaluate(None, dep_train_loader)
            print(f"{'train:':6} LP: {lp:.4f} LD: {ld:.4f} {mp} {mdt} {mdp}")
            lp, ld, mp, mdt, dev_m = model.evaluate(pos_dev_loader,
                                                    dep_dev_loader)
            print(f"{'dev:':6} LP: {lp:.4f} LD: {ld:.4f} {mp} {mdt} {dev_m}")
            lp, ld, mp, mdt, mdp = model.evaluate(pos_test_loader,
                                                  dep_test_loader)
            print(f"{'test:':6} LP: {lp:.4f} LD: {ld:.4f} {mp} {mdt} {mdp}")

            t = datetime.now() - start
            # save the model if it is the best so far
            if dev_m > best_metric and epoch > config.patience:
                best_e, best_metric = epoch, dev_m
                model.parser.save(config.model)
                print(f"{t}s elapsed (saved)\n")
            else:
                print(f"{t}s elapsed\n")
            total_time += t
            if epoch - best_e >= config.patience:
                break
        model.parser = BiaffineParser.load(config.model)
        lp, ld, mp, mdt, mdp = model.evaluate(pos_test_loader, dep_test_loader)

        print(f"max score of dev is {best_metric.score:.2%} at epoch {best_e}")
        print(f"the score of test at epoch {best_e} is {mdp.score:.2%}")
        print(f"average time of each epoch is {total_time / epoch}s")
        print(f"{total_time}s elapsed")