def get_iterator(self, inputs, num_workers: int, batch_size: int, preprocess_params, forward_params, postprocess_params): if isinstance(inputs, collections.abc.Sized): dataset = PipelineDataset(inputs, self.preprocess, preprocess_params) else: if num_workers > 1: logger.warning( "For iterable dataset using num_workers>1 is likely to result" " in errors since everything is iterable, setting `num_workers=1`" " to guarantee correctness.") num_workers = 1 dataset = PipelineIterator(inputs, self.preprocess, preprocess_params) if "TOKENIZERS_PARALLELISM" not in os.environ: logger.info( "Disabling tokenizer parallelism, we're using DataLoader multithreading already" ) os.environ["TOKENIZERS_PARALLELISM"] = "false" collate_fn = no_collate_fn if batch_size == 1 else pad_collate_fn( self.tokenizer, self.feature_extractor) dataloader = DataLoader(dataset, num_workers=num_workers, batch_size=batch_size, collate_fn=collate_fn) model_iterator = PipelineIterator(dataloader, self.forward, forward_params, loader_batch_size=batch_size) final_iterator = PipelineIterator(model_iterator, self.postprocess, postprocess_params) return final_iterator
def test_pipeline_batch_unbatch_iterator_tensors(self): import torch from transformers.pipelines.pt_utils import PipelineIterator dummy_dataset = [{ "id": torch.LongTensor([[10, 20], [0, 1], [0, 2]]) }, { "id": torch.LongTensor([[3]]) }] def add(number, extra=0): return {"id": number["id"] + extra} dataset = PipelineIterator(dummy_dataset, add, {"extra": 2}, loader_batch_size=3) outputs = [item for item in dataset] self.assertEqual(nested_simplify(outputs), [{ "id": [[12, 22]] }, { "id": [[2, 3]] }, { "id": [[2, 4]] }, { "id": [[5]] }])
def test_pipeline_batch_unbatch_iterator(self): from transformers.pipelines.pt_utils import PipelineIterator dummy_dataset = [{"id": [0, 1, 2]}, {"id": [3]}] def add(number, extra=0): return {"id": [i + extra for i in number["id"]]} dataset = PipelineIterator(dummy_dataset, add, {"extra": 2}, loader_batch_size=3) outputs = [item for item in dataset] self.assertEqual(outputs, [{"id": 2}, {"id": 3}, {"id": 4}, {"id": 5}])
def test_pipeline_iterator(self): from transformers.pipelines.pt_utils import PipelineIterator dummy_dataset = [0, 1, 2, 3] def add(number, extra=0): return number + extra dataset = PipelineIterator(dummy_dataset, add, {"extra": 2}) self.assertEqual(len(dataset), 4) outputs = [item for item in dataset] self.assertEqual(outputs, [2, 3, 4, 5])
def test_pipeline_iterator_no_len(self): from transformers.pipelines.pt_utils import PipelineIterator def dummy_dataset(): for i in range(4): yield i def add(number, extra=0): return number + extra dataset = PipelineIterator(dummy_dataset(), add, {"extra": 2}) with self.assertRaises(TypeError): len(dataset) outputs = [item for item in dataset] self.assertEqual(outputs, [2, 3, 4, 5])