def apply_shuffle_seed(datapipe: DataPipe, rng: Any) -> DataPipe: graph = traverse(datapipe, only_datapipe=True) all_pipes = get_all_graph_pipes(graph) shufflers = {pipe for pipe in all_pipes if isinstance(pipe, (dp.iter.Shuffler, dp.map.Shuffler))} for shuffler in shufflers: shuffle_seed = int(torch.empty((), dtype=torch.int64).random_(generator=rng).item()) shuffler.set_seed(shuffle_seed) return datapipe
def apply_sharding(datapipe: DataPipe, num_of_instances: int, instance_id: int) -> DataPipe: graph = traverse(datapipe, only_datapipe=True) all_pipes = get_all_graph_pipes(graph) already_applied_to = None for pipe in all_pipes: if hasattr(pipe, 'is_shardable'): if pipe.is_shardable(): if hasattr(pipe, 'apply_sharding'): if already_applied_to is not None: raise RuntimeError('This implementation of sharding can be only applied once per instance of DataPipeline.', 'Already applied to', already_applied_to, 'while trying to apply to', pipe) pipe.apply_sharding(num_of_instances, instance_id) already_applied_to = pipe return datapipe
def test_has_annotations(self, dataset_mock, config, annotation_dp_type): def scan(graph): for node, sub_graph in graph.items(): yield node yield from scan(sub_graph) dataset, _ = dataset_mock.load(config) for dp in scan(traverse(dataset)): if type(dp) is annotation_dp_type: break else: raise AssertionError( f"The dataset doesn't comprise a {annotation_dp_type.__name__}() datapipe." )
def test_has_annotations(self, test_home, dataset_mock, config, annotation_dp_type): def scan(graph): for node, sub_graph in graph.items(): yield node yield from scan(sub_graph) dataset_mock.prepare(test_home, config) dataset = datasets.load(dataset_mock.name, **config) if not any( type(dp) is annotation_dp_type for dp in scan(traverse(dataset))): raise AssertionError( f"The dataset doesn't contain a {annotation_dp_type.__name__}() datapipe." )
def apply_shuffle_settings(datapipe: DataPipe, shuffle: Optional[bool]) -> DataPipe: if shuffle is None: return datapipe graph = traverse(datapipe, only_datapipe=True) all_pipes = get_all_graph_pipes(graph) shufflers = [pipe for pipe in all_pipes if isinstance(pipe, (dp.iter.Shuffler, dp.map.Shuffler))] if not shufflers and shuffle: warnings.warn( "`shuffle=True` was set, but the datapipe does not contain a `Shuffler`. Adding one at the end. " "Be aware that the default buffer size might not be sufficient for your task." ) datapipe = datapipe.shuffle() shufflers = [datapipe, ] # type: ignore[list-item] for shuffler in shufflers: shuffler.set_shuffle(shuffle) return datapipe
def test_traversable(self, dataset_mock, config): dataset, _ = dataset_mock.load(config) traverse(dataset)
def test_traversable(self, test_home, dataset_mock, config): dataset_mock.prepare(test_home, config) dataset = datasets.load(dataset_mock.name, **config) traverse(dataset)
def extract_datapipes(dp): return get_all_graph_pipes(traverse(dp, only_datapipe=True))
def test_traversable(self, dataset_mock, config, only_datapipe): dataset, _ = dataset_mock.load(config) traverse(dataset, only_datapipe=only_datapipe)