示例#1
0
文件: sdp.py 项目: ericxsun/parser
    def _predict(self, loader):
        self.model.eval()

        preds = {'labels': [], 'probs': [] if self.args.prob else None}
        for words, *feats in progress_bar(loader):
            word_mask = words.ne(self.args.pad_index)
            mask = word_mask if len(words.shape) < 3 else word_mask.any(-1)
            mask = mask.unsqueeze(1) & mask.unsqueeze(2)
            mask[:, 0] = 0
            lens = mask[:, 1].sum(-1).tolist()
            s_edge, s_sib, s_cop, s_grd, s_label = self.model(words, feats)
            s_edge = self.model.inference((s_edge, s_sib, s_cop, s_grd), mask)
            label_preds = self.model.decode(s_edge,
                                            s_label).masked_fill(~mask, -1)
            preds['labels'].extend(chart[1:i, :i].tolist()
                                   for i, chart in zip(lens, label_preds))
            if self.args.prob:
                preds['probs'].extend([
                    prob[1:i, :i].cpu()
                    for i, prob in zip(lens, s_edge.unbind())
                ])
        preds['labels'] = [
            CoNLL.build_relations(
                [[self.LABEL.vocab[i] if i >= 0 else None for i in row]
                 for row in chart]) for chart in preds['labels']
        ]

        return preds
示例#2
0
    def _train(self, loader):
        self.model.train()

        bar = progress_bar(loader)
        if self.args.em_alg:
            self.model.zero_cache()
        for words, _ in bar:

            mask = words.ne(self.WORD.pad_index)
            # ignore the first token of each sentence

            emit_probs, trans_probs = self.model(words, mask)

            if self.args.em_alg:
                logP = self.model.baum_welch(words, mask, emit_probs,
                                             trans_probs)
            else:
                self.optimizer.zero_grad()
                logP = self.model.get_logP(emit_probs, trans_probs, mask)
                loss = -logP.mean()
                loss.backward()
                nn.utils.clip_grad_norm_(self.model.parameters(),
                                         self.args.clip)
                self.optimizer.step()
                self.scheduler.step()
            bar.set_postfix_str(f" logP: {logP.mean():.4f}")

        if self.args.em_alg:
            self.model.step()
示例#3
0
文件: sdp.py 项目: ericxsun/parser
    def _train(self, loader):
        self.model.train()

        bar, metric = progress_bar(loader), ChartMetric()

        for i, (words, *feats, labels) in enumerate(bar, 1):
            word_mask = words.ne(self.args.pad_index)
            mask = word_mask if len(words.shape) < 3 else word_mask.any(-1)
            mask = mask.unsqueeze(1) & mask.unsqueeze(2)
            mask[:, 0] = 0
            s_edge, s_sib, s_cop, s_grd, s_label = self.model(words, feats)
            loss, s_edge = self.model.loss(s_edge, s_sib, s_cop, s_grd,
                                           s_label, labels, mask)
            loss = loss / self.args.update_steps
            loss.backward()
            nn.utils.clip_grad_norm_(self.model.parameters(), self.args.clip)
            if i % self.args.update_steps == 0:
                self.optimizer.step()
                self.scheduler.step()
                self.optimizer.zero_grad()

            label_preds = self.model.decode(s_edge, s_label)
            metric(label_preds.masked_fill(~mask, -1),
                   labels.masked_fill(~mask, -1))
            bar.set_postfix_str(
                f"lr: {self.scheduler.get_last_lr()[0]:.4e} - loss: {loss:.4f} - {metric}"
            )
        logger.info(f"{bar.postfix}")
