Exemple #1
0
    def testFromDatasetId(self):
        """Tests cross-trainer cache with `register_dataset`/`from_dataset_id`."""
        cluster = self._create_cluster(num_workers=1)
        dataset = dataset_ops.Dataset.range(10000000).repeat()
        dataset_id1 = data_service_ops.register_dataset(
            cluster.dispatcher.target, dataset, dataset_id="dataset_id")
        dataset1 = data_service_ops.from_dataset_id(
            processing_mode=data_service_ops.ShardingPolicy.OFF,
            service=cluster.dispatcher.target,
            dataset_id=dataset_id1,
            element_spec=dataset.element_spec,
            job_name="job",
            cross_trainer_cache=data_service_ops.CrossTrainerCache(
                trainer_id="Trainer 1"))
        self.assertDatasetProduces(dataset1.take(10), list(range(10)))

        dataset_id2 = data_service_ops.register_dataset(
            cluster.dispatcher.target, dataset, dataset_id="dataset_id")
        dataset2 = data_service_ops.from_dataset_id(
            processing_mode=data_service_ops.ShardingPolicy.OFF,
            service=cluster.dispatcher.target,
            dataset_id=dataset_id2,
            element_spec=dataset.element_spec,
            job_name="job",
            cross_trainer_cache=data_service_ops.CrossTrainerCache(
                trainer_id="Trainer 2"))
        self.assertDatasetProduces(dataset2.take(10), list(range(10)))
 def testEmptyJobNameFromDatasetId(self):
     cluster = data_service_test_base.TestCluster(num_workers=1)
     dataset_id = data_service_ops.register_dataset(
         cluster.dispatcher.target, dataset_ops.Dataset.range(10))
     with self.assertRaisesRegex(ValueError, "job_name must not be empty"):
         data_service_ops.from_dataset_id(dataset_id=dataset_id,
                                          processing_mode="parallel_epochs",
                                          service=cluster.dispatcher.target,
                                          job_name="")
Exemple #3
0
 def testNonStringJobNameFromDatasetId(self):
     cluster = data_service_test_base.TestCluster(num_workers=1)
     dataset_id = data_service_ops.register_dataset(
         cluster.dispatcher.target, dataset_ops.Dataset.range(10))
     with self.assertRaisesRegex(ValueError, "`job_name` must be a string"):
         data_service_ops.from_dataset_id(
             dataset_id=dataset_id,
             processing_mode="parallel_epochs",
             service=cluster.dispatcher.target,
             job_name=constant_op.constant("foo"))
    def testElementSpecMixedMode(self):
        cluster = data_service_test_base.TestCluster(
            num_workers=1,
            work_dir=data_service_test_base.NO_WORK_DIR,
            fault_tolerant_mode=False)
        num_elements = 10
        dataset = dataset_ops.Dataset.range(num_elements)

        @def_function.function
        def get_dataset_id():
            return data_service_ops.register_dataset(
                cluster.dispatcher_address(), dataset)

        dataset_id = get_dataset_id()
        dataset_id_val = tensor_util.constant_value(dataset_id)

        with self.assertRaisesRegex(
                ValueError,
                f"Failed to fetch element spec for dataset id {dataset_id_val} from "
                "tf.data service. If the dataset was registered in graph mode or "
                "inside a tf.function, the `element_spec` must be specified as an "
                "argument to `from_dataset_id`."):
            dataset = data_service_ops.from_dataset_id(
                processing_mode=data_service_ops.ShardingPolicy.OFF,
                service=cluster.dispatcher_address(),
                dataset_id=dataset_id)
    def testFromDatasetIdOmitsCompression(self, compression):
        cluster = data_service_test_base.TestCluster(
            num_workers=1, data_transfer_protocol="grpc")
        dataset = dataset_ops.Dataset.from_tensor_slices(
            list("abcdefghijklmnopqrstuvwxyz"))

        def to_upper(x):
            return script_ops.numpy_function(
                func=lambda x: x.decode("utf-8").upper(),
                inp=[x],
                Tout=dtypes.string)

        dataset = dataset.map(to_upper,
                              num_parallel_calls=dataset_ops.AUTOTUNE)
        with mock.patch.object(compat, "forward_compatible",
                               return_value=True):
            dataset_id = data_service_ops.register_dataset(
                cluster.dispatcher.target,
                dataset=dataset,
                compression=compression)
            dataset = data_service_ops.from_dataset_id(
                processing_mode=data_service_ops.ShardingPolicy.OFF,
                service=cluster.dispatcher.target,
                dataset_id=dataset_id,
                element_spec=dataset.element_spec)
            self.assertDatasetProduces(dataset,
                                       list("ABCDEFGHIJKLMNOPQRSTUVWXYZ"))
    def testFromDatasetIdSharedJobs(self):
        cluster = data_service_test_base.TestCluster(num_workers=2)

        datasets = [
            dataset_ops.Dataset.range(20, output_type=dtypes.int32),
            dataset_ops.Dataset.from_tensor_slices(list(range(20, 40)))
        ]
        dataset_ids = [
            data_service_ops.register_dataset(cluster.dispatcher_address(), ds)
            for ds in datasets
        ]

        # Read from both jobs in parallel, with 2 consumers for each job.
        data_service_datasets = []
        for _ in range(2):
            for dataset, dataset_id in zip(datasets, dataset_ids):
                ds = data_service_ops.from_dataset_id(
                    "distributed_epoch",
                    cluster.dispatcher_address(),
                    dataset_id,
                    dataset.element_spec,
                    job_name="shared_job")
                data_service_datasets.append(ds)
        ds = dataset_ops.Dataset.from_tensor_slices(data_service_datasets)
        ds = ds.interleave(lambda x: x,
                           cycle_length=len(data_service_datasets))

        self.assertDatasetProduces(ds,
                                   list(range(40)),
                                   assert_items_equal=True)
