コード例 #1
0
ファイル: dependency.py プロジェクト: ashim95/parser
    def _train(self, loader):
        self.model.train()

        bar, metric = progress_bar(loader), AttachmentMetric()

        for words, feats, arcs, sibs, rels in bar:
            self.optimizer.zero_grad()

            mask = words.ne(self.WORD.pad_index)
            # ignore the first token of each sentence
            mask[:, 0] = 0
            s_arc, s_sib, s_rel = self.model(words, feats)
            loss, s_arc = self.model.loss(s_arc, s_sib, s_rel, arcs, sibs, rels, mask, self.args.mbr, self.args.partial)
            loss.backward()
            nn.utils.clip_grad_norm_(self.model.parameters(), self.args.clip)
            self.optimizer.step()
            self.scheduler.step()

            arc_preds, rel_preds = self.model.decode(s_arc, s_sib, s_rel, mask)
            if self.args.partial:
                mask &= arcs.ge(0)
            # ignore all punctuation if not specified
            if not self.args.punct:
                mask &= words.unsqueeze(-1).ne(self.puncts).all(-1)
            metric(arc_preds, rel_preds, arcs, rels, mask)
            bar.set_postfix_str(f"lr: {self.scheduler.get_last_lr()[0]:.4e} - loss: {loss:.4f} - {metric}")
コード例 #2
0
    def _evaluate(self, loader):
        self.model.eval()

        total_loss, metric = 0, AttachmentMetric()

        for words, feats, arcs, rels in loader:
            mask = words.ne(self.WORD.pad_index)
            # ignore the first token of each sentence
            mask[:, 0] = 0
            s_arc, s_rel = self.model(words, feats)
            loss, s_arc = self.model.loss(s_arc, s_rel, arcs, rels, mask,
                                          self.args.mbr,
                                          self.args.partial)
            arc_preds, rel_preds = self.model.decode(s_arc, s_rel, mask,
                                                     self.args.tree,
                                                     self.args.proj)
            if self.args.partial:
                mask &= arcs.ge(0)
            # ignore all punctuation if not specified
            if not self.args.punct:
                mask &= words.unsqueeze(-1).ne(self.puncts).all(-1)
            total_loss += loss.item()
            metric(arc_preds, rel_preds, arcs, rels, mask)
        total_loss /= len(loader)

        return total_loss, metric
コード例 #3
0
ファイル: dep.py プロジェクト: ericxsun/parser
    def _train(self, loader):
        self.model.train()

        bar, metric = progress_bar(loader), AttachmentMetric()

        for i, (words, texts, *feats, arcs, rels) in enumerate(bar, 1):
            word_mask = words.ne(self.args.pad_index)
            mask = word_mask if len(words.shape) < 3 else word_mask.any(-1)
            # ignore the first token of each sentence
            mask[:, 0] = 0
            s_arc, s_sib, s_rel = self.model(words, feats)
            loss, s_arc = self.model.loss(s_arc, s_sib, s_rel, arcs, rels, mask)
            loss = loss / self.args.update_steps
            loss.backward()
            nn.utils.clip_grad_norm_(self.model.parameters(), self.args.clip)
            if i % self.args.update_steps == 0:
                self.optimizer.step()
                self.scheduler.step()
                self.optimizer.zero_grad()

            arc_preds, rel_preds = self.model.decode(s_arc, s_rel, mask)
            if self.args.partial:
                mask &= arcs.ge(0)
            # ignore all punctuation if not specified
            if not self.args.punct:
                mask.masked_scatter_(mask, ~mask.new_tensor([ispunct(w) for s in texts for w in s]))
            metric(arc_preds, rel_preds, arcs, rels, mask)
            bar.set_postfix_str(f"lr: {self.scheduler.get_last_lr()[0]:.4e} - loss: {loss:.4f} - {metric}")
        logger.info(f"{bar.postfix}")