示例#4
0
    def _predict(self, loader):
        self.model.eval()

        preds, probs = {'trees': []}, []

        for words, feats, trees in progress_bar(loader):
            batch_size, seq_len = words.shape
            lens = words.ne(self.args.pad_index).sum(1) - 1
            mask = lens.new_tensor(range(seq_len - 1)) < lens.view(-1, 1, 1)
            mask = mask & mask.new_ones(seq_len - 1, seq_len - 1).triu_(1)
            s_span, s_label = self.model(words, feats)
            if self.args.mbr:
                s_span = self.model.crf(s_span, mask, mbr=True)
            chart_preds = self.model.decode(s_span, s_label, mask)
            preds['trees'].extend([
                Tree.build(tree, [(i, j, self.CHART.vocab[label])
                                  for i, j, label in chart])
                for tree, chart in zip(trees, chart_preds)
            ])
            if self.args.prob:
                probs.extend([
                    prob[:i - 1, 1:i].cpu()
                    for i, prob in zip(lens, s_span.unbind())
                ])
        if self.args.prob:
            preds['probs'] = probs

        return preds
示例#5
0
文件: const.py 项目: yzhangcs/parser
    def _predict(self, loader):
        self.model.eval()

        preds = {'trees': [], 'probs': [] if self.args.prob else None}
        for batch in progress_bar(loader):
            words, *feats, trees = batch
            word_mask = words.ne(self.args.pad_index)[:, 1:]
            mask = word_mask if len(words.shape) < 3 else word_mask.any(-1)
            mask = (mask.unsqueeze(1) & mask.unsqueeze(2)).triu_(1)
            lens = mask[:, 0].sum(-1)
            s_span, s_label = self.model(words, feats)
            s_span = ConstituencyCRF(
                s_span, mask[:,
                             0].sum(-1)).marginals if self.args.mbr else s_span
            chart_preds = self.model.decode(s_span, s_label, mask)
            preds['trees'].extend([
                Tree.build(tree, [(i, j, self.CHART.vocab[label])
                                  for i, j, label in chart])
                for tree, chart in zip(trees, chart_preds)
            ])
            if self.args.prob:
                preds['probs'].extend(
                    [prob[:i - 1, 1:i].cpu() for i, prob in zip(lens, s_span)])

        return preds
示例#6
0
    def _train(self, loader):
        self.model.train()

        bar, metric = progress_bar(loader), AttachmentMetric()

        for words, feats, arcs, sibs, rels in bar:
            self.optimizer.zero_grad()

            mask = words.ne(self.WORD.pad_index)
            # ignore the first token of each sentence
            mask[:, 0] = 0
            s_arc, s_sib, s_rel = self.model(words, feats)
            loss, s_arc = self.model.loss(s_arc, s_sib, s_rel, arcs, sibs, rels, mask, self.args.mbr, self.args.partial)
            loss.backward()
            nn.utils.clip_grad_norm_(self.model.parameters(), self.args.clip)
            self.optimizer.step()
            self.scheduler.step()

            arc_preds, rel_preds = self.model.decode(s_arc, s_sib, s_rel, mask)
            if self.args.partial:
                mask &= arcs.ge(0)
            # ignore all punctuation if not specified
            if not self.args.punct:
                mask &= words.unsqueeze(-1).ne(self.puncts).all(-1)
            metric(arc_preds, rel_preds, arcs, rels, mask)
            bar.set_postfix_str(f"lr: {self.scheduler.get_last_lr()[0]:.4e} - loss: {loss:.4f} - {metric}")
示例#7
0
    def _predict(self, loader):
        self.model.eval()

        preds = {}
        arcs, rels, probs = [], [], []
        for words, feats in progress_bar(loader):
            mask = words.ne(self.WORD.pad_index)
            # ignore the first token of each sentence
            mask[:, 0] = 0
            lens = mask.sum(1).tolist()
            s_arc, s_sib, s_rel = self.model(words, feats)
            if self.args.mbr:
                s_arc = self.model.crf((s_arc, s_sib), mask, mbr=True)
            arc_preds, rel_preds = self.model.decode(s_arc, s_sib, s_rel, mask, self.args.tree, self.args.mbr, self.args.proj)
            arcs.extend(arc_preds[mask].split(lens))
            rels.extend(rel_preds[mask].split(lens))
            if self.args.prob:
                arc_probs = s_arc if self.args.mbr else s_arc.softmax(-1)
                probs.extend([prob[1:i+1, :i+1].cpu() for i, prob in zip(lens, arc_probs.unbind())])
        arcs = [seq.tolist() for seq in arcs]
        rels = [self.REL.vocab[seq.tolist()] for seq in rels]
        preds = {'arcs': arcs, 'rels': rels}
        if self.args.prob:
            preds['probs'] = probs

        return preds
