def get_iterator(self, inputs, num_workers: int, batch_size: int, preprocess_params, forward_params, postprocess_params): if isinstance(inputs, collections.abc.Sized): dataset = PipelineDataset(inputs, self.preprocess, preprocess_params) else: if num_workers > 1: logger.warning( "For iterable dataset using num_workers>1 is likely to result" " in errors since everything is iterable, setting `num_workers=1`" " to guarantee correctness.") num_workers = 1 dataset = PipelineIterator(inputs, self.preprocess, preprocess_params) if "TOKENIZERS_PARALLELISM" not in os.environ: logger.info( "Disabling tokenizer parallelism, we're using DataLoader multithreading already" ) os.environ["TOKENIZERS_PARALLELISM"] = "false" collate_fn = no_collate_fn if batch_size == 1 else pad_collate_fn( self.tokenizer, self.feature_extractor) dataloader = DataLoader(dataset, num_workers=num_workers, batch_size=batch_size, collate_fn=collate_fn) model_iterator = PipelineIterator(dataloader, self.forward, forward_params, loader_batch_size=batch_size) final_iterator = PipelineIterator(model_iterator, self.postprocess, postprocess_params) return final_iterator
def test_pipeline_dataset(self): from transformers.pipelines.pt_utils import PipelineDataset dummy_dataset = [0, 1, 2, 3] def add(number, extra=0): return number + extra dataset = PipelineDataset(dummy_dataset, add, {"extra": 2}) self.assertEqual(len(dataset), 4) outputs = [dataset[i] for i in range(4)] self.assertEqual(outputs, [2, 3, 4, 5])