def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) if self.args.feat in ('char', 'bert', 'elmo'): self.WORD, self.FEAT = self.transform.FORM else: self.WORD, self.FEAT = self.transform.FORM, self.transform.CPOS self.ARC, self.REL = self.transform.HEAD, self.transform.DEPREL self.puncts = torch.tensor([i for s, i in self.WORD.vocab.stoi.items() if ispunct(s)]).to(self.args.device) if self.args.elmo_options: self.elmo = ElmoEmbedder(self.args.elmo_options, self.args.elmo_weights, -1) else: self.efml = EFML(self.args.elmo_weights) self.elmo = False #print(self.__dict__) if self.args.map_method == 'vecmap': self.mapper = Vecmap(vars(self.args)) elif self.args.map_method == 'elmogan': self.mapper = Elmogan(vars(self.args)) elif self.args.map_method == 'muse': self.mapper = Muse(vars(self.args)) else: self.mapper = None
def _train(self, loader): self.model.train() bar, metric = progress_bar(loader), AttachmentMetric() for i, (words, texts, *feats, arcs, rels) in enumerate(bar, 1): word_mask = words.ne(self.args.pad_index) mask = word_mask if len(words.shape) < 3 else word_mask.any(-1) # ignore the first token of each sentence mask[:, 0] = 0 s_arc, s_sib, s_rel = self.model(words, feats) loss, s_arc = self.model.loss(s_arc, s_sib, s_rel, arcs, rels, mask) loss = loss / self.args.update_steps loss.backward() nn.utils.clip_grad_norm_(self.model.parameters(), self.args.clip) if i % self.args.update_steps == 0: self.optimizer.step() self.scheduler.step() self.optimizer.zero_grad() arc_preds, rel_preds = self.model.decode(s_arc, s_rel, mask) if self.args.partial: mask &= arcs.ge(0) # ignore all punctuation if not specified if not self.args.punct: mask.masked_scatter_(mask, ~mask.new_tensor([ispunct(w) for s in texts for w in s])) metric(arc_preds, rel_preds, arcs, rels, mask) bar.set_postfix_str(f"lr: {self.scheduler.get_last_lr()[0]:.4e} - loss: {loss:.4f} - {metric}") logger.info(f"{bar.postfix}")
def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) if self.args.feat in ('char', 'bert'): self.WORD, self.FEAT = self.transform.FORM else: self.WORD, self.FEAT = self.transform.FORM, self.transform.CPOS self.ARC, self.REL = self.transform.HEAD, self.transform.DEPREL self.puncts = torch.tensor([i for s, i in self.WORD.vocab.stoi.items() if ispunct(s)]).to(self.args.device)
def train(WORD, CHAR, ARC, REL, transform, encoder, epoch=60, word_dim=100): model = BiaffineDependencyModel(n_words=WORD.vocab.n_init, n_feats=len(CHAR.vocab), n_rels=len(REL.vocab), pad_index=WORD.pad_index, unk_index=WORD.unk_index, bos_index=WORD.bos_index, feat_pad_index=CHAR.pad_index, encoder=encoder, n_embed=word_dim) model.load_pretrained(WORD.embed) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print(device) model.to(device) puncts = torch.tensor( [i for s, i in WORD.vocab.stoi.items() if ispunct(s)]).to(device) train, dev, test = get_dataset(transform) # train.sentences = train.sentences[:30000] dev.sentences = dev.sentences[:200] test.sentences = test.sentences[:200] print('train sentences:%d dev sentences:%d test sentences:%d' % (len(train.sentences), len(dev.sentences), len(test.sentences))) if (encoder == 'lstm'): optimizer = Adam(model.parameters(), lr=2e-3, betas=(0.9, 0.9), eps=1e-12) else: optimizer = ScheduledOptim( Adam(model.parameters(), betas=(0.9, 0.98), eps=1e-09), 2.0, 800, 4000) train_parser(train, dev, test, model, optimizer, transform, WORD, puncts, encoder, epochs=epoch, path=encoder + '_model')
def _evaluate(self, loader): self.model.eval() total_loss, metric = 0, AttachmentMetric() for words, texts, *feats, arcs, sibs, rels in loader: word_mask = words.ne(self.args.pad_index) mask = word_mask if len(words.shape) < 3 else word_mask.any(-1) # ignore the first token of each sentence mask[:, 0] = 0 s_arc, s_sib, s_rel = self.model(words, feats) loss, s_arc = self.model.loss(s_arc, s_sib, s_rel, arcs, sibs, rels, mask, self.args.mbr, self.args.partial) arc_preds, rel_preds = self.model.decode(s_arc, s_sib, s_rel, mask, self.args.tree, self.args.mbr, self.args.proj) if self.args.partial: mask &= arcs.ge(0) # ignore all punctuation if not specified if not self.args.punct: mask.masked_scatter_(mask, ~mask.new_tensor([ispunct(w) for s in texts for w in s])) total_loss += loss.item() metric(arc_preds, rel_preds, arcs, rels, mask) total_loss /= len(loader) return total_loss, metric