示例#8
0
文件: const.py 项目: yzhangcs/parser
    def _train(self, loader):
        self.model.train()

        bar = progress_bar(loader)

        for i, batch in enumerate(bar, 1):
            words, *feats, trees, charts = batch
            word_mask = words.ne(self.args.pad_index)[:, 1:]
            mask = word_mask if len(words.shape) < 3 else word_mask.any(-1)
            mask = (mask.unsqueeze(1) & mask.unsqueeze(2)).triu_(1)
            s_span, s_label = self.model(words, feats)
            loss, _ = self.model.loss(s_span, s_label, charts, mask,
                                      self.args.mbr)
            loss = loss / self.args.update_steps
            loss.backward()
            nn.utils.clip_grad_norm_(self.model.parameters(), self.args.clip)
            if i % self.args.update_steps == 0:
                self.optimizer.step()
                self.scheduler.step()
                self.optimizer.zero_grad()

            bar.set_postfix_str(
                f"lr: {self.scheduler.get_last_lr()[0]:.4e} - loss: {loss:.4f}"
            )
        logger.info(f"{bar.postfix}")
示例#9
0
    def _train(self, loader):
        self.model.train()

        bar, metric = progress_bar(loader), ChartMetric()

        for words, *feats, edges, labels in bar:
            self.optimizer.zero_grad()

            mask = words.ne(self.WORD.pad_index)
            mask = mask.unsqueeze(1) & mask.unsqueeze(2)
            mask[:, 0] = 0
            s_edge, s_sib, s_cop, s_grd, s_label = self.model(words, feats)
            loss, s_edge = self.model.loss(s_edge, s_sib, s_cop, s_grd,
                                           s_label, edges, labels, mask)
            loss.backward()
            nn.utils.clip_grad_norm_(self.model.parameters(), self.args.clip)
            self.optimizer.step()
            self.scheduler.step()

            edge_preds, label_preds = self.model.decode(s_edge, s_label)
            metric(label_preds.masked_fill(~(edge_preds.gt(0) & mask), -1),
                   labels.masked_fill(~(edges.gt(0) & mask), -1))
            bar.set_postfix_str(
                f"lr: {self.scheduler.get_last_lr()[0]:.4e} - loss: {loss:.4f} - {metric}"
            )
示例#10
0
文件: dep.py 项目: ericxsun/parser
    def _train(self, loader):
        self.model.train()

        bar, metric = progress_bar(loader), AttachmentMetric()

        for i, (words, texts, *feats, arcs, rels) in enumerate(bar, 1):
            word_mask = words.ne(self.args.pad_index)
            mask = word_mask if len(words.shape) < 3 else word_mask.any(-1)
            # ignore the first token of each sentence
            mask[:, 0] = 0
            s_arc, s_sib, s_rel = self.model(words, feats)
            loss, s_arc = self.model.loss(s_arc, s_sib, s_rel, arcs, rels, mask)
            loss = loss / self.args.update_steps
            loss.backward()
            nn.utils.clip_grad_norm_(self.model.parameters(), self.args.clip)
            if i % self.args.update_steps == 0:
                self.optimizer.step()
                self.scheduler.step()
                self.optimizer.zero_grad()

            arc_preds, rel_preds = self.model.decode(s_arc, s_rel, mask)
            if self.args.partial:
                mask &= arcs.ge(0)
            # ignore all punctuation if not specified
            if not self.args.punct:
                mask.masked_scatter_(mask, ~mask.new_tensor([ispunct(w) for s in texts for w in s]))
            metric(arc_preds, rel_preds, arcs, rels, mask)
            bar.set_postfix_str(f"lr: {self.scheduler.get_last_lr()[0]:.4e} - loss: {loss:.4f} - {metric}")
        logger.info(f"{bar.postfix}")
