Beispiel #1
0
 def build_dataloader(self,
                      data,
                      transform: Callable = None,
                      training=False,
                      device=None,
                      logger: logging.Logger = None,
                      cache=False,
                      gradient_accumulation=1,
                      **kwargs) -> DataLoader:
     _transform = [generate_lemma_rule, append_bos, self.vocabs, transform]
     if isinstance(data, str) and not self.config.punct:
         _transform.append(PunctuationMask('token', 'punct_mask'))
     dataset = UniversalDependenciesParser.build_dataset(
         self, data, _transform)
     if self.vocabs.mutable:
         UniversalDependenciesParser.build_vocabs(self,
                                                  dataset,
                                                  logger,
                                                  transformer=True)
     max_seq_len = self.config.get('max_seq_len', None)
     if max_seq_len and isinstance(data, str):
         dataset.prune(lambda x: len(x['token_input_ids']) > max_seq_len,
                       logger)
     return PadSequenceDataLoader(batch_sampler=self.sampler_builder.build(
         self.compute_lens(data, dataset, length_field='token'),
         shuffle=training,
         gradient_accumulation=gradient_accumulation),
                                  device=device,
                                  dataset=dataset,
                                  pad={'arc': 0})
Beispiel #2
0
 def build_dataset(self, data, first_transform=None):
     if not first_transform:
         first_transform = append_bos
     transform = [first_transform, get_sibs]
     if self.config.get('lowercase', False):
         transform.append(LowerCase('token'))
     transform.append(self.vocabs)
     if not self.config.punct:
         transform.append(PunctuationMask('token', 'punct_mask'))
     return CoNLLParsingDataset(data, transform=transform)
Beispiel #3
0
 def build_dataloader(self,
                      data,
                      batch_size,
                      shuffle=False,
                      device=None,
                      logger: logging.Logger = None,
                      sampler_builder=None,
                      gradient_accumulation=1,
                      transformer: ContextualWordEmbedding = None,
                      **kwargs) -> DataLoader:
     transform = [
         generate_lemma_rule, append_bos, self.vocabs,
         transformer.transform(),
         FieldLength('token')
     ]
     if not self.config.punct:
         transform.append(PunctuationMask('token', 'punct_mask'))
     dataset = self.build_dataset(data, transform)
     if self.vocabs.mutable:
         # noinspection PyTypeChecker
         self.build_vocabs(dataset, logger)
     lens = [len(x['token_input_ids']) for x in dataset]
     if sampler_builder:
         sampler = sampler_builder.build(lens, shuffle,
                                         gradient_accumulation)
     else:
         sampler = SortingSamplerBuilder(batch_size).build(
             lens, shuffle, gradient_accumulation)
     return PadSequenceDataLoader(
         dataset,
         batch_size,
         shuffle,
         device=device,
         batch_sampler=sampler,
         pad={'arc': 0},
     )