Esempio n. 1
0
 def build_dataloader(self,
                      data,
                      transform: TransformList = None,
                      training=False,
                      device=None,
                      logger: logging.Logger = None,
                      gradient_accumulation=1,
                      **kwargs) -> DataLoader:
     transform.insert(0, append_bos)
     dataset = BiaffineDependencyParser.build_dataset(self, data, transform)
     if isinstance(data, str):
         dataset.purge_cache()
     if self.vocabs.mutable:
         BiaffineDependencyParser.build_vocabs(self,
                                               dataset,
                                               logger,
                                               transformer=True)
     if dataset.cache:
         timer = CountdownTimer(len(dataset))
         BiaffineDependencyParser.cache_dataset(self, dataset, timer,
                                                training, logger)
     max_seq_len = self.config.get('max_seq_len', None)
     if max_seq_len and isinstance(data, str):
         dataset.prune(lambda x: len(x['token_input_ids']) > 510, logger)
     return PadSequenceDataLoader(batch_sampler=self.sampler_builder.build(
         self.compute_lens(data, dataset, length_field='FORM'),
         shuffle=training,
         gradient_accumulation=gradient_accumulation),
                                  device=device,
                                  dataset=dataset,
                                  pad=self.get_pad_dict())
Esempio n. 2
0
 def build_dataloader(self,
                      data,
                      transform: Callable = None,
                      training=False,
                      device=None,
                      logger: logging.Logger = None,
                      cache=False,
                      gradient_accumulation=1,
                      **kwargs) -> DataLoader:
     dataset = CRFConstituencyParsing.build_dataset(self, data, transform)
     if isinstance(data, str):
         dataset.purge_cache()
     if self.vocabs.mutable:
         CRFConstituencyParsing.build_vocabs(self, dataset, logger)
     if dataset.cache:
         timer = CountdownTimer(len(dataset))
         # noinspection PyCallByClass
         BiaffineDependencyParser.cache_dataset(self, dataset, timer,
                                                training, logger)
     return PadSequenceDataLoader(batch_sampler=self.sampler_builder.build(
         self.compute_lens(data, dataset),
         shuffle=training,
         gradient_accumulation=gradient_accumulation),
                                  device=device,
                                  dataset=dataset)
Esempio n. 3
0
 def update_metrics(self, batch: Dict[str, Any],
                    output: Union[torch.Tensor, Dict[str, torch.Tensor],
                                  Iterable[torch.Tensor], Any],
                    prediction: Dict[str, Any], metric: Union[MetricDict,
                                                              Metric]):
     BiaffineDependencyParser.update_metric(self, *prediction, batch['arc'],
                                            batch['rel_id'], output[1],
                                            batch.get('punct_mask',
                                                      None), metric, batch)
Esempio n. 4
0
 def update_metrics(self, metrics, batch, outputs, mask):
     arc_preds, rel_preds, puncts = outputs['arc_preds'], outputs[
         'rel_preds'], batch.get('punct_mask', None)
     BiaffineDependencyParser.update_metric(self, arc_preds, rel_preds,
                                            batch['arc'], batch['rel_id'],
                                            mask, puncts, metrics['deps'],
                                            batch)
     for task, key in zip(['lemmas', 'upos', 'feats'],
                          ['lemma_id', 'pos_id', 'feat_id']):
         metric: Metric = metrics[task]
         pred = outputs['class_probabilities'][task]
         gold = batch[key]
         metric(pred.detach(), gold, mask=mask)
     return metrics
Esempio n. 5
0
 def prediction_to_human(self, outputs: dict, batch):
     arcs, rels = outputs['arc_preds'], outputs['rel_preds']
     upos = outputs['class_probabilities']['upos'][:, 1:, :].argmax(
         -1).tolist()
     feats = outputs['class_probabilities']['feats'][:, 1:, :].argmax(
         -1).tolist()
     lemmas = outputs['class_probabilities']['lemmas'][:, 1:, :].argmax(
         -1).tolist()
     lem_vocab = self.vocabs['lemma'].idx_to_token
     pos_vocab = self.vocabs['pos'].idx_to_token
     feat_vocab = self.vocabs['feat'].idx_to_token
     # noinspection PyCallByClass,PyTypeChecker
     for tree, form, lemma, pos, feat in zip(
             BiaffineDependencyParser.prediction_to_head_rel(
                 self, arcs, rels, batch), batch['token'], lemmas, upos,
             feats):
         form = form[1:]
         assert len(form) == len(tree)
         lemma = [
             apply_lemma_rule(t, lem_vocab[r]) for t, r in zip(form, lemma)
         ]
         pos = [pos_vocab[x] for x in pos]
         feat = [feat_vocab[x] for x in feat]
         yield CoNLLSentence([
             CoNLLUWord(id=i + 1,
                        form=fo,
                        lemma=l,
                        upos=p,
                        feats=fe,
                        head=a,
                        deprel=r)
             for i, (fo, (a, r), l, p,
                     fe) in enumerate(zip(form, tree, lemma, pos, feat))
         ])