示例#11
0
    def load(self, data, lang=None, max_len=None, **kwargs):
        r"""
        Args:
            data (list[list] or str):
                A list of instances or a filename.
            lang (str):
                Language code (e.g., ``en``) or language name (e.g., ``English``) for the text to tokenize.
                ``None`` if tokenization is not required.
                Default: ``None``.
            max_len (int):
                Sentences exceeding the length will be discarded. Default: ``None``.

        Returns:
            A list of :class:`TreeSentence` instances.
        """
        if isinstance(data, str) and os.path.exists(data):
            with open(data, 'r') as f:
                trees = [nltk.Tree.fromstring(s) for s in f]
            self.root = trees[0].label()
        else:
            if lang is not None:
                tokenizer = Tokenizer(lang)
                data = [tokenizer(i) for i in ([data] if isinstance(data, str) else data)]
            else:
                data = [data] if isinstance(data[0], str) else data
            trees = [self.totree(i, self.root) for i in data]

        i, sentences = 0, []
        for tree in progress_bar(trees):
            sentences.append(TreeSentence(self, tree))
            i += 1
        if max_len is not None:
            sentences = [i for i in sentences if len(i) < max_len]

        return sentences
示例#12
0
    def _predict(self, loader):
        self.model.eval()

        preds = {}
        charts, probs = [], []
        for words, *feats in progress_bar(loader):
            mask = words.ne(self.WORD.pad_index)
            mask = mask.unsqueeze(1) & mask.unsqueeze(2)
            mask[:, 0] = 0
            lens = mask[:, 1].sum(-1).tolist()
            s_edge, s_label = self.model(words, feats)
            edge_preds, label_preds = self.model.decode(s_edge, s_label)
            chart_preds = label_preds.masked_fill(~(edge_preds.gt(0) & mask),
                                                  -1)
            charts.extend(chart[1:i, :i].tolist()
                          for i, chart in zip(lens, chart_preds.unbind()))
            if self.args.prob:
                probs.extend([
                    prob[1:i, :i].cpu()
                    for i, prob in zip(lens,
                                       s_edge.softmax(-1).unbind())
                ])
        charts = [
            CoNLL.build_relations(
                [[self.LABEL.vocab[i] if i >= 0 else None for i in row]
                 for row in chart]) for chart in charts
        ]
        preds = {'labels': charts}
        if self.args.prob:
            preds['probs'] = probs

        return preds
示例#13
0
    def load(self, data, max_len=None, **kwargs):
        r"""
        Args:
            data (list[list] or str):
                A list of instances or a filename.
            max_len (int):
                Sentences exceeding the length will be discarded. Default: ``None``.

        Returns:
            A list of :class:`TreeSentence` instances.
        """
        if isinstance(data, str):
            with open(data, 'r') as f:
                trees = [nltk.Tree.fromstring(string) for string in f]
            self.root = trees[0].label()
        else:
            data = [data] if isinstance(data[0], str) else data
            trees = [self.totree(i, self.root) for i in data]

        i, sentences = 0, []
        for tree in progress_bar(trees, leave=False):
            if len(tree) == 1 and not isinstance(tree[0][0], nltk.Tree):
                continue
            sentences.append(TreeSentence(self, tree))
            i += 1
        if max_len is not None:
            sentences = [i for i in sentences if len(i) < max_len]

        return sentences
