def train(self, base_path: Union[Path, str], fix_len=20, min_freq=2, buckets=32, batch_size=5000, lr=2e-3, mu=.9, nu=.9, epsilon=1e-12, clip=5.0, decay=.75, decay_steps=5000, patience=100, max_epochs=10): r""" Train any class that implement model interface Args: base_path (object): Main path to which all output during training is logged and models are saved max_epochs: Maximum number of epochs to train. Terminates training if this number is surpassed. patience: decay_steps: decay: clip: epsilon: nu: mu: lr: proj: tree: batch_size: buckets: min_freq: fix_len: """ ################################################################################################################ # BUILD ################################################################################################################ args = Config() args.feat = self.parser.feat args.embed = self.parser.embed os.makedirs(os.path.dirname(base_path), exist_ok=True) logger.info("Building the fields") WORD = Field('words', pad=pad, unk=unk, bos=bos, lower=True) if args.feat == 'char': FEAT = SubwordField('chars', pad=pad, unk=unk, bos=bos, fix_len=fix_len) elif args.feat == 'bert': from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained(args.bert) args.max_len = min(args.max_len or tokenizer.max_len, tokenizer.max_len) FEAT = SubwordField('bert', pad=tokenizer.pad_token, unk=tokenizer.unk_token, bos=tokenizer.bos_token or tokenizer.cls_token, fix_len=fix_len, tokenize=tokenizer.tokenize) FEAT.vocab = tokenizer.get_vocab() else: FEAT = Field('tags', bos=bos) ARC = Field('arcs', bos=bos, use_vocab=False, fn=CoNLL.get_arcs) REL = Field('rels', bos=bos) if args.feat in ('char', 'bert'): transform = CoNLL(FORM=(WORD, FEAT), HEAD=ARC, DEPREL=REL) else: transform = CoNLL(FORM=WORD, CPOS=FEAT, HEAD=ARC, DEPREL=REL) train = Dataset(transform, self.corpus.train) WORD.build( train, min_freq, (Embedding.load(args.embed, unk) if self.parser.embed else None)) FEAT.build(train) REL.build(train) args.update({ 'n_words': WORD.vocab.n_init, 'n_feats': len(FEAT.vocab), 'n_rels': len(REL.vocab), 'pad_index': WORD.pad_index, 'unk_index': WORD.unk_index, 'bos_index': WORD.bos_index, 'feat_pad_index': FEAT.pad_index, }) parser = DependencyParser( n_words=args.n_words, n_feats=args.n_feats, n_rels=args.n_feats, pad_index=args.pad_index, unk_index=args.unk_index, # bos_index=args.bos_index, feat_pad_index=args.feat_pad_index, transform=transform) word_field_embeddings = self.parser.embeddings[0] word_field_embeddings.n_vocab = 1000 parser.embeddings = self.parser.embeddings parser.embeddings[0] = word_field_embeddings parser.load_pretrained(WORD.embed).to(device) ################################################################################################################ # TRAIN ################################################################################################################ args = Config() parser.transform.train() if dist.is_initialized(): batch_size = batch_size // dist.get_world_size() logger.info('Loading the data') train = Dataset(parser.transform, self.corpus.train, **args) dev = Dataset(parser.transform, self.corpus.dev) test = Dataset(parser.transform, self.corpus.test) train.build(batch_size, buckets, True, dist.is_initialized()) dev.build(batch_size, buckets) test.build(batch_size, buckets) logger.info( f"\n{'train:':6} {train}\n{'dev:':6} {dev}\n{'test:':6} {test}\n") logger.info(f'{parser}') if dist.is_initialized(): parser = DDP(parser, device_ids=[dist.get_rank()], find_unused_parameters=True) optimizer = Adam(parser.parameters(), lr, (mu, nu), epsilon) scheduler = ExponentialLR(optimizer, decay**(1 / decay_steps)) elapsed = timedelta() best_e, best_metric = 1, Metric() for epoch in range(1, max_epochs + 1): start = datetime.now() logger.info(f'Epoch {epoch} / {max_epochs}:') parser.train() loader = train.loader bar, metric = progress_bar(loader), AttachmentMetric() for words, feats, arcs, rels in bar: optimizer.zero_grad() mask = words.ne(parser.WORD.pad_index) # ignore the first token of each sentence mask[:, 0] = 0 s_arc, s_rel = parser.forward(words, feats) loss = parser.forward_loss(s_arc, s_rel, arcs, rels, mask) loss.backward() nn.utils.clip_grad_norm_(parser.parameters(), clip) optimizer.step() scheduler.step() arc_preds, rel_preds = parser.decode(s_arc, s_rel, mask) # ignore all punctuation if not specified if not self.parser.args['punct']: mask &= words.unsqueeze(-1).ne(parser.puncts).all(-1) metric(arc_preds, rel_preds, arcs, rels, mask) bar.set_postfix_str( f'lr: {scheduler.get_last_lr()[0]:.4e} - loss: {loss:.4f} - {metric}' ) loss, dev_metric = parser.evaluate(dev.loader) logger.info(f"{'dev:':6} - loss: {loss:.4f} - {dev_metric}") loss, test_metric = parser.evaluate(test.loader) logger.info(f"{'test:':6} - loss: {loss:.4f} - {test_metric}") t = datetime.now() - start # save the model if it is the best so far if dev_metric > best_metric: best_e, best_metric = epoch, dev_metric if is_master(): parser.save(base_path) logger.info(f'{t}s elapsed (saved)\n') else: logger.info(f'{t}s elapsed\n') elapsed += t if epoch - best_e >= patience: break loss, metric = parser.load(base_path).evaluate(test.loader) logger.info(f'Epoch {best_e} saved') logger.info(f"{'dev:':6} - {best_metric}") logger.info(f"{'test:':6} - {metric}") logger.info(f'{elapsed}s elapsed, {elapsed / epoch}s/epoch')
from underthesea.models.dependency_parser import DependencyParser parser = DependencyParser.load('./tmp/resources/parsers/dp') dataset = parser.predict([[ 'Đó', 'là', 'kết quả', 'của', 'cuộc', 'vật lộn', 'bền bỉ', 'gần', '17', 'năm', 'của', 'Huỳnh Đỗi', '.' ]], verbose=False) print(dataset.sentences)
def init_parser(): global uts_parser if not uts_parser: uts_parser = DependencyParser.load('vi-dp-v1') return uts_parser
from os.path import join from underthesea.file_utils import MODELS_FOLDER from underthesea.models.dependency_parser import DependencyParser base_path = join(MODELS_FOLDER, 'parsers', 'vi-dp-v1.3.2a2') parser = DependencyParser.load(base_path) sentences = [ ['Đó', 'là', 'kết quả', 'của', 'cuộc', 'vật lộn', 'bền bỉ', 'gần', '17', 'năm', 'của', 'Huỳnh Đỗi', '.'] ] dataset = parser.predict(sentences) print(dataset.sentences)