def build_dataloader(self, data, transform: TransformList = None, training=False, device=None, logger: logging.Logger = None, gradient_accumulation=1, **kwargs) -> DataLoader: transform.insert(0, append_bos) dataset = BiaffineDependencyParser.build_dataset(self, data, transform) if isinstance(data, str): dataset.purge_cache() if self.vocabs.mutable: BiaffineDependencyParser.build_vocabs(self, dataset, logger, transformer=True) if dataset.cache: timer = CountdownTimer(len(dataset)) BiaffineDependencyParser.cache_dataset(self, dataset, timer, training, logger) max_seq_len = self.config.get('max_seq_len', None) if max_seq_len and isinstance(data, str): dataset.prune(lambda x: len(x['token_input_ids']) > 510, logger) return PadSequenceDataLoader(batch_sampler=self.sampler_builder.build( self.compute_lens(data, dataset, length_field='FORM'), shuffle=training, gradient_accumulation=gradient_accumulation), device=device, dataset=dataset, pad=self.get_pad_dict())
def build_dataloader(self, data, transform: Callable = None, training=False, device=None, logger: logging.Logger = None, cache=False, gradient_accumulation=1, **kwargs) -> DataLoader: dataset = CRFConstituencyParsing.build_dataset(self, data, transform) if isinstance(data, str): dataset.purge_cache() if self.vocabs.mutable: CRFConstituencyParsing.build_vocabs(self, dataset, logger) if dataset.cache: timer = CountdownTimer(len(dataset)) # noinspection PyCallByClass BiaffineDependencyParser.cache_dataset(self, dataset, timer, training, logger) return PadSequenceDataLoader(batch_sampler=self.sampler_builder.build( self.compute_lens(data, dataset), shuffle=training, gradient_accumulation=gradient_accumulation), device=device, dataset=dataset)