示例#14
0
    def _train(self, loader):
        self.model.train()

        bar, metric = progress_bar(loader), AttachmentMetric()
        # words, feats, etc. come from loader! loader is train.loader, where train is Dataset
        for words, feats, arcs, rels in bar:
            self.optimizer.zero_grad()
            if self.elmo:
                feat_embs = self.elmo.embed_batch(feats)
            else:
                feat_embs = self.efml.sents2elmo(feats, output_layer=-2)
            #TODO: dodaj mapping, ce in samo ce gre za vecmap
            if self.args.map_method == 'vecmap':
                # map feat_embs with vecmap, actually self.mapper defined in class init
                feat_embs = self.mapper.map_batch(feat_embs)
                
            mask = words.ne(self.WORD.pad_index)
            # ignore the first token of each sentence
            mask[:, 0] = 0
            
            feats0 = torch.zeros(words.shape+(1024,)) # words.clone()
            feats1 = torch.zeros(words.shape+(1024,))
            feats2 = torch.zeros(words.shape+(1024,))
            # words get ignored, all input comes from feats - 3 elmo layers
            # still inputting words due to reasons(tm)
            
            #feats0 = feats0.unsqueeze(-1)
            #feats0 = feats0.expand(words.shape+(1024,))
            for sentence in range(len(feat_embs)):
                for token in range(len(feat_embs[sentence][1])):
                    feats0[sentence][token] = torch.Tensor(feat_embs[sentence][0][token])
                    feats1[sentence][token] = torch.Tensor(feat_embs[sentence][1][token])
                    feats2[sentence][token] = torch.Tensor(feat_embs[sentence][2][token])
            feats = torch.cat((feats0, feats1, feats2), -1)
            if str(self.args.device) == '-1':
                feats = feats.to('cpu')
            else:
                feats = feats.to('cuda:'+str(self.args.device)) #TODO: fix to allow cpu or gpu
            s_arc, s_rel = self.model(words, feats) #INFO: here is the data input, y = model(x)
            loss = self.model.loss(s_arc, s_rel, arcs, rels, mask, self.args.partial)
            loss.backward()
            nn.utils.clip_grad_norm_(self.model.parameters(), self.args.clip)
            self.optimizer.step()
            self.scheduler.step()

            arc_preds, rel_preds = self.model.decode(s_arc, s_rel, mask)
            if self.args.partial:
                mask &= arcs.ge(0)
            # ignore all punctuation if not specified
            if not self.args.punct:
                mask &= words.unsqueeze(-1).ne(self.puncts).all(-1)
            metric(arc_preds, rel_preds, arcs, rels, mask)
            bar.set_postfix_str(f"lr: {self.scheduler.get_last_lr()[0]:.4e} - loss: {loss:.4f} - {metric}")
示例#15
0
    def load(self, data, lang=None, proj=False, max_len=None, **kwargs):
        r"""
        Loads the data in CoNLL-X format.
        Also supports for loading data from CoNLL-U file with comments and non-integer IDs.

        Args:
            data (list[list] or str):
                A list of instances or a filename.
            lang (str):
                Language code (e.g., ``en``) or language name (e.g., ``English``) for the text to tokenize.
                ``None`` if tokenization is not required.
                Default: ``None``.
            proj (bool):
                If ``True``, discards all non-projective sentences. Default: ``False``.
            max_len (int):
                Sentences exceeding the length will be discarded. Default: ``None``.

        Returns:
            A list of :class:`CoNLLSentence` instances.
        """

        if isinstance(data, str) and os.path.exists(data):
            with open(data, 'r') as f:
                lines = [line.strip() for line in f]
        else:
            if lang is not None:
                tokenizer = Tokenizer(lang)
                data = [
                    tokenizer(i)
                    for i in ([data] if isinstance(data, str) else data)
                ]
            else:
                data = [data] if isinstance(data[0], str) else data
            lines = '\n'.join([self.toconll(i) for i in data]).split('\n')

        i, start, sentences = 0, 0, []
        for line in progress_bar(lines):
            if not line:
                sentences.append(CoNLLSentence(self, lines[start:i]))
                start = i + 1
            i += 1
        if proj:
            sentences = [
                i for i in sentences
                if self.isprojective(list(map(int, i.arcs)))
            ]
        if max_len is not None:
            sentences = [i for i in sentences if len(i) < max_len]

        return sentences
