예제 #1
0
 def _wrap_dataset(self, input_type, dataset, input_workers,
                   split_batch_by, enable_get_next_as_optional):
   if isinstance(dataset, dataset_ops.Dataset):
     return input_lib.DistributedDatasetV1(
         dataset, input_workers,
         split_batch_by,
         _enable_get_next_as_optional=enable_get_next_as_optional)
   else:
     return input_lib.DistributedDataset(
         dataset, input_workers,
         split_batch_by,
         _enable_get_next_as_optional=enable_get_next_as_optional)
예제 #2
0
    def testGraphModeError(self):
        with context.graph_mode():
            worker_device_pairs = [("", ["/device:CPU:0"])]
            devices = nest.flatten([ds for _, ds in worker_device_pairs])
            device_map = values.ReplicaDeviceMap(devices)
            input_workers = input_lib.InputWorkers(device_map,
                                                   worker_device_pairs)
            dataset = dataset_ops.Dataset.range(10).batch(2)

            with self.assertRaisesRegexp(
                    RuntimeError, "__iter__ is only "
                    "supported when eager execution is "
                    "enabled."):
                dist_dataset = input_lib.DistributedDatasetV1(
                    dataset, input_workers)
                iter(dist_dataset)
예제 #3
0
 def _wrap_dataset(self,
                   input_type,
                   dataset,
                   input_workers,
                   split_batch_by,
                   strategy,
                   input_context=None):
     if isinstance(dataset, dataset_ops.Dataset):
         return input_lib.DistributedDatasetV1(
             dataset,
             input_workers,
             strategy,
             split_batch_by=split_batch_by,
             input_context=input_context)
     else:
         return input_lib.DistributedDataset(dataset,
                                             input_workers,
                                             strategy,
                                             split_batch_by=split_batch_by,
                                             input_context=input_context)
예제 #4
0
 def _wrap_dataset(self,
                   input_type,
                   dataset,
                   input_workers,
                   split_batch_by,
                   strategy,
                   input_context=None):
   if isinstance(dataset, (dataset_ops.Dataset, dataset_ops.DatasetV1Adapter)):
     return input_lib.DistributedDatasetV1(
         dataset,
         input_workers,
         strategy,
         split_batch_by=split_batch_by,
         input_context=input_context)
   elif input_type == "dataset":
     return input_lib.DistributedDataset(
         dataset,
         input_workers,
         strategy,
         split_batch_by=split_batch_by,
         input_context=input_context)
   else:
     return strategy.experimental_distribute_datasets_from_function(dataset)
예제 #5
0
 def _wrap_dataset(self,
                   input_type,
                   dataset,
                   input_workers,
                   split_batch_by,
                   strategy,
                   input_context=None):
   if input_type == "dataset":
     if tf2.enabled():
       return input_lib.DistributedDataset(
           dataset,
           input_workers,
           strategy,
           split_batch_by=split_batch_by,
           input_context=input_context)
     else:
       return input_lib.DistributedDatasetV1(
           dataset,
           input_workers,
           strategy,
           split_batch_by=split_batch_by,
           input_context=input_context)
   else:
     return strategy.experimental_distribute_datasets_from_function(dataset)