Esempio n. 6
0
 def decode_output(self, outputs, mask, batch):
     arc_scores, rel_scores = outputs['class_probabilities']['deps']['s_arc'], \
                              outputs['class_probabilities']['deps']['s_rel']
     arc_preds, rel_preds = BiaffineDependencyParser.decode(
         self, arc_scores, rel_scores, mask, batch)
     outputs['arc_preds'], outputs['rel_preds'] = arc_preds, rel_preds
     return outputs
Esempio n. 7
0
 def compute_loss(self, arc_scores, rel_scores, arcs, rels, mask, criterion, batch=None):
     parse_loss = BiaffineDependencyParser.compute_loss(self, arc_scores, rel_scores, arcs, rels, mask, criterion, batch)
     if self.model.training:
         gold_input_ids = batch['gold_input_ids']
         pred_input_ids = batch['pred_input_ids']
         input_ids_mask = batch['input_ids_mask']
         token_span = batch['token_span']
         gold_input_ids = batch['gold_input_ids'] = gold_input_ids.gather(1, token_span[:, :, 0])
         input_ids_mask = batch['input_ids_mask'] = input_ids_mask.gather(1, token_span[:, :, 0])
         mlm_loss = F.cross_entropy(pred_input_ids[input_ids_mask], gold_input_ids[input_ids_mask])
         loss = parse_loss + mlm_loss
         return loss
     return parse_loss
Esempio n. 8
0
    def forward(self,
                hidden,
                batch: Dict[str, torch.Tensor],
                mask) -> Dict[str, Any]:
        mask_without_root = mask.clone()
        mask_without_root[:, 0] = False

        logits = {}
        class_probabilities = {}
        output_dict = {"logits": logits,
                       "class_probabilities": class_probabilities}
        loss = 0

        arc = batch.get('arc', None)
        # Run through each of the tasks on the shared encoder and save predictions
        for task in self.decoders:
            if self.scalar_mix:
                decoder_input = self.scalar_mix[task](hidden, mask)
            else:
                decoder_input = hidden

            if task == "deps":
                s_arc, s_rel = self.decoders[task](decoder_input, mask)
                pred_output = {'class_probabilities': {'s_arc': s_arc, 's_rel': s_rel}}
                if arc is not None:
                    # noinspection PyTypeChecker
                    pred_output['loss'] = BiaffineDependencyParser.compute_loss(None, s_arc, s_rel, arc,
                                                                                batch['rel_id'],
                                                                                mask_without_root,
                                                                                torch.nn.functional.cross_entropy)
            else:
                pred_output = self.decoders[task](decoder_input, mask_without_root,
                                                  batch.get(self.gold_keys[task], None))
            if 'logits' in pred_output:
                logits[task] = pred_output["logits"]
            if 'class_probabilities' in pred_output:
                class_probabilities[task] = pred_output["class_probabilities"]
            if 'loss' in pred_output:
                # Keep track of the loss if we have the gold tags available
                loss += pred_output["loss"]

        if arc is not None:
            output_dict["loss"] = loss

        return output_dict
Esempio n. 9
0
 def input_is_flat(self, data) -> bool:
     return BiaffineDependencyParser.input_is_flat(self, data, self.config.use_pos)
Esempio n. 10
0
 def build_metric(self, **kwargs):
     return BiaffineDependencyParser.build_metric(self, **kwargs)
Esempio n. 11
0
 def input_is_flat(self, data):
     # noinspection PyCallByClass,PyTypeChecker
     return BiaffineDependencyParser.input_is_flat(self, data, False)