def build_dataloader(self, data, transform: TransformList = None, training=False, device=None, logger: logging.Logger = None, gradient_accumulation=1, **kwargs) -> DataLoader: transform.insert(0, append_bos) dataset = BiaffineDependencyParser.build_dataset(self, data, transform) if isinstance(data, str): dataset.purge_cache() if self.vocabs.mutable: BiaffineDependencyParser.build_vocabs(self, dataset, logger, transformer=True) if dataset.cache: timer = CountdownTimer(len(dataset)) BiaffineDependencyParser.cache_dataset(self, dataset, timer, training, logger) max_seq_len = self.config.get('max_seq_len', None) if max_seq_len and isinstance(data, str): dataset.prune(lambda x: len(x['token_input_ids']) > 510, logger) return PadSequenceDataLoader(batch_sampler=self.sampler_builder.build( self.compute_lens(data, dataset, length_field='FORM'), shuffle=training, gradient_accumulation=gradient_accumulation), device=device, dataset=dataset, pad=self.get_pad_dict())
def build_dataloader(self, data, transform: Callable = None, training=False, device=None, logger: logging.Logger = None, cache=False, gradient_accumulation=1, **kwargs) -> DataLoader: dataset = CRFConstituencyParsing.build_dataset(self, data, transform) if isinstance(data, str): dataset.purge_cache() if self.vocabs.mutable: CRFConstituencyParsing.build_vocabs(self, dataset, logger) if dataset.cache: timer = CountdownTimer(len(dataset)) # noinspection PyCallByClass BiaffineDependencyParser.cache_dataset(self, dataset, timer, training, logger) return PadSequenceDataLoader(batch_sampler=self.sampler_builder.build( self.compute_lens(data, dataset), shuffle=training, gradient_accumulation=gradient_accumulation), device=device, dataset=dataset)
def update_metrics(self, batch: Dict[str, Any], output: Union[torch.Tensor, Dict[str, torch.Tensor], Iterable[torch.Tensor], Any], prediction: Dict[str, Any], metric: Union[MetricDict, Metric]): BiaffineDependencyParser.update_metric(self, *prediction, batch['arc'], batch['rel_id'], output[1], batch.get('punct_mask', None), metric, batch)
def compute_loss(self, batch: Dict[str, Any], output: Union[torch.Tensor, Dict[str, torch.Tensor], Iterable[torch.Tensor], Any], criterion) -> \ Union[torch.FloatTensor, Dict[str, torch.FloatTensor]]: (arc_scores, rel_scores), mask = output return BiaffineDependencyParser.compute_loss(self, arc_scores, rel_scores, batch['arc'], batch['rel_id'], mask, criterion, batch)
def decode_output(self, output: Union[torch.Tensor, Dict[str, torch.Tensor], Iterable[torch.Tensor], Any], mask: torch.BoolTensor, batch: Dict[str, Any], decoder, **kwargs) -> Union[Dict[str, Any], Any]: (arc_scores, rel_scores), mask = output return BiaffineDependencyParser.decode(self, arc_scores, rel_scores, mask, batch)
def compute_loss(self, arc_scores, rel_scores, arcs, rels, mask, criterion, batch=None): parse_loss = BiaffineDependencyParser.compute_loss( self, arc_scores, rel_scores, arcs, rels, mask, criterion, batch) if self.model.training: gold_input_ids = batch['gold_input_ids'] pred_input_ids = batch['pred_input_ids'] input_ids_mask = batch['input_ids_mask'] token_span = batch['token_span'] gold_input_ids = batch['gold_input_ids'] = gold_input_ids.gather( 1, token_span[:, :, 0]) input_ids_mask = batch['input_ids_mask'] = input_ids_mask.gather( 1, token_span[:, :, 0]) mlm_loss = F.cross_entropy(pred_input_ids[input_ids_mask], gold_input_ids[input_ids_mask]) loss = parse_loss + mlm_loss return loss return parse_loss
def input_is_flat(self, data) -> bool: return BiaffineDependencyParser.input_is_flat(self, data, self.config.use_pos)
def build_metric(self, **kwargs): return BiaffineDependencyParser.build_metric(self, **kwargs)