def build_dataloader(self, data, transform: TransformList = None, training=False, device=None, logger: logging.Logger = None, gradient_accumulation=1, **kwargs) -> DataLoader: transform.insert(0, append_bos) dataset = BiaffineDependencyParser.build_dataset(self, data, transform) if isinstance(data, str): dataset.purge_cache() if self.vocabs.mutable: BiaffineDependencyParser.build_vocabs(self, dataset, logger, transformer=True) if dataset.cache: timer = CountdownTimer(len(dataset)) BiaffineDependencyParser.cache_dataset(self, dataset, timer, training, logger) max_seq_len = self.config.get('max_seq_len', None) if max_seq_len and isinstance(data, str): dataset.prune(lambda x: len(x['token_input_ids']) > 510, logger) return PadSequenceDataLoader(batch_sampler=self.sampler_builder.build( self.compute_lens(data, dataset, length_field='FORM'), shuffle=training, gradient_accumulation=gradient_accumulation), device=device, dataset=dataset, pad=self.get_pad_dict())
def build_dataset(self, data, transform=None): transforms = TransformList( functools.partial(append_bos_to_form_pos, pos_key='UPOS'), functools.partial(unpack_deps_to_head_deprel, pad_rel=self.config.pad_rel)) if transform: transforms.append(transform) return super(BiaffineSemanticDependencyParser, self).build_dataset(data, transforms)
def append_transform(self, transform: Callable): assert transform is not None, 'None transform not allowed' if not self.transform: self.transform = TransformList(transform) elif not isinstance(self.transform, TransformList): if self.transform != transform: self.transform = TransformList(self.transform, transform) else: if transform not in self.transform: self.transform.append(transform) return self
def insert_transform(self, index: int, transform: Callable): assert transform is not None, 'None transform not allowed' if not self.transform: self.transform = TransformList(transform) elif not isinstance(self.transform, TransformList): if self.transform != transform: self.transform = TransformList(self.transform) self.transform.insert(index, transform) else: if transform not in self.transform: self.transform.insert(index, transform) return self
def build_transform( self, task: Task) -> Tuple[TransformerSequenceTokenizer, TransformList]: encoder: ContextualWordEmbedding = self.config.encoder encoder_transform: TransformerSequenceTokenizer = task.build_tokenizer( encoder.transform()) length_transform = FieldLength('token', 'token_length') transform = TransformList(encoder_transform, length_transform) extra_transform = self.config.get('transform', None) if extra_transform: transform.insert(0, extra_transform) return encoder_transform, transform
def build_dataloader(self, data, batch_size, sampler_builder: SamplerBuilder = None, gradient_accumulation=1, shuffle=False, device=None, logger: logging.Logger = None, **kwargs) -> DataLoader: if isinstance(data, TransformDataset): dataset = data else: transform = self.config.encoder.transform() if self.config.get('transform', None): transform = TransformList(self.config.transform, transform) dataset = self.build_dataset(data, transform, logger) if self.vocabs.mutable: # noinspection PyTypeChecker self.build_vocabs(dataset, logger) lens = [len(x['token_input_ids']) for x in dataset] if sampler_builder: sampler = sampler_builder.build(lens, shuffle, gradient_accumulation) else: sampler = None return PadSequenceDataLoader(dataset, batch_size, shuffle, device=device, batch_sampler=sampler)
def build_dataloader(self, data, shuffle, device, embed: Embedding, training=False, logger=None, gradient_accumulation=1, sampler_builder=None, batch_size=None, bos='\0', **kwargs) -> DataLoader: first_transform = TransformList(functools.partial(append_bos, bos=bos)) embed_transform = embed.transform(vocabs=self.vocabs) transformer_transform = self._get_transformer_transform_from_transforms(embed_transform) if embed_transform: if transformer_transform and isinstance(embed_transform, TransformList): embed_transform.remove(transformer_transform) first_transform.append(embed_transform) dataset = self.build_dataset(data, first_transform=first_transform) if self.config.get('transform', None): dataset.append_transform(self.config.transform) if self.vocabs.mutable: self.build_vocabs(dataset, logger, self._transformer_trainable()) if transformer_transform and isinstance(embed_transform, TransformList): embed_transform.append(transformer_transform) dataset.append_transform(FieldLength('token', 'sent_length')) if isinstance(data, str): dataset.purge_cache() if len(dataset) > 1000 and isinstance(data, str): timer = CountdownTimer(len(dataset)) self.cache_dataset(dataset, timer, training, logger) if sampler_builder: lens = [sample['sent_length'] for sample in dataset] sampler = sampler_builder.build(lens, shuffle, gradient_accumulation) else: sampler = None loader = PadSequenceDataLoader(dataset=dataset, batch_sampler=sampler, batch_size=batch_size, pad=self.get_pad_dict(), device=device, vocabs=self.vocabs) return loader
def build_dataset(self, data, bos_transform=None): transform = TransformList( functools.partial(append_bos, pos_key='UPOS'), functools.partial(unpack_deps_to_head_deprel, pad_rel=self.config.pad_rel, arc_key='arc_2nd', rel_key='rel_2nd')) if self.config.joint: transform.append(merge_head_deprel_with_2nd) if bos_transform: transform.append(bos_transform) return super().build_dataset(data, transform)
class Transformable(ABC): def __init__(self, transform: Union[Callable, List] = None) -> None: super().__init__() if isinstance(transform, list) and not isinstance(transform, TransformList): transform = TransformList(*transform) self.transform: Union[Callable, TransformList] = transform def append_transform(self, transform: Callable): assert transform is not None, 'None transform not allowed' if not self.transform: self.transform = TransformList(transform) elif not isinstance(self.transform, TransformList): if self.transform != transform: self.transform = TransformList(self.transform, transform) else: if transform not in self.transform: self.transform.append(transform) return self def insert_transform(self, index: int, transform: Callable): assert transform is not None, 'None transform not allowed' if not self.transform: self.transform = TransformList(transform) elif not isinstance(self.transform, TransformList): if self.transform != transform: self.transform = TransformList(self.transform) self.transform.insert(index, transform) else: if transform not in self.transform: self.transform.insert(index, transform) return self def transform_sample(self, sample: dict, inplace=False) -> dict: if not inplace: sample = copy(sample) if self.transform: sample = self.transform(sample) return sample
def build_dataloader(self, data, batch_size, shuffle, device, logger: logging.Logger = None, vocabs=None, sampler_builder=None, gradient_accumulation=1, **kwargs) -> DataLoader: if vocabs is None: vocabs = self.vocabs transform = TransformList(unpack_ner, FieldLength('token')) if isinstance(self.config.embed, Embedding): transform.append(self.config.embed.transform(vocabs=vocabs)) transform.append(self.vocabs) dataset = self.build_dataset(data, vocabs, transform) if vocabs.mutable: self.build_vocabs(dataset, logger, vocabs) if 'token' in vocabs: lens = [x['token'] for x in dataset] else: lens = [len(x['token_input_ids']) for x in dataset] if sampler_builder: sampler = sampler_builder.build(lens, shuffle, gradient_accumulation) else: sampler = None return PadSequenceDataLoader(batch_sampler=sampler, device=device, dataset=dataset)
def last_transform(self): return TransformList(self.vocabs, FieldLength(self.config.token_key))
def build_dataset(self, data, transform=None, **kwargs): if not isinstance(transform, list): transform = TransformList() transform.append(add_lemma_rules_to_sample) return super().build_dataset(data, transform, **kwargs)
def __init__(self, transform: Union[Callable, List] = None) -> None: super().__init__() if isinstance(transform, list) and not isinstance(transform, TransformList): transform = TransformList(*transform) self.transform: Union[Callable, TransformList] = transform
def transform(self, **kwargs) -> Callable: vocab = Vocab() vocab.load(os.path.join(get_resource(self.path), 'vocab.json')) return TransformList(ContextualStringEmbeddingTransform(self.field), FieldToIndex(f'{self.field}_f_char', vocab), FieldToIndex(f'{self.field}_b_char', vocab))
def transform(self, **kwargs): transforms = [e.transform(**kwargs) for e in self._embeddings] transforms = [t for t in transforms if t] return TransformList(*transforms)