示例#1
0
 def build_dataloader(self,
                      data,
                      transform: TransformList = None,
                      training=False,
                      device=None,
                      logger: logging.Logger = None,
                      gradient_accumulation=1,
                      **kwargs) -> DataLoader:
     dataset = BiaffineSemanticDependencyParser.build_dataset(
         self, data, transform)
     if isinstance(data, str):
         dataset.purge_cache()
         length_field = 'token'
     else:
         length_field = 'FORM'
     if self.vocabs.mutable:
         BiaffineSemanticDependencyParser.build_vocabs(self,
                                                       dataset,
                                                       logger,
                                                       transformer=True)
     if dataset.cache:
         timer = CountdownTimer(len(dataset))
         BiaffineSemanticDependencyParser.cache_dataset(
             self, dataset, timer, training, logger)
     return PadSequenceDataLoader(batch_sampler=self.sampler_builder.build(
         self.compute_lens(data, dataset, length_field=length_field),
         shuffle=training,
         gradient_accumulation=gradient_accumulation),
                                  device=device,
                                  dataset=dataset,
                                  pad=self.get_pad_dict())
示例#2
0
 def update_metrics(self, batch: Dict[str, Any],
                    output: Union[torch.Tensor, Dict[str, torch.Tensor],
                                  Iterable[torch.Tensor], Any],
                    prediction: Dict[str, Any], metric: Union[MetricDict,
                                                              Metric]):
     BiaffineSemanticDependencyParser.update_metric(self, *prediction,
                                                    batch['arc'],
                                                    batch['rel_id'],
                                                    output[1], output[-1],
                                                    metric, batch)
示例#3
0
 def compute_loss(self, batch: Dict[str, Any],
                  output: Union[torch.Tensor, Dict[str, torch.Tensor], Iterable[torch.Tensor], Any], criterion) -> \
         Union[torch.FloatTensor, Dict[str, torch.FloatTensor]]:
     (arc_scores, rel_scores), mask, punct_mask = output
     return BiaffineSemanticDependencyParser.compute_loss(
         self, arc_scores, rel_scores, batch['arc'], batch['rel_id'], mask,
         criterion, batch)
示例#4
0
 def decode_output(self,
                   output: Union[torch.Tensor, Dict[str, torch.Tensor], Iterable[torch.Tensor], Any],
                   mask: torch.BoolTensor,
                   batch: Dict[str, Any],
                   decoder, **kwargs) -> Union[Dict[str, Any], Any]:
     (arc_scores, rel_scores), mask, punct_mask = output
     return BiaffineSemanticDependencyParser.decode(self, arc_scores, rel_scores, mask, batch)
示例#5
0
 def compute_mask(arc_scores_2nd, batch, mask_1st):
     mask = batch.get('mask_2nd', None)
     if mask is None:
         batch[
             'mask_2nd'] = mask = BiaffineSemanticDependencyParser.convert_to_3d_mask(
                 arc_scores_2nd, mask_1st)
     return mask
示例#6
0
 def build_metric(self, **kwargs):
     # noinspection PyCallByClass
     return MetricDict({
         '1st':
         super().build_metric(**kwargs),
         '2nd':
         BiaffineSemanticDependencyParser.build_metric(self, **kwargs)
     })
示例#7
0
 def update_metric(self,
                   arc_preds,
                   rel_preds,
                   arcs,
                   rels,
                   mask,
                   puncts,
                   metric,
                   batch=None):
     super().update_metric(arc_preds[0], rel_preds[0], arcs, rels, mask,
                           puncts, metric['1st'], batch)
     puncts = BiaffineSemanticDependencyParser.convert_to_3d_puncts(
         puncts, batch['mask_2nd'])
     # noinspection PyCallByClass
     BiaffineSemanticDependencyParser.update_metric(
         self, arc_preds[1], rel_preds[1], batch['arc_2nd'],
         batch['rel_2nd_id'], batch['mask_2nd'], puncts, metric['2nd'],
         batch)
