Beispiel #1
0
 def build_dataloader(self,
                      data,
                      transform: TransformList = None,
                      training=False,
                      device=None,
                      logger: logging.Logger = None,
                      tokenizer: PreTrainedTokenizer = None,
                      **kwargs) -> DataLoader:
     assert tokenizer
     dataset = TextTokenizingDataset(data, cache=isinstance(data, str), delimiter=self.config.sent_delimiter,
                                     generate_idx=isinstance(data, list),
                                     max_seq_len=self.config.max_seq_len,
                                     sent_delimiter=self.config.sent_delimiter,
                                     transform=[
                                         TransformerSequenceTokenizer(tokenizer,
                                                                      'text',
                                                                      ret_prefix_mask=True,
                                                                      ret_subtokens=True,
                                                                      ),
                                         FieldLength('text_input_ids', 'text_input_ids_length', delta=-2),
                                         generate_token_span_tuple])
     return PadSequenceDataLoader(
         batch_sampler=self.sampler_builder.build(self.compute_lens(data, dataset, 'text_input_ids', 'text'),
                                                  shuffle=training),
         device=device,
         dataset=dataset)
Beispiel #2
0
 def build_dataloader(self,
                      data,
                      batch_size,
                      shuffle,
                      device,
                      logger: logging.Logger = None,
                      vocabs=None,
                      sampler_builder=None,
                      gradient_accumulation=1,
                      **kwargs) -> DataLoader:
     if vocabs is None:
         vocabs = self.vocabs
     transform = TransformList(unpack_ner, FieldLength('token'))
     if isinstance(self.config.embed, Embedding):
         transform.append(self.config.embed.transform(vocabs=vocabs))
     transform.append(self.vocabs)
     dataset = self.build_dataset(data, vocabs, transform)
     if vocabs.mutable:
         self.build_vocabs(dataset, logger, vocabs)
     if 'token' in vocabs:
         lens = [x['token'] for x in dataset]
     else:
         lens = [len(x['token_input_ids']) for x in dataset]
     if sampler_builder:
         sampler = sampler_builder.build(lens, shuffle,
                                         gradient_accumulation)
     else:
         sampler = None
     return PadSequenceDataLoader(batch_sampler=sampler,
                                  device=device,
                                  dataset=dataset)
Beispiel #3
0
 def build_dataloader(self, data, batch_size, shuffle, device, logger: logging.Logger, sampler_builder,
                      gradient_accumulation,
                      **kwargs) -> DataLoader:
     # shuffle = False  # We need to find the smallest grad_acc
     dataset = HeadDrivenPhraseStructureDataset(data, transform=[append_bos_eos])
     if self.config.get('transform', None):
         dataset.append_transform(self.config.transform)
     dataset.append_transform(self.vocabs)
     if isinstance(self.config.embed, Embedding):
         transform = self.config.embed.transform(vocabs=self.vocabs)
         if transform:
             dataset.append_transform(transform)
     dataset.append_transform(self.vocabs)
     field_length = FieldLength('token')
     dataset.append_transform(field_length)
     if isinstance(data, str):
         dataset.purge_cache()  # Enable cache
     if self.vocabs.mutable:
         self.build_vocabs(dataset, logger)
     if 'token' in self.vocabs:
         lens = [x[field_length.dst] for x in dataset]
     else:
         lens = [len(x['token_input_ids']) for x in dataset]
     if sampler_builder:
         sampler = sampler_builder.build(lens, shuffle, gradient_accumulation)
     else:
         sampler = None
     return PadSequenceDataLoader(batch_sampler=sampler,
                                  batch_size=batch_size,
                                  device=device,
                                  dataset=dataset)
