def _wrap_dataset(self, input_type, dataset, input_workers, split_batch_by, enable_get_next_as_optional): if isinstance(dataset, dataset_ops.Dataset): return input_lib.DistributedDatasetV1( dataset, input_workers, split_batch_by, _enable_get_next_as_optional=enable_get_next_as_optional) else: return input_lib.DistributedDataset( dataset, input_workers, split_batch_by, _enable_get_next_as_optional=enable_get_next_as_optional)
def testGraphModeError(self): with context.graph_mode(): worker_device_pairs = [("", ["/device:CPU:0"])] devices = nest.flatten([ds for _, ds in worker_device_pairs]) device_map = values.ReplicaDeviceMap(devices) input_workers = input_lib.InputWorkers(device_map, worker_device_pairs) dataset = dataset_ops.Dataset.range(10).batch(2) with self.assertRaisesRegexp( RuntimeError, "__iter__ is only " "supported when eager execution is " "enabled."): dist_dataset = input_lib.DistributedDatasetV1( dataset, input_workers) iter(dist_dataset)
def _wrap_dataset(self, input_type, dataset, input_workers, split_batch_by, strategy, input_context=None): if isinstance(dataset, dataset_ops.Dataset): return input_lib.DistributedDatasetV1( dataset, input_workers, strategy, split_batch_by=split_batch_by, input_context=input_context) else: return input_lib.DistributedDataset(dataset, input_workers, strategy, split_batch_by=split_batch_by, input_context=input_context)
def _wrap_dataset(self, input_type, dataset, input_workers, split_batch_by, strategy, input_context=None): if isinstance(dataset, (dataset_ops.Dataset, dataset_ops.DatasetV1Adapter)): return input_lib.DistributedDatasetV1( dataset, input_workers, strategy, split_batch_by=split_batch_by, input_context=input_context) elif input_type == "dataset": return input_lib.DistributedDataset( dataset, input_workers, strategy, split_batch_by=split_batch_by, input_context=input_context) else: return strategy.experimental_distribute_datasets_from_function(dataset)
def _wrap_dataset(self, input_type, dataset, input_workers, split_batch_by, strategy, input_context=None): if input_type == "dataset": if tf2.enabled(): return input_lib.DistributedDataset( dataset, input_workers, strategy, split_batch_by=split_batch_by, input_context=input_context) else: return input_lib.DistributedDatasetV1( dataset, input_workers, strategy, split_batch_by=split_batch_by, input_context=input_context) else: return strategy.experimental_distribute_datasets_from_function(dataset)