def testFromDatasetIdMultipleComponents(self):
    cluster = self.create_cluster(num_workers=1)

    num_elements = 10
    ds = dataset_ops.Dataset.range(num_elements)
    ds = dataset_ops.Dataset.zip({"a": (ds, ds), "b": ds})
    dataset_id = data_service_ops.register_dataset(cluster.target, ds)
    from_dataset_id_ds = data_service_ops.from_dataset_id(
        "parallel_epochs", cluster.target, dataset_id, ds.element_spec)
    output = self.getDatasetOutput(from_dataset_id_ds)
    for i in range(num_elements):
      self.assertEqual(i, output[i]["a"][0])
      self.assertEqual(i, output[i]["a"][1])
      self.assertEqual(i, output[i]["b"])
Exemplo n.º 2
0
    def testFromDatasetIdWrongElementSpec(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)

        num_elements = 10
        ds = dataset_ops.Dataset.range(num_elements)
        dataset_id = data_service_ops.register_dataset(
            cluster.dispatcher_address(), ds)
        wrong_spec = tensor_spec.TensorSpec(shape=(), dtype=dtypes.variant)
        from_dataset_id_ds = data_service_ops.from_dataset_id(
            "parallel_epochs", cluster.dispatcher_address(), dataset_id,
            wrong_spec)
        with self.assertRaisesRegex(errors.FailedPreconditionError,
                                    "Expected a tensor of type variant"):
            self.evaluate(self.getNext(from_dataset_id_ds)())
Exemplo n.º 3
0
    def testFromDatasetIdMultipleComponents(self):
        dispatcher, workers = self.start_cluster(1)  # to avoid gcing workers, pylint: disable=unused-variable

        num_elements = 10
        ds = dataset_ops.Dataset.range(num_elements)
        ds = dataset_ops.Dataset.zip({"a": (ds, ds), "b": ds})
        dataset_id = data_service_ops.register_dataset(dispatcher.target, ds)
        from_dataset_id_ds = data_service_ops.from_dataset_id(
            "parallel_epochs", dispatcher.target, dataset_id, ds.element_spec)
        output = self.getDatasetOutput(from_dataset_id_ds)
        for i in range(num_elements):
            self.assertEqual(i, output[i]["a"][0])
            self.assertEqual(i, output[i]["a"][1])
            self.assertEqual(i, output[i]["b"])
Exemplo n.º 4
0
 def testElementSpecGraphMode(self):
     cluster = data_service_test_base.TestCluster(num_workers=1,
                                                  work_dir=NO_WORK_DIR,
                                                  fault_tolerant_mode=False)
     num_elements = 10
     ds = dataset_ops.Dataset.range(num_elements)
     dataset_id = data_service_ops.register_dataset(
         cluster.dispatcher_address(), ds)
     with self.assertRaisesRegex(
             ValueError,
             "In graph mode element_spec must be provided manually."):
         ds = data_service_ops.from_dataset_id("parallel_epochs",
                                               cluster.dispatcher_address(),
                                               dataset_id)
Exemplo n.º 5
0
 def testFromDatasetIdOmitsElementSpecAndCompression(self, compression):
     cluster = data_service_test_base.TestCluster(
         num_workers=1, data_transfer_protocol="grpc")
     dataset = dataset_ops.Dataset.from_tensor_slices(
         list("ABCDEFGHIJKLMNOPQRSTUVWXYZ"))
     dataset_id = data_service_ops.register_dataset(
         cluster.dispatcher.target,
         dataset=dataset,
         compression=compression)
     dataset = data_service_ops.from_dataset_id(
         processing_mode=data_service_ops.ShardingPolicy.OFF,
         service=cluster.dispatcher.target,
         dataset_id=dataset_id)
     self.assertDatasetProduces(dataset, list("ABCDEFGHIJKLMNOPQRSTUVWXYZ"))
  def testFromDatasetIdDoesntRequireElementSpec(self):
    cluster = data_service_test_base.TestCluster(
        num_workers=1,
        work_dir=NO_WORK_DIR,
        fault_tolerant_mode=False,
        data_transfer_protocol="grpc")
    num_elements = 10
    ds = dataset_ops.Dataset.range(num_elements)

    dataset_id = data_service_ops.register_dataset(cluster.dispatcher_address(),
                                                   ds)
    ds = data_service_ops.from_dataset_id("parallel_epochs",
                                          cluster.dispatcher_address(),
                                          dataset_id)
    self.assertDatasetProduces(ds, list(range(num_elements)))
