def build_dataset(self, data, transform=None): transforms = TransformList( functools.partial(append_bos_to_form_pos, pos_key='UPOS'), functools.partial(unpack_deps_to_head_deprel, pad_rel=self.config.pad_rel)) if transform: transforms.append(transform) return super(BiaffineSemanticDependencyParser, self).build_dataset(data, transforms)
def build_dataset(self, data, bos_transform=None): transform = TransformList( functools.partial(append_bos, pos_key='UPOS'), functools.partial(unpack_deps_to_head_deprel, pad_rel=self.config.pad_rel, arc_key='arc_2nd', rel_key='rel_2nd')) if self.config.joint: transform.append(merge_head_deprel_with_2nd) if bos_transform: transform.append(bos_transform) return super().build_dataset(data, transform)
def build_dataloader(self, data, shuffle, device, embed: Embedding, training=False, logger=None, gradient_accumulation=1, sampler_builder=None, batch_size=None, bos='\0', **kwargs) -> DataLoader: first_transform = TransformList(functools.partial(append_bos, bos=bos)) embed_transform = embed.transform(vocabs=self.vocabs) transformer_transform = self._get_transformer_transform_from_transforms(embed_transform) if embed_transform: if transformer_transform and isinstance(embed_transform, TransformList): embed_transform.remove(transformer_transform) first_transform.append(embed_transform) dataset = self.build_dataset(data, first_transform=first_transform) if self.config.get('transform', None): dataset.append_transform(self.config.transform) if self.vocabs.mutable: self.build_vocabs(dataset, logger, self._transformer_trainable()) if transformer_transform and isinstance(embed_transform, TransformList): embed_transform.append(transformer_transform) dataset.append_transform(FieldLength('token', 'sent_length')) if isinstance(data, str): dataset.purge_cache() if len(dataset) > 1000 and isinstance(data, str): timer = CountdownTimer(len(dataset)) self.cache_dataset(dataset, timer, training, logger) if sampler_builder: lens = [sample['sent_length'] for sample in dataset] sampler = sampler_builder.build(lens, shuffle, gradient_accumulation) else: sampler = None loader = PadSequenceDataLoader(dataset=dataset, batch_sampler=sampler, batch_size=batch_size, pad=self.get_pad_dict(), device=device, vocabs=self.vocabs) return loader
class Transformable(ABC): def __init__(self, transform: Union[Callable, List] = None) -> None: super().__init__() if isinstance(transform, list) and not isinstance(transform, TransformList): transform = TransformList(*transform) self.transform: Union[Callable, TransformList] = transform def append_transform(self, transform: Callable): assert transform is not None, 'None transform not allowed' if not self.transform: self.transform = TransformList(transform) elif not isinstance(self.transform, TransformList): if self.transform != transform: self.transform = TransformList(self.transform, transform) else: if transform not in self.transform: self.transform.append(transform) return self def insert_transform(self, index: int, transform: Callable): assert transform is not None, 'None transform not allowed' if not self.transform: self.transform = TransformList(transform) elif not isinstance(self.transform, TransformList): if self.transform != transform: self.transform = TransformList(self.transform) self.transform.insert(index, transform) else: if transform not in self.transform: self.transform.insert(index, transform) return self def transform_sample(self, sample: dict, inplace=False) -> dict: if not inplace: sample = copy(sample) if self.transform: sample = self.transform(sample) return sample
def build_dataloader(self, data, batch_size, shuffle, device, logger: logging.Logger = None, vocabs=None, sampler_builder=None, gradient_accumulation=1, **kwargs) -> DataLoader: if vocabs is None: vocabs = self.vocabs transform = TransformList(unpack_ner, FieldLength('token')) if isinstance(self.config.embed, Embedding): transform.append(self.config.embed.transform(vocabs=vocabs)) transform.append(self.vocabs) dataset = self.build_dataset(data, vocabs, transform) if vocabs.mutable: self.build_vocabs(dataset, logger, vocabs) if 'token' in vocabs: lens = [x['token'] for x in dataset] else: lens = [len(x['token_input_ids']) for x in dataset] if sampler_builder: sampler = sampler_builder.build(lens, shuffle, gradient_accumulation) else: sampler = None return PadSequenceDataLoader(batch_sampler=sampler, device=device, dataset=dataset)
def build_dataset(self, data, transform=None, **kwargs): if not isinstance(transform, list): transform = TransformList() transform.append(add_lemma_rules_to_sample) return super().build_dataset(data, transform, **kwargs)