示例#16
0
    def _predict(self, loader):
        self.model.eval()

        preds = {}
        tags = []
        for words, in progress_bar(loader):
            mask = words.ne(self.WORD.pad_index)
            lens = mask.sum(1).tolist()
            emit_probs, trans_probs = self.model(words, mask)
            tag_preds = self.model.decode(emit_probs, trans_probs, mask)
            tags.extend(tag_preds[mask].split(lens))

        tags = [[f"#C{t}#" for t in seq.tolist()] for seq in tags]
        preds = {'tags': tags}

        return preds
示例#17
0
    def _predict(self, loader):
        self.model.eval()

        preds = {}
        arcs, rels, probs = [], [], []
        for words, feats in progress_bar(loader):
            if self.elmo:
                feat_embs = self.elmo.embed_batch(feats)
            else:
                feat_embs = self.efml.sents2elmo(feats, output_layer=-2)
            if self.mapper:
                # map feat_embs with self.mapper defined in class init
                feat_embs = self.mapper.map_batch(feat_embs)
            mask = words.ne(self.WORD.pad_index)
            # ignore the first token of each sentence
            mask[:, 0] = 0
            lens = mask.sum(1).tolist()
            feats0 = torch.zeros(words.shape+(1024,))
            feats1 = torch.zeros(words.shape+(1024,))
            feats2 = torch.zeros(words.shape+(1024,))
            for sentence in range(len(feat_embs)):
                for token in range(len(feat_embs[sentence][1])):
                    feats0[sentence][token] = torch.Tensor(feat_embs[sentence][0][token])
                    feats1[sentence][token] = torch.Tensor(feat_embs[sentence][1][token])
                    feats2[sentence][token] = torch.Tensor(feat_embs[sentence][2][token])
            feats = torch.cat((feats0, feats1, feats2), -1)
            if str(self.args.device) == '-1':
                feats = feats.to('cpu')
            else:
                feats = feats.to('cuda:'+str(self.args.device))
            s_arc, s_rel = self.model(words, feats)
            arc_preds, rel_preds = self.model.decode(s_arc, s_rel, mask,
                                                     self.args.tree,
                                                     self.args.proj)
            arcs.extend(arc_preds[mask].split(lens))
            rels.extend(rel_preds[mask].split(lens))
            if self.args.prob:
                arc_probs = s_arc.softmax(-1)
                probs.extend([prob[1:i+1, :i+1].cpu() for i, prob in zip(lens, arc_probs.unbind())])
        arcs = [seq.tolist() for seq in arcs]
        rels = [self.REL.vocab[seq.tolist()] for seq in rels]
        preds = {'arcs': arcs, 'rels': rels}
        if self.args.prob:
            preds['probs'] = probs

        return preds
示例#18
0
文件: dep.py 项目: ericxsun/parser
    def _predict(self, loader):
        self.model.eval()

        preds = {'arcs': [], 'rels': [], 'probs': [] if self.args.prob else None}
        for words, texts, *feats in progress_bar(loader):
            word_mask = words.ne(self.args.pad_index)
            mask = word_mask if len(words.shape) < 3 else word_mask.any(-1)
            # ignore the first token of each sentence
            mask[:, 0] = 0
            lens = mask.sum(1).tolist()
            s_arc, s_rel = self.model(words, feats)
            arc_preds, rel_preds = self.model.decode(s_arc, s_rel, mask, self.args.tree, self.args.proj)
            preds['arcs'].extend(arc_preds[mask].split(lens))
            preds['rels'].extend(rel_preds[mask].split(lens))
            if self.args.prob:
                preds['probs'].extend([prob[1:i+1, :i+1].cpu() for i, prob in zip(lens, s_arc.softmax(-1).unbind())])
        preds['arcs'] = [seq.tolist() for seq in preds['arcs']]
        preds['rels'] = [self.REL.vocab[seq.tolist()] for seq in preds['rels']]

        return preds
