def build_dataloader(self, data, transform: TransformList = None, training=False, device=None, logger: logging.Logger = None, tokenizer: PreTrainedTokenizer = None, **kwargs) -> DataLoader: assert tokenizer dataset = TextTokenizingDataset(data, cache=isinstance(data, str), delimiter=self.config.sent_delimiter, generate_idx=isinstance(data, list), max_seq_len=self.config.max_seq_len, sent_delimiter=self.config.sent_delimiter, transform=[ TransformerSequenceTokenizer(tokenizer, 'text', ret_prefix_mask=True, ret_subtokens=True, ), FieldLength('text_input_ids', 'text_input_ids_length', delta=-2), generate_token_span_tuple]) return PadSequenceDataLoader( batch_sampler=self.sampler_builder.build(self.compute_lens(data, dataset, 'text_input_ids', 'text'), shuffle=training), device=device, dataset=dataset)
def build_dataloader(self, data, batch_size, shuffle, device, logger: logging.Logger = None, vocabs=None, sampler_builder=None, gradient_accumulation=1, **kwargs) -> DataLoader: if vocabs is None: vocabs = self.vocabs transform = TransformList(unpack_ner, FieldLength('token')) if isinstance(self.config.embed, Embedding): transform.append(self.config.embed.transform(vocabs=vocabs)) transform.append(self.vocabs) dataset = self.build_dataset(data, vocabs, transform) if vocabs.mutable: self.build_vocabs(dataset, logger, vocabs) if 'token' in vocabs: lens = [x['token'] for x in dataset] else: lens = [len(x['token_input_ids']) for x in dataset] if sampler_builder: sampler = sampler_builder.build(lens, shuffle, gradient_accumulation) else: sampler = None return PadSequenceDataLoader(batch_sampler=sampler, device=device, dataset=dataset)
def build_dataloader(self, data, batch_size, shuffle, device, logger: logging.Logger, sampler_builder, gradient_accumulation, **kwargs) -> DataLoader: # shuffle = False # We need to find the smallest grad_acc dataset = HeadDrivenPhraseStructureDataset(data, transform=[append_bos_eos]) if self.config.get('transform', None): dataset.append_transform(self.config.transform) dataset.append_transform(self.vocabs) if isinstance(self.config.embed, Embedding): transform = self.config.embed.transform(vocabs=self.vocabs) if transform: dataset.append_transform(transform) dataset.append_transform(self.vocabs) field_length = FieldLength('token') dataset.append_transform(field_length) if isinstance(data, str): dataset.purge_cache() # Enable cache if self.vocabs.mutable: self.build_vocabs(dataset, logger) if 'token' in self.vocabs: lens = [x[field_length.dst] for x in dataset] else: lens = [len(x['token_input_ids']) for x in dataset] if sampler_builder: sampler = sampler_builder.build(lens, shuffle, gradient_accumulation) else: sampler = None return PadSequenceDataLoader(batch_sampler=sampler, batch_size=batch_size, device=device, dataset=dataset)
def build_dataloader(self, data, batch_size, sampler_builder: SamplerBuilder = None, gradient_accumulation=1, shuffle=False, device=None, logger: logging.Logger = None, **kwargs) -> DataLoader: if isinstance(data, TransformableDataset): dataset = data else: dataset = self.build_dataset(data, [ self.config.embed.transform(vocabs=self.vocabs), self.vocabs, FieldLength('token') ]) if self.vocabs.mutable: # noinspection PyTypeChecker self.build_vocabs(dataset, logger) lens = [len(x['token_input_ids']) for x in dataset] if sampler_builder: sampler = sampler_builder.build(lens, shuffle, gradient_accumulation) else: sampler = None return PadSequenceDataLoader(dataset, batch_size, shuffle, device=device, batch_sampler=sampler)
def build_transform(self, task: Task) -> Tuple[TransformerSequenceTokenizer, TransformList]: encoder: ContextualWordEmbedding = self.config.encoder encoder_transform: TransformerSequenceTokenizer = task.build_tokenizer(encoder.transform()) length_transform = FieldLength('token', 'token_length') transform = TransformList(encoder_transform, length_transform) extra_transform = self.config.get('transform', None) if extra_transform: transform.insert(0, extra_transform) return encoder_transform, transform
def build_dataset(self, data, transform, logger=None): _transform = [ unpack_tree_to_features, self.vocabs, FieldLength('token'), transform ] if self.config.get('no_subcategory', True): _transform.insert(0, remove_subcategory) dataset = ConstituencyDataset(data, transform=_transform, cache=isinstance(data, str)) return dataset
def build_dataloader(self, data, shuffle, device, embed: Embedding, training=False, logger=None, gradient_accumulation=1, sampler_builder=None, batch_size=None, bos='\0', **kwargs) -> DataLoader: first_transform = TransformList(functools.partial(append_bos, bos=bos)) embed_transform = embed.transform(vocabs=self.vocabs) transformer_transform = self._get_transformer_transform_from_transforms( embed_transform) if embed_transform: if transformer_transform and isinstance(embed_transform, TransformList): embed_transform.remove(transformer_transform) first_transform.append(embed_transform) dataset = self.build_dataset(data, first_transform=first_transform) if self.config.get('transform', None): dataset.append_transform(self.config.transform) if self.vocabs.mutable: self.build_vocabs(dataset, logger, self._transformer_trainable()) if transformer_transform and isinstance(embed_transform, TransformList): embed_transform.append(transformer_transform) dataset.append_transform(FieldLength('token', 'sent_length')) if isinstance(data, str): dataset.purge_cache() if len(dataset) > 1000 and isinstance(data, str): timer = CountdownTimer(len(dataset)) self.cache_dataset(dataset, timer, training, logger) if sampler_builder: lens = [sample['sent_length'] for sample in dataset] sampler = sampler_builder.build(lens, shuffle, gradient_accumulation) else: sampler = None loader = PadSequenceDataLoader(dataset=dataset, batch_sampler=sampler, batch_size=batch_size, pad=self.get_pad_dict(), device=device, vocabs=self.vocabs) return loader
def build_dataset(self, data, generate_idx, logger, transform=None): dataset = CoNLL2012SRLDataset(data, transform=[filter_v_args, unpack_srl, group_pa_by_p], doc_level_offset=self.config.doc_level_offset, generate_idx=generate_idx) if transform: dataset.append_transform(transform) if isinstance(self.config.get('embed', None), Embedding): transform = self.config.embed.transform(vocabs=self.vocabs) if transform: dataset.append_transform(transform) dataset.append_transform(self.vocabs) dataset.append_transform(FieldLength('token')) if isinstance(data, str): dataset.purge_cache() # Enable cache if self.vocabs.mutable: self.build_vocabs(dataset, logger) return dataset
def build_dataloader(self, data, batch_size, shuffle, device, logger: logging.Logger, **kwargs) -> DataLoader: dataset = CONLL12CorefDataset(data, [FieldLength('text')]) if isinstance(self.config.embed, Embedding): transform = self.config.embed.transform(vocabs=self.vocabs) if transform: dataset.append_transform(transform) dataset.append_transform(self.vocabs) if isinstance(data, str): dataset.purge_cache() # Enable cache if self.vocabs.mutable: self.build_vocabs(dataset) return PadSequenceDataLoader(batch_size=batch_size, shuffle=shuffle, device=device, dataset=dataset, pad={ 'spans': 0, 'span_labels': -1 })
def build_dataloader(self, data, shuffle, device, training=False, logger=None, gradient_accumulation=1, sampler_builder=None, batch_size=None, **kwargs) -> DataLoader: dataset = self.build_dataset(data) if self.vocabs.mutable: self.build_vocabs(dataset, logger, self.config.transformer) transformer_tokenizer = self.transformer_tokenizer if transformer_tokenizer: dataset.transform.append(self.build_tokenizer_transform()) dataset.append_transform(FieldLength('token', 'sent_length')) if isinstance(data, str): dataset.purge_cache() if len(dataset) > 1000 and isinstance(data, str): timer = CountdownTimer(len(dataset)) self.cache_dataset(dataset, timer, training, logger) if self.config.transformer: lens = [len(sample['input_ids']) for sample in dataset] else: lens = [sample['sent_length'] for sample in dataset] if sampler_builder: sampler = sampler_builder.build(lens, shuffle, gradient_accumulation) else: sampler = None loader = PadSequenceDataLoader(dataset=dataset, batch_sampler=sampler, batch_size=batch_size, num_workers=0 if isdebugging() else 2, pad=self.get_pad_dict(), device=device, vocabs=self.vocabs) return loader
def build_dataloader(self, data, batch_size, shuffle=False, device=None, logger: logging.Logger = None, sampler_builder=None, gradient_accumulation=1, transformer: ContextualWordEmbedding = None, **kwargs) -> DataLoader: transform = [ generate_lemma_rule, append_bos, self.vocabs, transformer.transform(), FieldLength('token') ] if not self.config.punct: transform.append(PunctuationMask('token', 'punct_mask')) dataset = self.build_dataset(data, transform) if self.vocabs.mutable: # noinspection PyTypeChecker self.build_vocabs(dataset, logger) lens = [len(x['token_input_ids']) for x in dataset] if sampler_builder: sampler = sampler_builder.build(lens, shuffle, gradient_accumulation) else: sampler = SortingSamplerBuilder(batch_size).build( lens, shuffle, gradient_accumulation) return PadSequenceDataLoader( dataset, batch_size, shuffle, device=device, batch_sampler=sampler, pad={'arc': 0}, )
def last_transform(self): return TransformList(self.vocabs, FieldLength(self.config.token_key))