Exemplo n.º 7
0
 def testElementSpecGraphMode(self):
     cluster = data_service_test_base.TestCluster(
         num_workers=1,
         work_dir=data_service_test_base.NO_WORK_DIR,
         fault_tolerant_mode=False)
     num_elements = 10
     dataset = dataset_ops.Dataset.range(num_elements)
     dataset_id = data_service_ops.register_dataset(
         cluster.dispatcher_address(), dataset)
     with self.assertRaisesRegex(
             ValueError,
             "In graph mode `element_spec` must be provided manually."):
         _ = data_service_ops.from_dataset_id(
             processing_mode=data_service_ops.ShardingPolicy.OFF,
             service=cluster.dispatcher_address(),
             dataset_id=dataset_id)
Exemplo n.º 8
0
    def testFromDatasetIdDoesntRequireElementSpec(self):
        cluster = data_service_test_base.TestCluster(
            num_workers=1,
            work_dir=data_service_test_base.NO_WORK_DIR,
            fault_tolerant_mode=False,
            data_transfer_protocol="grpc")
        num_elements = 10
        dataset = dataset_ops.Dataset.range(num_elements)

        dataset_id = data_service_ops.register_dataset(
            cluster.dispatcher_address(), dataset)
        dataset = data_service_ops.from_dataset_id(
            processing_mode=data_service_ops.ShardingPolicy.OFF,
            service=cluster.dispatcher_address(),
            dataset_id=dataset_id)
        self.assertDatasetProduces(dataset, list(range(num_elements)))
 def testFromDatasetIdOmitsElementSpecAndCompression(self, compression):
     cluster = data_service_test_base.TestCluster(num_workers=1)
     dataset = dataset_ops.Dataset.from_tensor_slices(
         list("ABCDEFGHIJKLMNOPQRSTUVWXYZ"))
     with mock.patch.object(compat, "forward_compatible",
                            return_value=True):
         dataset_id = data_service_ops.register_dataset(
             cluster.dispatcher.target,
             dataset=dataset,
             compression=compression)
         dataset = data_service_ops.from_dataset_id(
             processing_mode=ShardingPolicy.OFF,
             service=cluster.dispatcher.target,
             dataset_id=dataset_id)
         self.assertDatasetProduces(dataset,
                                    list("ABCDEFGHIJKLMNOPQRSTUVWXYZ"))
Exemplo n.º 10
0
    def testReadDatasetOnDifferentDevices(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        num_elements = 10
        with ops.device(self._devices[0]):
            dataset = dataset_ops.Dataset.range(num_elements)
            element_spec = dataset.element_spec
            dataset_id = data_service_ops.register_dataset(
                cluster.dispatcher_address(), dataset)
            dataset = data_service_ops.from_dataset_id(
                processing_mode=data_service_ops.ShardingPolicy.OFF,
                service=cluster.dispatcher_address(),
                dataset_id=dataset_id,
                element_spec=element_spec)
            self.assertDatasetProduces(dataset, list(range(num_elements)))

        with ops.device(self._devices[1]):
            dataset = data_service_ops.from_dataset_id(
                processing_mode=data_service_ops.ShardingPolicy.OFF,
                service=cluster.dispatcher_address(),
                dataset_id=dataset_id,
                element_spec=dataset.element_spec)
            self.assertDatasetProduces(dataset, list(range(num_elements)))
Exemplo n.º 11
0
    def testFromDatasetIdOmitsCompression(self, compression):
        cluster = data_service_test_base.TestCluster(
            num_workers=1, data_transfer_protocol="grpc")
        dataset = dataset_ops.Dataset.from_tensor_slices(
            list("abcdefghijklmnopqrstuvwxyz"))

        def to_upper(x):
            return script_ops.numpy_function(
                func=lambda x: x.decode("utf-8").upper(),
                inp=[x],
                Tout=dtypes.string)

        dataset = dataset.map(to_upper,
                              num_parallel_calls=dataset_ops.AUTOTUNE)
        dataset_id = data_service_ops.register_dataset(
            cluster.dispatcher.target,
            dataset=dataset,
            compression=compression)
        dataset = data_service_ops.from_dataset_id(
            processing_mode=data_service_ops.ShardingPolicy.OFF,
            service=cluster.dispatcher.target,
            dataset_id=dataset_id,
            element_spec=dataset.element_spec)
        self.assertDatasetProduces(dataset, list("ABCDEFGHIJKLMNOPQRSTUVWXYZ"))
Exemplo n.º 12
0
 def get_dataset_id():
     return data_service_ops.register_dataset(
         cluster.dispatcher_address(), dataset)