コード例 #1
0
    def __init__(self, input_dataset, num_workers, index):
        self._input_dataset = input_dataset

        self._element_spec = input_dataset.element_spec
        variant_tensor = ged_ops.auto_shard_dataset(
            self._input_dataset._variant_tensor,  # pylint: disable=protected-access
            num_workers=num_workers,
            index=index,
            **self._flat_structure)
        super(_AutoShardDataset, self).__init__(input_dataset, variant_tensor)
コード例 #2
0
ファイル: distribute.py プロジェクト: Harryi0/tinyML
    def __init__(self, input_dataset, num_workers, index):
        self._input_dataset = input_dataset

        self._element_spec = input_dataset.element_spec
        if (compat.forward_compatible(2019, 11, 25) or
            (input_dataset.options().experimental_distribute.auto_shard_policy
             != AutoShardPolicy.AUTO)):
            variant_tensor = ged_ops.auto_shard_dataset(
                self._input_dataset._variant_tensor,  # pylint: disable=protected-access
                num_workers=num_workers,
                index=index,
                auto_shard_policy=int(input_dataset.options(
                ).experimental_distribute.auto_shard_policy),
                **self._flat_structure)
        else:
            variant_tensor = ged_ops.auto_shard_dataset(
                self._input_dataset._variant_tensor,  # pylint: disable=protected-access
                num_workers=num_workers,
                index=index,
                **self._flat_structure)
        super(_AutoShardDataset, self).__init__(input_dataset, variant_tensor)
コード例 #3
0
    def __init__(self, input_dataset, num_workers, index, num_replicas=None):
        self._input_dataset = input_dataset

        self._element_spec = input_dataset.element_spec
        variant_tensor = ged_ops.auto_shard_dataset(
            self._input_dataset._variant_tensor,  # pylint: disable=protected-access
            num_workers=num_workers,
            index=index,
            auto_shard_policy=int(input_dataset.options().
                                  experimental_distribute.auto_shard_policy),
            num_replicas=num_replicas,
            **self._flat_structure)
        super(_AutoShardDataset, self).__init__(input_dataset, variant_tensor)
コード例 #4
0
    def __init__(self, input_dataset, num_workers, index):
        self._input_dataset = input_dataset

        self._structure = input_dataset._element_structure  # pylint: disable=protected-access
        if compat.forward_compatible(2019, 8, 3):
            variant_tensor = ged_ops.auto_shard_dataset(
                self._input_dataset._variant_tensor,  # pylint: disable=protected-access
                num_workers=num_workers,
                index=index,
                **self._flat_structure)
        else:
            variant_tensor = ged_ops.experimental_auto_shard_dataset(
                self._input_dataset._variant_tensor,  # pylint: disable=protected-access
                num_workers=num_workers,
                index=index,
                **self._flat_structure)
        super(_AutoShardDataset, self).__init__(input_dataset, variant_tensor)