Exemple #7
0
  def testFromDatasetId(self):
    num_elements = 10
    service = self.create_cluster(1)

    ds = dataset_ops.Dataset.range(num_elements)
    dataset_id = data_service_ops.register_dataset(service, ds)
    from_dataset_id_ds = data_service_ops.from_dataset_id(
        "parallel_epochs", service, dataset_id, ds.element_spec)
    self.assertDatasetProduces(from_dataset_id_ds, list(range(num_elements)))
    def testFromDatasetIdNotRegistered(self):
        dispatcher, workers = self.start_cluster(1)  # to avoid gcing workers, pylint: disable=unused-variable

        dataset_id = 0
        element_spec = tensor_spec.TensorSpec(shape=(), dtype=dtypes.variant)
        from_dataset_id_ds = data_service_ops.from_dataset_id(
            "parallel_epochs", dispatcher.target, dataset_id, element_spec)
        with self.assertRaisesRegex(errors.NotFoundError, "Dataset id"):
            self.evaluate(self.getNext(from_dataset_id_ds)())
  def testFromDatasetId(self):
    dispatcher, workers = self.start_cluster(1)  # to avoid gcing workers, pylint: disable=unused-variable

    num_elements = 10
    ds = dataset_ops.Dataset.range(num_elements)
    dataset_id = data_service_ops.register_dataset(dispatcher.target, ds)
    from_dataset_id_ds = data_service_ops.from_dataset_id(
        "parallel_epochs", dispatcher.target, dataset_id, ds.element_spec)
    self.assertDatasetProduces(from_dataset_id_ds, list(range(num_elements)))
    def testFromDatasetIdNotRegistered(self):
        cluster = self.create_cluster(num_workers=1)

        dataset_id = 0
        element_spec = tensor_spec.TensorSpec(shape=(), dtype=dtypes.variant)
        from_dataset_id_ds = data_service_ops.from_dataset_id(
            "parallel_epochs", cluster.target, dataset_id, element_spec)
        with self.assertRaisesRegex(errors.NotFoundError, "Dataset id"):
            self.evaluate(self.getNext(from_dataset_id_ds)())
  def testFromDatasetId(self):
    cluster = self.create_cluster(num_workers=1)

    num_elements = 10
    ds = dataset_ops.Dataset.range(num_elements)
    dataset_id = data_service_ops.register_dataset(cluster.target, ds)
    from_dataset_id_ds = data_service_ops.from_dataset_id(
        "parallel_epochs", cluster.target, dataset_id, ds.element_spec)
    self.assertDatasetProduces(from_dataset_id_ds, list(range(num_elements)))