Beispiel #4
0
 def build_dataloader(self,
                      data,
                      batch_size,
                      sampler_builder: SamplerBuilder = None,
                      gradient_accumulation=1,
                      shuffle=False,
                      device=None,
                      logger: logging.Logger = None,
                      **kwargs) -> DataLoader:
     if isinstance(data, TransformableDataset):
         dataset = data
     else:
         dataset = self.build_dataset(data, [
             self.config.embed.transform(vocabs=self.vocabs), self.vocabs,
             FieldLength('token')
         ])
     if self.vocabs.mutable:
         # noinspection PyTypeChecker
         self.build_vocabs(dataset, logger)
     lens = [len(x['token_input_ids']) for x in dataset]
     if sampler_builder:
         sampler = sampler_builder.build(lens, shuffle,
                                         gradient_accumulation)
     else:
         sampler = None
     return PadSequenceDataLoader(dataset,
                                  batch_size,
                                  shuffle,
                                  device=device,
                                  batch_sampler=sampler)
Beispiel #5
0
 def build_transform(self, task: Task) -> Tuple[TransformerSequenceTokenizer, TransformList]:
     encoder: ContextualWordEmbedding = self.config.encoder
     encoder_transform: TransformerSequenceTokenizer = task.build_tokenizer(encoder.transform())
     length_transform = FieldLength('token', 'token_length')
     transform = TransformList(encoder_transform, length_transform)
     extra_transform = self.config.get('transform', None)
     if extra_transform:
         transform.insert(0, extra_transform)
     return encoder_transform, transform
 def build_dataset(self, data, transform, logger=None):
     _transform = [
         unpack_tree_to_features, self.vocabs,
         FieldLength('token'), transform
     ]
     if self.config.get('no_subcategory', True):
         _transform.insert(0, remove_subcategory)
     dataset = ConstituencyDataset(data,
                                   transform=_transform,
                                   cache=isinstance(data, str))
     return dataset
Beispiel #7
0
    def build_dataloader(self,
                         data,
                         shuffle,
                         device,
                         embed: Embedding,
                         training=False,
                         logger=None,
                         gradient_accumulation=1,
                         sampler_builder=None,
                         batch_size=None,
                         bos='\0',
                         **kwargs) -> DataLoader:
        first_transform = TransformList(functools.partial(append_bos, bos=bos))
        embed_transform = embed.transform(vocabs=self.vocabs)
        transformer_transform = self._get_transformer_transform_from_transforms(
            embed_transform)
        if embed_transform:
            if transformer_transform and isinstance(embed_transform,
                                                    TransformList):
                embed_transform.remove(transformer_transform)

            first_transform.append(embed_transform)
        dataset = self.build_dataset(data, first_transform=first_transform)
        if self.config.get('transform', None):
            dataset.append_transform(self.config.transform)

        if self.vocabs.mutable:
            self.build_vocabs(dataset, logger, self._transformer_trainable())
        if transformer_transform and isinstance(embed_transform,
                                                TransformList):
            embed_transform.append(transformer_transform)

        dataset.append_transform(FieldLength('token', 'sent_length'))
        if isinstance(data, str):
            dataset.purge_cache()
        if len(dataset) > 1000 and isinstance(data, str):
            timer = CountdownTimer(len(dataset))
            self.cache_dataset(dataset, timer, training, logger)
        if sampler_builder:
            lens = [sample['sent_length'] for sample in dataset]
            sampler = sampler_builder.build(lens, shuffle,
                                            gradient_accumulation)
        else:
            sampler = None
        loader = PadSequenceDataLoader(dataset=dataset,
                                       batch_sampler=sampler,
                                       batch_size=batch_size,
                                       pad=self.get_pad_dict(),
                                       device=device,
                                       vocabs=self.vocabs)
        return loader
Beispiel #8
0
 def build_dataset(self, data, generate_idx, logger, transform=None):
     dataset = CoNLL2012SRLDataset(data, transform=[filter_v_args, unpack_srl, group_pa_by_p],
                                   doc_level_offset=self.config.doc_level_offset, generate_idx=generate_idx)
     if transform:
         dataset.append_transform(transform)
     if isinstance(self.config.get('embed', None), Embedding):
         transform = self.config.embed.transform(vocabs=self.vocabs)
         if transform:
             dataset.append_transform(transform)
     dataset.append_transform(self.vocabs)
     dataset.append_transform(FieldLength('token'))
     if isinstance(data, str):
         dataset.purge_cache()  # Enable cache
     if self.vocabs.mutable:
         self.build_vocabs(dataset, logger)
     return dataset
