def build_dataloader(self, data, transform: TransformList = None, training=False, device=None, logger: logging.Logger = None, gradient_accumulation=1, **kwargs) -> DataLoader: transform.insert(0, append_bos) dataset = BiaffineDependencyParser.build_dataset(self, data, transform) if isinstance(data, str): dataset.purge_cache() if self.vocabs.mutable: BiaffineDependencyParser.build_vocabs(self, dataset, logger, transformer=True) if dataset.cache: timer = CountdownTimer(len(dataset)) BiaffineDependencyParser.cache_dataset(self, dataset, timer, training, logger) max_seq_len = self.config.get('max_seq_len', None) if max_seq_len and isinstance(data, str): dataset.prune(lambda x: len(x['token_input_ids']) > 510, logger) return PadSequenceDataLoader(batch_sampler=self.sampler_builder.build( self.compute_lens(data, dataset, length_field='FORM'), shuffle=training, gradient_accumulation=gradient_accumulation), device=device, dataset=dataset, pad=self.get_pad_dict())
def build_dataloader(self, data, transform: Callable = None, training=False, device=None, logger: logging.Logger = None, cache=False, gradient_accumulation=1, **kwargs) -> DataLoader: dataset = CRFConstituencyParsing.build_dataset(self, data, transform) if isinstance(data, str): dataset.purge_cache() if self.vocabs.mutable: CRFConstituencyParsing.build_vocabs(self, dataset, logger) if dataset.cache: timer = CountdownTimer(len(dataset)) # noinspection PyCallByClass BiaffineDependencyParser.cache_dataset(self, dataset, timer, training, logger) return PadSequenceDataLoader(batch_sampler=self.sampler_builder.build( self.compute_lens(data, dataset), shuffle=training, gradient_accumulation=gradient_accumulation), device=device, dataset=dataset)
def update_metrics(self, batch: Dict[str, Any], output: Union[torch.Tensor, Dict[str, torch.Tensor], Iterable[torch.Tensor], Any], prediction: Dict[str, Any], metric: Union[MetricDict, Metric]): BiaffineDependencyParser.update_metric(self, *prediction, batch['arc'], batch['rel_id'], output[1], batch.get('punct_mask', None), metric, batch)
def update_metrics(self, metrics, batch, outputs, mask): arc_preds, rel_preds, puncts = outputs['arc_preds'], outputs[ 'rel_preds'], batch.get('punct_mask', None) BiaffineDependencyParser.update_metric(self, arc_preds, rel_preds, batch['arc'], batch['rel_id'], mask, puncts, metrics['deps'], batch) for task, key in zip(['lemmas', 'upos', 'feats'], ['lemma_id', 'pos_id', 'feat_id']): metric: Metric = metrics[task] pred = outputs['class_probabilities'][task] gold = batch[key] metric(pred.detach(), gold, mask=mask) return metrics
def prediction_to_human(self, outputs: dict, batch): arcs, rels = outputs['arc_preds'], outputs['rel_preds'] upos = outputs['class_probabilities']['upos'][:, 1:, :].argmax( -1).tolist() feats = outputs['class_probabilities']['feats'][:, 1:, :].argmax( -1).tolist() lemmas = outputs['class_probabilities']['lemmas'][:, 1:, :].argmax( -1).tolist() lem_vocab = self.vocabs['lemma'].idx_to_token pos_vocab = self.vocabs['pos'].idx_to_token feat_vocab = self.vocabs['feat'].idx_to_token # noinspection PyCallByClass,PyTypeChecker for tree, form, lemma, pos, feat in zip( BiaffineDependencyParser.prediction_to_head_rel( self, arcs, rels, batch), batch['token'], lemmas, upos, feats): form = form[1:] assert len(form) == len(tree) lemma = [ apply_lemma_rule(t, lem_vocab[r]) for t, r in zip(form, lemma) ] pos = [pos_vocab[x] for x in pos] feat = [feat_vocab[x] for x in feat] yield CoNLLSentence([ CoNLLUWord(id=i + 1, form=fo, lemma=l, upos=p, feats=fe, head=a, deprel=r) for i, (fo, (a, r), l, p, fe) in enumerate(zip(form, tree, lemma, pos, feat)) ])
def decode_output(self, outputs, mask, batch): arc_scores, rel_scores = outputs['class_probabilities']['deps']['s_arc'], \ outputs['class_probabilities']['deps']['s_rel'] arc_preds, rel_preds = BiaffineDependencyParser.decode( self, arc_scores, rel_scores, mask, batch) outputs['arc_preds'], outputs['rel_preds'] = arc_preds, rel_preds return outputs
def compute_loss(self, arc_scores, rel_scores, arcs, rels, mask, criterion, batch=None): parse_loss = BiaffineDependencyParser.compute_loss(self, arc_scores, rel_scores, arcs, rels, mask, criterion, batch) if self.model.training: gold_input_ids = batch['gold_input_ids'] pred_input_ids = batch['pred_input_ids'] input_ids_mask = batch['input_ids_mask'] token_span = batch['token_span'] gold_input_ids = batch['gold_input_ids'] = gold_input_ids.gather(1, token_span[:, :, 0]) input_ids_mask = batch['input_ids_mask'] = input_ids_mask.gather(1, token_span[:, :, 0]) mlm_loss = F.cross_entropy(pred_input_ids[input_ids_mask], gold_input_ids[input_ids_mask]) loss = parse_loss + mlm_loss return loss return parse_loss
def forward(self, hidden, batch: Dict[str, torch.Tensor], mask) -> Dict[str, Any]: mask_without_root = mask.clone() mask_without_root[:, 0] = False logits = {} class_probabilities = {} output_dict = {"logits": logits, "class_probabilities": class_probabilities} loss = 0 arc = batch.get('arc', None) # Run through each of the tasks on the shared encoder and save predictions for task in self.decoders: if self.scalar_mix: decoder_input = self.scalar_mix[task](hidden, mask) else: decoder_input = hidden if task == "deps": s_arc, s_rel = self.decoders[task](decoder_input, mask) pred_output = {'class_probabilities': {'s_arc': s_arc, 's_rel': s_rel}} if arc is not None: # noinspection PyTypeChecker pred_output['loss'] = BiaffineDependencyParser.compute_loss(None, s_arc, s_rel, arc, batch['rel_id'], mask_without_root, torch.nn.functional.cross_entropy) else: pred_output = self.decoders[task](decoder_input, mask_without_root, batch.get(self.gold_keys[task], None)) if 'logits' in pred_output: logits[task] = pred_output["logits"] if 'class_probabilities' in pred_output: class_probabilities[task] = pred_output["class_probabilities"] if 'loss' in pred_output: # Keep track of the loss if we have the gold tags available loss += pred_output["loss"] if arc is not None: output_dict["loss"] = loss return output_dict
def input_is_flat(self, data) -> bool: return BiaffineDependencyParser.input_is_flat(self, data, self.config.use_pos)
def build_metric(self, **kwargs): return BiaffineDependencyParser.build_metric(self, **kwargs)
def input_is_flat(self, data): # noinspection PyCallByClass,PyTypeChecker return BiaffineDependencyParser.input_is_flat(self, data, False)