コード例 #4
0
    def _train(self, loader):
        self.model.train()

        bar, metric = progress_bar(loader), AttachmentMetric()
        # words, feats, etc. come from loader! loader is train.loader, where train is Dataset
        for words, feats, arcs, rels in bar:
            self.optimizer.zero_grad()
            if self.elmo:
                feat_embs = self.elmo.embed_batch(feats)
            else:
                feat_embs = self.efml.sents2elmo(feats, output_layer=-2)
            #TODO: dodaj mapping, ce in samo ce gre za vecmap
            if self.args.map_method == 'vecmap':
                # map feat_embs with vecmap, actually self.mapper defined in class init
                feat_embs = self.mapper.map_batch(feat_embs)
                
            mask = words.ne(self.WORD.pad_index)
            # ignore the first token of each sentence
            mask[:, 0] = 0
            
            feats0 = torch.zeros(words.shape+(1024,)) # words.clone()
            feats1 = torch.zeros(words.shape+(1024,))
            feats2 = torch.zeros(words.shape+(1024,))
            # words get ignored, all input comes from feats - 3 elmo layers
            # still inputting words due to reasons(tm)
            
            #feats0 = feats0.unsqueeze(-1)
            #feats0 = feats0.expand(words.shape+(1024,))
            for sentence in range(len(feat_embs)):
                for token in range(len(feat_embs[sentence][1])):
                    feats0[sentence][token] = torch.Tensor(feat_embs[sentence][0][token])
                    feats1[sentence][token] = torch.Tensor(feat_embs[sentence][1][token])
                    feats2[sentence][token] = torch.Tensor(feat_embs[sentence][2][token])
            feats = torch.cat((feats0, feats1, feats2), -1)
            if str(self.args.device) == '-1':
                feats = feats.to('cpu')
            else:
                feats = feats.to('cuda:'+str(self.args.device)) #TODO: fix to allow cpu or gpu
            s_arc, s_rel = self.model(words, feats) #INFO: here is the data input, y = model(x)
            loss = self.model.loss(s_arc, s_rel, arcs, rels, mask, self.args.partial)
            loss.backward()
            nn.utils.clip_grad_norm_(self.model.parameters(), self.args.clip)
            self.optimizer.step()
            self.scheduler.step()

            arc_preds, rel_preds = self.model.decode(s_arc, s_rel, mask)
            if self.args.partial:
                mask &= arcs.ge(0)
            # ignore all punctuation if not specified
            if not self.args.punct:
                mask &= words.unsqueeze(-1).ne(self.puncts).all(-1)
            metric(arc_preds, rel_preds, arcs, rels, mask)
            bar.set_postfix_str(f"lr: {self.scheduler.get_last_lr()[0]:.4e} - loss: {loss:.4f} - {metric}")
コード例 #5
0
    def _evaluate(self, loader):
        print("called _evaluate function")
        print(self.mapper)
        self.model.eval()

        total_loss, metric = 0, AttachmentMetric()

        for words, feats, arcs, rels in loader:
            if self.elmo:
                feat_embs0 = self.elmo.embed_batch(feats)
            else:
                feat_embs0 = self.efml.sents2elmo(feats, output_layer=-2)
            if self.mapper:
                # map feat_embs with self.mapper defined in class init
                feat_embs = self.mapper.map_batch(feat_embs0)
            else:
                feat_embs = feat_embs0
            mask = words.ne(self.WORD.pad_index)
            # ignore the first token of each sentence
            mask[:, 0] = 0
            feats0 = torch.zeros(words.shape+(1024,))
            feats1 = torch.zeros(words.shape+(1024,))
            feats2 = torch.zeros(words.shape+(1024,))
            for sentence in range(len(feat_embs)):
                for token in range(len(feat_embs[sentence][1])):
                    feats0[sentence][token] = torch.Tensor(feat_embs[sentence][0][token])
                    feats1[sentence][token] = torch.Tensor(feat_embs[sentence][1][token])
                    feats2[sentence][token] = torch.Tensor(feat_embs[sentence][2][token])
            feats = torch.cat((feats0, feats1, feats2), -1)
            if str(self.args.device) == '-1':
                feats = feats.to('cpu')
            else:
                feats = feats.to('cuda:'+str(self.args.device))
            s_arc, s_rel = self.model(words, feats)
            loss = self.model.loss(s_arc, s_rel, arcs, rels, mask, self.args.partial)
            arc_preds, rel_preds = self.model.decode(s_arc, s_rel, mask,
                                                     self.args.tree,
                                                     self.args.proj)
            if self.args.partial:
                mask &= arcs.ge(0)
            # ignore all punctuation if not specified
            if not self.args.punct:
                mask &= words.unsqueeze(-1).ne(self.puncts).all(-1)
            total_loss += loss.item()
            metric(arc_preds, rel_preds, arcs, rels, mask)
        total_loss /= len(loader)

        return total_loss, metric
コード例 #6
0
ファイル: dep.py プロジェクト: ericxsun/parser
    def _evaluate(self, loader):
        self.model.eval()

        total_loss, metric = 0, AttachmentMetric()

        for words, texts, *feats, arcs, sibs, rels in loader:
            word_mask = words.ne(self.args.pad_index)
            mask = word_mask if len(words.shape) < 3 else word_mask.any(-1)
            # ignore the first token of each sentence
            mask[:, 0] = 0
            s_arc, s_sib, s_rel = self.model(words, feats)
            loss, s_arc = self.model.loss(s_arc, s_sib, s_rel, arcs, sibs, rels, mask, self.args.mbr, self.args.partial)
            arc_preds, rel_preds = self.model.decode(s_arc, s_sib, s_rel, mask, self.args.tree, self.args.mbr, self.args.proj)
            if self.args.partial:
                mask &= arcs.ge(0)
            # ignore all punctuation if not specified
            if not self.args.punct:
                mask.masked_scatter_(mask, ~mask.new_tensor([ispunct(w) for s in texts for w in s]))
            total_loss += loss.item()
            metric(arc_preds, rel_preds, arcs, rels, mask)
        total_loss /= len(loader)

        return total_loss, metric