示例#19
0
    def _train(self, loader):
        self.model.train()

        bar = progress_bar(loader)

        for words, feats, trees, charts in bar:
            self.optimizer.zero_grad()

            batch_size, seq_len = words.shape
            lens = words.ne(self.args.pad_index).sum(1) - 1
            mask = lens.new_tensor(range(seq_len - 1)) < lens.view(-1, 1, 1)
            mask = mask & mask.new_ones(seq_len-1, seq_len-1).triu_(1)
            s_span, s_label = self.model(words, feats)
            loss, _ = self.model.loss(s_span, s_label, charts, mask, self.args.mbr)
            loss.backward()
            nn.utils.clip_grad_norm_(self.model.parameters(), self.args.clip)
            self.optimizer.step()
            self.scheduler.step()

            bar.set_postfix_str(f"lr: {self.scheduler.get_last_lr()[0]:.4e} - loss: {loss:.4f}")
示例#20
0
    def load(self, data, proj=False, max_len=None, **kwargs):
        r"""
        Loads the data in CoNLL-X format.
        Also supports for loading data from CoNLL-U file with comments and non-integer IDs.

        Args:
            data (list[list] or str):
                A list of instances or a filename.
            proj (bool):
                If ``True``, discards all non-projective sentences. Default: ``False``.
            max_len (int):
                Sentences exceeding the length will be discarded. Default: ``None``.

        Returns:
            A list of :class:`CoNLLSentence` instances.
        """

        if isinstance(data, str):
            with open(data, 'r') as f:
                lines = [line.strip() for line in f]
        else:
            data = [data] if isinstance(data[0], str) else data
            lines = '\n'.join([self.toconll(i) for i in data]).split('\n')

        i, start, sentences = 0, 0, []
        for line in progress_bar(lines, leave=False):
            if not line:
                sentences.append(CoNLLSentence(self, lines[start:i]))
                start = i + 1
            i += 1
        if proj:
            sentences = [
                i for i in sentences
                if self.isprojective(list(map(int, i.arcs)))
            ]
        if max_len is not None:
            sentences = [i for i in sentences if len(i) < max_len]

        return sentences
示例#21
0
    def _predict(self, loader):
        self.model.eval()

        preds = {}
        charts, probs = [], []
        for words, feats in progress_bar(loader):
            mask = words.ne(self.WORD.pad_index)
            mask = mask.unsqueeze(1) & mask.unsqueeze(2)
            lens = mask[:, 0].sum(-1).tolist()
            s_edge, s_label = self.model(words, feats)
            charts.extend(self.model.decode(s_edge, s_label, mask))
            if self.args.prob:
                edge_probs = s_edge.softmax(-1)
                probs.extend([
                    prob[:i, :i].cpu()
                    for i, prob in zip(lens, edge_probs.unbind())
                ])
        charts = [[[self.LABEL.vocab[i] if i >= 0 else None for i in row]
                   for row in chart] for chart in charts]
        preds = {'labels': [CoNLL.build_relations(chart) for chart in charts]}
        if self.args.prob:
            preds['probs'] = probs

        return preds
示例#22
0
 def __call__(self, sentences):
     # numericalize the fields of each sentence
     for sentence in progress_bar(sentences):
         for f in self.flattened_fields:
             sentence.transformed[f.name] = f.transform([getattr(sentence, f.name)])[0]
     return self.flattened_fields