Ejemplo n.º 1
0
 def get_iterator(self, inputs, num_workers: int, batch_size: int,
                  preprocess_params, forward_params, postprocess_params):
     if isinstance(inputs, collections.abc.Sized):
         dataset = PipelineDataset(inputs, self.preprocess,
                                   preprocess_params)
     else:
         if num_workers > 1:
             logger.warning(
                 "For iterable dataset using num_workers>1 is likely to result"
                 " in errors since everything is iterable, setting `num_workers=1`"
                 " to guarantee correctness.")
             num_workers = 1
         dataset = PipelineIterator(inputs, self.preprocess,
                                    preprocess_params)
     if "TOKENIZERS_PARALLELISM" not in os.environ:
         logger.info(
             "Disabling tokenizer parallelism, we're using DataLoader multithreading already"
         )
         os.environ["TOKENIZERS_PARALLELISM"] = "false"
     collate_fn = no_collate_fn if batch_size == 1 else pad_collate_fn(
         self.tokenizer, self.feature_extractor)
     dataloader = DataLoader(dataset,
                             num_workers=num_workers,
                             batch_size=batch_size,
                             collate_fn=collate_fn)
     model_iterator = PipelineIterator(dataloader,
                                       self.forward,
                                       forward_params,
                                       loader_batch_size=batch_size)
     final_iterator = PipelineIterator(model_iterator, self.postprocess,
                                       postprocess_params)
     return final_iterator
    def test_pipeline_dataset(self):
        from transformers.pipelines.pt_utils import PipelineDataset

        dummy_dataset = [0, 1, 2, 3]

        def add(number, extra=0):
            return number + extra

        dataset = PipelineDataset(dummy_dataset, add, {"extra": 2})
        self.assertEqual(len(dataset), 4)
        outputs = [dataset[i] for i in range(4)]
        self.assertEqual(outputs, [2, 3, 4, 5])