Beispiel #9
0
 def build_dataloader(self, data, batch_size, shuffle, device,
                      logger: logging.Logger, **kwargs) -> DataLoader:
     dataset = CONLL12CorefDataset(data, [FieldLength('text')])
     if isinstance(self.config.embed, Embedding):
         transform = self.config.embed.transform(vocabs=self.vocabs)
         if transform:
             dataset.append_transform(transform)
     dataset.append_transform(self.vocabs)
     if isinstance(data, str):
         dataset.purge_cache()  # Enable cache
     if self.vocabs.mutable:
         self.build_vocabs(dataset)
     return PadSequenceDataLoader(batch_size=batch_size,
                                  shuffle=shuffle,
                                  device=device,
                                  dataset=dataset,
                                  pad={
                                      'spans': 0,
                                      'span_labels': -1
                                  })
Beispiel #10
0
 def build_dataloader(self,
                      data,
                      shuffle,
                      device,
                      training=False,
                      logger=None,
                      gradient_accumulation=1,
                      sampler_builder=None,
                      batch_size=None,
                      **kwargs) -> DataLoader:
     dataset = self.build_dataset(data)
     if self.vocabs.mutable:
         self.build_vocabs(dataset, logger, self.config.transformer)
     transformer_tokenizer = self.transformer_tokenizer
     if transformer_tokenizer:
         dataset.transform.append(self.build_tokenizer_transform())
     dataset.append_transform(FieldLength('token', 'sent_length'))
     if isinstance(data, str):
         dataset.purge_cache()
     if len(dataset) > 1000 and isinstance(data, str):
         timer = CountdownTimer(len(dataset))
         self.cache_dataset(dataset, timer, training, logger)
     if self.config.transformer:
         lens = [len(sample['input_ids']) for sample in dataset]
     else:
         lens = [sample['sent_length'] for sample in dataset]
     if sampler_builder:
         sampler = sampler_builder.build(lens, shuffle,
                                         gradient_accumulation)
     else:
         sampler = None
     loader = PadSequenceDataLoader(dataset=dataset,
                                    batch_sampler=sampler,
                                    batch_size=batch_size,
                                    num_workers=0 if isdebugging() else 2,
                                    pad=self.get_pad_dict(),
                                    device=device,
                                    vocabs=self.vocabs)
     return loader
Beispiel #11
0
 def build_dataloader(self,
                      data,
                      batch_size,
                      shuffle=False,
                      device=None,
                      logger: logging.Logger = None,
                      sampler_builder=None,
                      gradient_accumulation=1,
                      transformer: ContextualWordEmbedding = None,
                      **kwargs) -> DataLoader:
     transform = [
         generate_lemma_rule, append_bos, self.vocabs,
         transformer.transform(),
         FieldLength('token')
     ]
     if not self.config.punct:
         transform.append(PunctuationMask('token', 'punct_mask'))
     dataset = self.build_dataset(data, transform)
     if self.vocabs.mutable:
         # noinspection PyTypeChecker
         self.build_vocabs(dataset, logger)
     lens = [len(x['token_input_ids']) for x in dataset]
     if sampler_builder:
         sampler = sampler_builder.build(lens, shuffle,
                                         gradient_accumulation)
     else:
         sampler = SortingSamplerBuilder(batch_size).build(
             lens, shuffle, gradient_accumulation)
     return PadSequenceDataLoader(
         dataset,
         batch_size,
         shuffle,
         device=device,
         batch_sampler=sampler,
         pad={'arc': 0},
     )
Beispiel #12
0
 def last_transform(self):
     return TransformList(self.vocabs, FieldLength(self.config.token_key))