Exemple #12
0
  def testFromDatasetId(self):
    cluster = data_service_test_base.TestCluster(num_workers=1)

    num_elements = 10
    ds = dataset_ops.Dataset.range(num_elements)
    dataset_id = data_service_ops.register_dataset(cluster.dispatcher_address(),
                                                   ds)
    from_dataset_id_ds = data_service_ops.from_dataset_id(
        "parallel_epochs", cluster.dispatcher_address(), dataset_id,
        ds.element_spec)
    self.assertDatasetProduces(from_dataset_id_ds, list(range(num_elements)))
  def testElementSpecEagerMode(self):
    cluster = data_service_test_base.TestCluster(
        num_workers=1, work_dir=NO_WORK_DIR, fault_tolerant_mode=False)
    num_elements = 10
    ds = dataset_ops.Dataset.range(num_elements)

    dataset_id = data_service_ops.register_dataset(cluster.dispatcher_address(),
                                                   ds)
    ds = data_service_ops.from_dataset_id("parallel_epochs",
                                          cluster.dispatcher_address(),
                                          dataset_id)
    self.assertDatasetProduces(ds, list(range(num_elements)))
    def testFromDatasetIdWrongElementSpec(self):
        cluster = self.create_cluster(num_workers=1)

        num_elements = 10
        ds = dataset_ops.Dataset.range(num_elements)
        dataset_id = data_service_ops.register_dataset(cluster.target, ds)
        wrong_spec = tensor_spec.TensorSpec(shape=(), dtype=dtypes.variant)
        from_dataset_id_ds = data_service_ops.from_dataset_id(
            "parallel_epochs", cluster.target, dataset_id, wrong_spec)
        with self.assertRaisesRegex(errors.FailedPreconditionError,
                                    "Expected a tensor of type variant"):
            self.evaluate(self.getNext(from_dataset_id_ds)())
 def testElementSpecGraphMode(self):
   cluster = data_service_test_base.TestCluster(
       num_workers=1, work_dir=NO_WORK_DIR, fault_tolerant_mode=False)
   num_elements = 10
   ds = dataset_ops.Dataset.range(num_elements)
   dataset_id = data_service_ops.register_dataset(cluster.dispatcher_address(),
                                                  ds)
   with self.assertRaisesRegex(
       ValueError, "In graph mode element_spec must be provided manually."):
     ds = data_service_ops.from_dataset_id("parallel_epochs",
                                           cluster.dispatcher_address(),
                                           dataset_id)
 def testFromDatasetIdCardinality(self, dataset_fn, sharding_policy,
                                  expected_result):
     cluster = data_service_test_base.TestCluster(num_workers=2)
     dataset = dataset_fn()
     dataset_id = data_service_ops.register_dataset(
         cluster.dispatcher.target, dataset=dataset)
     dataset = data_service_ops.from_dataset_id(
         processing_mode=sharding_policy,
         service=cluster.dispatcher.target,
         dataset_id=dataset_id,
         element_spec=dataset.element_spec)
     self.assertEqual(self.evaluate(dataset.cardinality()), expected_result)
    def testFromDatasetIdWrongElementSpec(self):
        dispatcher, workers = self.start_cluster(1)  # to avoid gcing workers, pylint: disable=unused-variable

        num_elements = 10
        ds = dataset_ops.Dataset.range(num_elements)
        dataset_id = data_service_ops.register_dataset(dispatcher.target, ds)
        wrong_spec = tensor_spec.TensorSpec(shape=(), dtype=dtypes.variant)
        from_dataset_id_ds = data_service_ops.from_dataset_id(
            "parallel_epochs", dispatcher.target, dataset_id, wrong_spec)
        with self.assertRaisesRegex(errors.FailedPreconditionError,
                                    "Expected a tensor of type variant"):
            self.evaluate(self.getNext(from_dataset_id_ds)())
 def testExplicitProtocolFromDatasetId(self):
   cluster = data_service_test_base.TestCluster(num_workers=1)
   range_ds = dataset_ops.Dataset.range(10)
   dataset_id = data_service_ops.register_dataset(cluster.dispatcher.target,
                                                  range_ds)
   ds = data_service_ops.from_dataset_id(
       dataset_id=dataset_id,
       processing_mode="parallel_epochs",
       element_spec=range_ds.element_spec,
       service=cluster.dispatcher.target,
       data_transfer_protocol="grpc")
   self.assertDatasetProduces(ds, list(range(10)))
