コード例 #1
0
ファイル: graph_settings.py プロジェクト: timgates42/pytorch
def apply_shuffle_seed(datapipe: DataPipe, rng: Any) -> DataPipe:
    graph = traverse(datapipe, only_datapipe=True)
    all_pipes = get_all_graph_pipes(graph)
    shufflers = {pipe for pipe in all_pipes if isinstance(pipe, (dp.iter.Shuffler, dp.map.Shuffler))}

    for shuffler in shufflers:
        shuffle_seed = int(torch.empty((), dtype=torch.int64).random_(generator=rng).item())
        shuffler.set_seed(shuffle_seed)

    return datapipe
コード例 #2
0
ファイル: graph_settings.py プロジェクト: timgates42/pytorch
def apply_sharding(datapipe: DataPipe, num_of_instances: int, instance_id: int) -> DataPipe:
    graph = traverse(datapipe, only_datapipe=True)
    all_pipes = get_all_graph_pipes(graph)
    already_applied_to = None
    for pipe in all_pipes:
        if hasattr(pipe, 'is_shardable'):
            if pipe.is_shardable():
                if hasattr(pipe, 'apply_sharding'):
                    if already_applied_to is not None:
                        raise RuntimeError('This implementation of sharding can be only applied once per instance of DataPipeline.',
                                           'Already applied to', already_applied_to, 'while trying to apply to', pipe)
                    pipe.apply_sharding(num_of_instances, instance_id)
                    already_applied_to = pipe
    return datapipe
コード例 #3
0
    def test_has_annotations(self, dataset_mock, config, annotation_dp_type):
        def scan(graph):
            for node, sub_graph in graph.items():
                yield node
                yield from scan(sub_graph)

        dataset, _ = dataset_mock.load(config)

        for dp in scan(traverse(dataset)):
            if type(dp) is annotation_dp_type:
                break
        else:
            raise AssertionError(
                f"The dataset doesn't comprise a {annotation_dp_type.__name__}() datapipe."
            )
コード例 #4
0
    def test_has_annotations(self, test_home, dataset_mock, config,
                             annotation_dp_type):
        def scan(graph):
            for node, sub_graph in graph.items():
                yield node
                yield from scan(sub_graph)

        dataset_mock.prepare(test_home, config)

        dataset = datasets.load(dataset_mock.name, **config)

        if not any(
                type(dp) is annotation_dp_type
                for dp in scan(traverse(dataset))):
            raise AssertionError(
                f"The dataset doesn't contain a {annotation_dp_type.__name__}() datapipe."
            )
コード例 #5
0
ファイル: graph_settings.py プロジェクト: timgates42/pytorch
def apply_shuffle_settings(datapipe: DataPipe, shuffle: Optional[bool]) -> DataPipe:
    if shuffle is None:
        return datapipe

    graph = traverse(datapipe, only_datapipe=True)
    all_pipes = get_all_graph_pipes(graph)
    shufflers = [pipe for pipe in all_pipes if isinstance(pipe, (dp.iter.Shuffler, dp.map.Shuffler))]
    if not shufflers and shuffle:
        warnings.warn(
            "`shuffle=True` was set, but the datapipe does not contain a `Shuffler`. Adding one at the end. "
            "Be aware that the default buffer size might not be sufficient for your task."
        )
        datapipe = datapipe.shuffle()
        shufflers = [datapipe, ]  # type: ignore[list-item]

    for shuffler in shufflers:
        shuffler.set_shuffle(shuffle)

    return datapipe
コード例 #6
0
    def test_traversable(self, dataset_mock, config):
        dataset, _ = dataset_mock.load(config)

        traverse(dataset)
コード例 #7
0
    def test_traversable(self, test_home, dataset_mock, config):
        dataset_mock.prepare(test_home, config)

        dataset = datasets.load(dataset_mock.name, **config)

        traverse(dataset)
コード例 #8
0
def extract_datapipes(dp):
    return get_all_graph_pipes(traverse(dp, only_datapipe=True))
コード例 #9
0
    def test_traversable(self, dataset_mock, config, only_datapipe):
        dataset, _ = dataset_mock.load(config)

        traverse(dataset, only_datapipe=only_datapipe)