示例#8
0
 def compute_loss(self,
                  arc_scores,
                  rel_scores,
                  arcs,
                  rels,
                  mask,
                  criterion,
                  batch=None):
     arc_scores_1st, arc_scores_2nd, rel_scores_1st, rel_scores_2nd = self.unpack_scores(
         arc_scores, rel_scores)
     loss_1st = super().compute_loss(arc_scores_1st, rel_scores_1st, arcs,
                                     rels, mask, criterion[0], batch)
     mask = self.compute_mask(arc_scores_2nd, batch, mask)
     # noinspection PyCallByClass
     loss_2st = BiaffineSemanticDependencyParser.compute_loss(
         self, arc_scores_2nd, rel_scores_2nd, batch['arc_2nd'],
         batch['rel_2nd_id'], mask, criterion[1], batch)
     return loss_1st + loss_2st
示例#9
0
    def decode(self,
               arc_scores,
               rel_scores,
               mask,
               batch=None,
               predicting=None):
        output_1st, output_2nd = batch.get('outputs', (None, None))
        if output_1st is None:
            arc_scores_1st, arc_scores_2nd, rel_scores_1st, rel_scores_2nd = self.unpack_scores(
                arc_scores, rel_scores)
            output_1st = super().decode(arc_scores_1st, rel_scores_1st, mask)
            mask = self.compute_mask(arc_scores_2nd, batch, mask)
            # noinspection PyCallByClass
            output_2nd = BiaffineSemanticDependencyParser.decode(
                self, arc_scores_2nd, rel_scores_2nd, mask, batch)
            if self.config.get('no_cycle'):
                assert predicting, 'No cycle constraint for evaluation is not implemented yet. If you are ' \
                                   'interested, welcome to submit a pull request.'
                root_rel_idx = self.vocabs['rel'].token_to_idx.get(
                    self.config.get('root', None), None)
                arc_pred_1st, rel_pred_1st, arc_pred_2nd, rel_pred_2nd = *output_1st, *output_2nd
                arc_scores_2nd = arc_scores_2nd.transpose(
                    1, 2).cpu().detach().numpy()
                arc_pred_2nd = arc_pred_2nd.cpu().detach().numpy()
                rel_pred_2nd = rel_pred_2nd.cpu().detach().numpy()
                trees = arc_pred_1st.cpu().detach().numpy()
                graphs = []
                for i, (arc_scores, arc_preds, rel_preds, tree,
                        tokens) in enumerate(
                            zip(arc_scores_2nd, arc_pred_2nd, rel_pred_2nd,
                                trees, batch['token'])):
                    sent_len = len(tokens)
                    graph = add_secondary_arcs_by_preds(
                        arc_scores, arc_preds[:sent_len, :sent_len], rel_preds,
                        tree[:sent_len], root_rel_idx)
                    graphs.append(graph[1:])  # Remove root
                    # if not predicting:
                    #     # Write back to torch Tensor
                    #     for d, hr in zip(graph):
                    #         pass
                output_2nd = None, graphs

        return tuple(zip(output_1st, output_2nd))
示例#10
0
 def input_is_flat(self, data) -> bool:
     return BiaffineSemanticDependencyParser.input_is_flat(
         self, data, self.config.use_pos)
示例#11
0
 def build_metric(self, **kwargs):
     return BiaffineSemanticDependencyParser.build_metric(self, **kwargs)
示例#12
0
 def build_samples(self, inputs, cls_is_bos=False, sep_is_eos=False):
     return BiaffineSemanticDependencyParser.build_samples(self, inputs, self.config.use_pos)
示例#13
0
 def build_criterion(self, **kwargs):
     # noinspection PyCallByClass
     return super().build_criterion(
         **kwargs), (BiaffineSemanticDependencyParser.build_criterion(
             self, **kwargs))