Exemple #19
0
  def testRegisteringDatasetAsTfFunction(self):
    cluster = data_service_test_base.TestCluster(num_workers=1)

    num_elements = 10
    ds = dataset_ops.Dataset.range(num_elements)
    register_func = def_function.function(data_service_ops.register_dataset)
    dataset_id = register_func(
        (constant_op.constant("grpc"),
         constant_op.constant(cluster.dispatcher_address())), ds)
    from_dataset_id_ds = data_service_ops.from_dataset_id(
        "parallel_epochs", cluster.dispatcher_address(), dataset_id,
        ds.element_spec)
    self.assertDatasetProduces(from_dataset_id_ds, list(range(num_elements)))
    def testReadDatasetOnDifferentDevices(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        num_elements = 10
        with ops.device(self._devices[0]):
            dataset = dataset_ops.Dataset.range(num_elements)
            element_spec = dataset.element_spec
            dataset_id = data_service_ops.register_dataset(
                cluster.dispatcher_address(), dataset)
            dataset = data_service_ops.from_dataset_id(
                processing_mode=data_service_ops.ShardingPolicy.OFF,
                service=cluster.dispatcher_address(),
                dataset_id=dataset_id,
                element_spec=element_spec)
            self.assertDatasetProduces(dataset, list(range(num_elements)))

        with ops.device(self._devices[1]):
            dataset = data_service_ops.from_dataset_id(
                processing_mode=data_service_ops.ShardingPolicy.OFF,
                service=cluster.dispatcher_address(),
                dataset_id=dataset_id,
                element_spec=dataset.element_spec)
            self.assertDatasetProduces(dataset, list(range(num_elements)))
Exemple #21
0
 def from_dataset_id(self,
                     processing_mode,
                     cluster,
                     dataset_id,
                     element_spec,
                     job_name=None):
     return data_service_ops.from_dataset_id(
         processing_mode,
         cluster.dispatcher_address(),
         dataset_id,
         element_spec,
         data_transfer_protocol=TRANSFER_PROTOCOL.value,
         job_name=job_name)
Exemple #22
0
 def testFromDatasetIdOmitsElementSpecAndCompression(self, compression):
     cluster = data_service_test_base.TestCluster(
         num_workers=1, data_transfer_protocol="grpc")
     dataset = dataset_ops.Dataset.from_tensor_slices(
         list("ABCDEFGHIJKLMNOPQRSTUVWXYZ"))
     dataset_id = data_service_ops.register_dataset(
         cluster.dispatcher.target,
         dataset=dataset,
         compression=compression)
     dataset = data_service_ops.from_dataset_id(
         processing_mode=data_service_ops.ShardingPolicy.OFF,
         service=cluster.dispatcher.target,
         dataset_id=dataset_id)
     self.assertDatasetProduces(dataset, list("ABCDEFGHIJKLMNOPQRSTUVWXYZ"))
    def testFromDatasetIdMultipleComponents(self):
        dispatcher, workers = self.start_cluster(1)  # to avoid gcing workers, pylint: disable=unused-variable

        num_elements = 10
        ds = dataset_ops.Dataset.range(num_elements)
        ds = dataset_ops.Dataset.zip({"a": (ds, ds), "b": ds})
        dataset_id = data_service_ops.register_dataset(dispatcher.target, ds)
        from_dataset_id_ds = data_service_ops.from_dataset_id(
            "parallel_epochs", dispatcher.target, dataset_id, ds.element_spec)
        output = self.getDatasetOutput(from_dataset_id_ds)
        for i in range(num_elements):
            self.assertEqual(i, output[i]["a"][0])
            self.assertEqual(i, output[i]["a"][1])
            self.assertEqual(i, output[i]["b"])
    def testFromDatasetIdMultipleComponents(self):
        cluster = self.create_cluster(num_workers=1)

        num_elements = 10
        ds = dataset_ops.Dataset.range(num_elements)
        ds = dataset_ops.Dataset.zip({"a": (ds, ds), "b": ds})
        dataset_id = data_service_ops.register_dataset(cluster.target, ds)
        from_dataset_id_ds = data_service_ops.from_dataset_id(
            "parallel_epochs", cluster.target, dataset_id, ds.element_spec)
        output = self.getDatasetOutput(from_dataset_id_ds)
        for i in range(num_elements):
            self.assertEqual(i, output[i]["a"][0])
            self.assertEqual(i, output[i]["a"][1])
            self.assertEqual(i, output[i]["b"])
  def testFromDatasetIdDoesntRequireElementSpec(self):
    cluster = data_service_test_base.TestCluster(
        num_workers=1,
        work_dir=NO_WORK_DIR,
        fault_tolerant_mode=False,
        data_transfer_protocol="grpc")
    num_elements = 10
    ds = dataset_ops.Dataset.range(num_elements)

    dataset_id = data_service_ops.register_dataset(cluster.dispatcher_address(),
                                                   ds)
    ds = data_service_ops.from_dataset_id("parallel_epochs",
                                          cluster.dispatcher_address(),
                                          dataset_id)
    self.assertDatasetProduces(ds, list(range(num_elements)))
    def testFromDatasetIdDoesntRequireElementSpec(self):
        cluster = data_service_test_base.TestCluster(
            num_workers=1,
            work_dir=data_service_test_base.NO_WORK_DIR,
            fault_tolerant_mode=False,
            data_transfer_protocol="grpc")
        num_elements = 10
        dataset = dataset_ops.Dataset.range(num_elements)

        dataset_id = data_service_ops.register_dataset(
            cluster.dispatcher_address(), dataset)
        dataset = data_service_ops.from_dataset_id(
            processing_mode=data_service_ops.ShardingPolicy.OFF,
            service=cluster.dispatcher_address(),
            dataset_id=dataset_id)
        self.assertDatasetProduces(dataset, list(range(num_elements)))
 def testElementSpecGraphMode(self):
     cluster = data_service_test_base.TestCluster(
         num_workers=1,
         work_dir=data_service_test_base.NO_WORK_DIR,
         fault_tolerant_mode=False)
     num_elements = 10
     dataset = dataset_ops.Dataset.range(num_elements)
     dataset_id = data_service_ops.register_dataset(
         cluster.dispatcher_address(), dataset)
     with self.assertRaisesRegex(
             ValueError,
             "In graph mode `element_spec` must be provided manually."):
         _ = data_service_ops.from_dataset_id(
             processing_mode=data_service_ops.ShardingPolicy.OFF,
             service=cluster.dispatcher_address(),
             dataset_id=dataset_id)
 def testFromDatasetIdOmitsElementSpecAndCompression(self, compression):
     cluster = data_service_test_base.TestCluster(num_workers=1)
     dataset = dataset_ops.Dataset.from_tensor_slices(
         list("ABCDEFGHIJKLMNOPQRSTUVWXYZ"))
     with mock.patch.object(compat, "forward_compatible",
                            return_value=True):
         dataset_id = data_service_ops.register_dataset(
             cluster.dispatcher.target,
             dataset=dataset,
             compression=compression)
         dataset = data_service_ops.from_dataset_id(
             processing_mode=ShardingPolicy.OFF,
             service=cluster.dispatcher.target,
             dataset_id=dataset_id)
         self.assertDatasetProduces(dataset,
                                    list("ABCDEFGHIJKLMNOPQRSTUVWXYZ"))
  def testElementSpecMixedMode(self):
    cluster = data_service_test_base.TestCluster(
        num_workers=1, work_dir=NO_WORK_DIR, fault_tolerant_mode=False)
    num_elements = 10
    ds = dataset_ops.Dataset.range(num_elements)

    @def_function.function
    def get_dataset_id():
      return data_service_ops.register_dataset(cluster.dispatcher_address(), ds)

    dataset_id = get_dataset_id()
    dataset_id_val = tensor_util.constant_value(dataset_id)

    with self.assertRaisesRegex(
        ValueError, "Failed to fetch element spec for dataset id " +
        str(dataset_id_val) + " from tf.data service. If the "
        "dataset was registered in graph mode or inside a "
        "tf.function, the `element_spec` must be specified as "
        "an argument to `from_dataset_id`."):
      ds = data_service_ops.from_dataset_id("parallel_epochs",
                                            cluster.dispatcher_address(),
                                            dataset_id)