def train(self,
              base_path: Union[Path, str],
              fix_len=20,
              min_freq=2,
              buckets=32,
              batch_size=5000,
              lr=2e-3,
              mu=.9,
              nu=.9,
              epsilon=1e-12,
              clip=5.0,
              decay=.75,
              decay_steps=5000,
              patience=100,
              max_epochs=10):
        r"""
        Train any class that implement model interface

        Args:
            base_path (object): Main path to which all output during training is logged and models are saved
            max_epochs: Maximum number of epochs to train. Terminates training if this number is surpassed.
            patience:
            decay_steps:
            decay:
            clip:
            epsilon:
            nu:
            mu:
            lr:
            proj:
            tree:
            batch_size:
            buckets:
            min_freq:
            fix_len:


        """
        ################################################################################################################
        # BUILD
        ################################################################################################################
        args = Config()
        args.feat = self.parser.feat
        args.embed = self.parser.embed
        os.makedirs(os.path.dirname(base_path), exist_ok=True)
        logger.info("Building the fields")
        WORD = Field('words', pad=pad, unk=unk, bos=bos, lower=True)
        if args.feat == 'char':
            FEAT = SubwordField('chars',
                                pad=pad,
                                unk=unk,
                                bos=bos,
                                fix_len=fix_len)
        elif args.feat == 'bert':
            from transformers import AutoTokenizer
            tokenizer = AutoTokenizer.from_pretrained(args.bert)
            args.max_len = min(args.max_len or tokenizer.max_len,
                               tokenizer.max_len)
            FEAT = SubwordField('bert',
                                pad=tokenizer.pad_token,
                                unk=tokenizer.unk_token,
                                bos=tokenizer.bos_token or tokenizer.cls_token,
                                fix_len=fix_len,
                                tokenize=tokenizer.tokenize)
            FEAT.vocab = tokenizer.get_vocab()
        else:
            FEAT = Field('tags', bos=bos)

        ARC = Field('arcs', bos=bos, use_vocab=False, fn=CoNLL.get_arcs)
        REL = Field('rels', bos=bos)
        if args.feat in ('char', 'bert'):
            transform = CoNLL(FORM=(WORD, FEAT), HEAD=ARC, DEPREL=REL)
        else:
            transform = CoNLL(FORM=WORD, CPOS=FEAT, HEAD=ARC, DEPREL=REL)

        train = Dataset(transform, self.corpus.train)
        WORD.build(
            train, min_freq,
            (Embedding.load(args.embed, unk) if self.parser.embed else None))
        FEAT.build(train)
        REL.build(train)
        args.update({
            'n_words': WORD.vocab.n_init,
            'n_feats': len(FEAT.vocab),
            'n_rels': len(REL.vocab),
            'pad_index': WORD.pad_index,
            'unk_index': WORD.unk_index,
            'bos_index': WORD.bos_index,
            'feat_pad_index': FEAT.pad_index,
        })
        parser = DependencyParser(
            n_words=args.n_words,
            n_feats=args.n_feats,
            n_rels=args.n_feats,
            pad_index=args.pad_index,
            unk_index=args.unk_index,
            # bos_index=args.bos_index,
            feat_pad_index=args.feat_pad_index,
            transform=transform)
        word_field_embeddings = self.parser.embeddings[0]
        word_field_embeddings.n_vocab = 1000
        parser.embeddings = self.parser.embeddings
        parser.embeddings[0] = word_field_embeddings
        parser.load_pretrained(WORD.embed).to(device)

        ################################################################################################################
        # TRAIN
        ################################################################################################################
        args = Config()
        parser.transform.train()
        if dist.is_initialized():
            batch_size = batch_size // dist.get_world_size()
        logger.info('Loading the data')
        train = Dataset(parser.transform, self.corpus.train, **args)
        dev = Dataset(parser.transform, self.corpus.dev)
        test = Dataset(parser.transform, self.corpus.test)
        train.build(batch_size, buckets, True, dist.is_initialized())
        dev.build(batch_size, buckets)
        test.build(batch_size, buckets)
        logger.info(
            f"\n{'train:':6} {train}\n{'dev:':6} {dev}\n{'test:':6} {test}\n")
        logger.info(f'{parser}')
        if dist.is_initialized():
            parser = DDP(parser,
                         device_ids=[dist.get_rank()],
                         find_unused_parameters=True)

        optimizer = Adam(parser.parameters(), lr, (mu, nu), epsilon)
        scheduler = ExponentialLR(optimizer, decay**(1 / decay_steps))

        elapsed = timedelta()
        best_e, best_metric = 1, Metric()

        for epoch in range(1, max_epochs + 1):
            start = datetime.now()
            logger.info(f'Epoch {epoch} / {max_epochs}:')

            parser.train()

            loader = train.loader
            bar, metric = progress_bar(loader), AttachmentMetric()
            for words, feats, arcs, rels in bar:
                optimizer.zero_grad()

                mask = words.ne(parser.WORD.pad_index)
                # ignore the first token of each sentence
                mask[:, 0] = 0
                s_arc, s_rel = parser.forward(words, feats)
                loss = parser.forward_loss(s_arc, s_rel, arcs, rels, mask)
                loss.backward()
                nn.utils.clip_grad_norm_(parser.parameters(), clip)
                optimizer.step()
                scheduler.step()

                arc_preds, rel_preds = parser.decode(s_arc, s_rel, mask)
                # ignore all punctuation if not specified
                if not self.parser.args['punct']:
                    mask &= words.unsqueeze(-1).ne(parser.puncts).all(-1)
                metric(arc_preds, rel_preds, arcs, rels, mask)
                bar.set_postfix_str(
                    f'lr: {scheduler.get_last_lr()[0]:.4e} - loss: {loss:.4f} - {metric}'
                )

            loss, dev_metric = parser.evaluate(dev.loader)
            logger.info(f"{'dev:':6} - loss: {loss:.4f} - {dev_metric}")
            loss, test_metric = parser.evaluate(test.loader)
            logger.info(f"{'test:':6} - loss: {loss:.4f} - {test_metric}")

            t = datetime.now() - start
            # save the model if it is the best so far
            if dev_metric > best_metric:
                best_e, best_metric = epoch, dev_metric
                if is_master():
                    parser.save(base_path)
                logger.info(f'{t}s elapsed (saved)\n')
            else:
                logger.info(f'{t}s elapsed\n')
            elapsed += t
            if epoch - best_e >= patience:
                break
        loss, metric = parser.load(base_path).evaluate(test.loader)

        logger.info(f'Epoch {best_e} saved')
        logger.info(f"{'dev:':6} - {best_metric}")
        logger.info(f"{'test:':6} - {metric}")
        logger.info(f'{elapsed}s elapsed, {elapsed / epoch}s/epoch')
Example #2
0
from underthesea.models.dependency_parser import DependencyParser

parser = DependencyParser.load('./tmp/resources/parsers/dp')

dataset = parser.predict([[
    'Đó', 'là', 'kết quả', 'của', 'cuộc', 'vật lộn', 'bền bỉ', 'gần', '17',
    'năm', 'của', 'Huỳnh Đỗi', '.'
]],
                         verbose=False)
print(dataset.sentences)
Example #3
0
def init_parser():
    global uts_parser
    if not uts_parser:
        uts_parser = DependencyParser.load('vi-dp-v1')
    return uts_parser
Example #4
0
from os.path import join

from underthesea.file_utils import MODELS_FOLDER
from underthesea.models.dependency_parser import DependencyParser

base_path = join(MODELS_FOLDER, 'parsers', 'vi-dp-v1.3.2a2')
parser = DependencyParser.load(base_path)
sentences = [
    ['Đó', 'là', 'kết quả', 'của', 'cuộc', 'vật lộn', 'bền bỉ', 'gần', '17', 'năm', 'của', 'Huỳnh Đỗi', '.']
]
dataset = parser.predict(sentences)
print(dataset.sentences)