Beispiel #1
0
 def get_iterator(self, inputs, num_workers: int, batch_size: int,
                  preprocess_params, forward_params, postprocess_params):
     if isinstance(inputs, collections.abc.Sized):
         dataset = PipelineDataset(inputs, self.preprocess,
                                   preprocess_params)
     else:
         if num_workers > 1:
             logger.warning(
                 "For iterable dataset using num_workers>1 is likely to result"
                 " in errors since everything is iterable, setting `num_workers=1`"
                 " to guarantee correctness.")
             num_workers = 1
         dataset = PipelineIterator(inputs, self.preprocess,
                                    preprocess_params)
     if "TOKENIZERS_PARALLELISM" not in os.environ:
         logger.info(
             "Disabling tokenizer parallelism, we're using DataLoader multithreading already"
         )
         os.environ["TOKENIZERS_PARALLELISM"] = "false"
     collate_fn = no_collate_fn if batch_size == 1 else pad_collate_fn(
         self.tokenizer, self.feature_extractor)
     dataloader = DataLoader(dataset,
                             num_workers=num_workers,
                             batch_size=batch_size,
                             collate_fn=collate_fn)
     model_iterator = PipelineIterator(dataloader,
                                       self.forward,
                                       forward_params,
                                       loader_batch_size=batch_size)
     final_iterator = PipelineIterator(model_iterator, self.postprocess,
                                       postprocess_params)
     return final_iterator
    def test_pipeline_batch_unbatch_iterator_tensors(self):
        import torch

        from transformers.pipelines.pt_utils import PipelineIterator

        dummy_dataset = [{
            "id": torch.LongTensor([[10, 20], [0, 1], [0, 2]])
        }, {
            "id": torch.LongTensor([[3]])
        }]

        def add(number, extra=0):
            return {"id": number["id"] + extra}

        dataset = PipelineIterator(dummy_dataset,
                                   add, {"extra": 2},
                                   loader_batch_size=3)

        outputs = [item for item in dataset]
        self.assertEqual(nested_simplify(outputs), [{
            "id": [[12, 22]]
        }, {
            "id": [[2, 3]]
        }, {
            "id": [[2, 4]]
        }, {
            "id": [[5]]
        }])
    def test_pipeline_batch_unbatch_iterator(self):
        from transformers.pipelines.pt_utils import PipelineIterator

        dummy_dataset = [{"id": [0, 1, 2]}, {"id": [3]}]

        def add(number, extra=0):
            return {"id": [i + extra for i in number["id"]]}

        dataset = PipelineIterator(dummy_dataset, add, {"extra": 2}, loader_batch_size=3)

        outputs = [item for item in dataset]
        self.assertEqual(outputs, [{"id": 2}, {"id": 3}, {"id": 4}, {"id": 5}])
    def test_pipeline_iterator(self):
        from transformers.pipelines.pt_utils import PipelineIterator

        dummy_dataset = [0, 1, 2, 3]

        def add(number, extra=0):
            return number + extra

        dataset = PipelineIterator(dummy_dataset, add, {"extra": 2})
        self.assertEqual(len(dataset), 4)

        outputs = [item for item in dataset]
        self.assertEqual(outputs, [2, 3, 4, 5])
    def test_pipeline_iterator_no_len(self):
        from transformers.pipelines.pt_utils import PipelineIterator

        def dummy_dataset():
            for i in range(4):
                yield i

        def add(number, extra=0):
            return number + extra

        dataset = PipelineIterator(dummy_dataset(), add, {"extra": 2})
        with self.assertRaises(TypeError):
            len(dataset)

        outputs = [item for item in dataset]
        self.assertEqual(outputs, [2, 3, 4, 5])