def testElementSpecEagerMode(self):
        cluster = data_service_test_base.TestCluster(num_workers=1,
                                                     work_dir=NO_WORK_DIR,
                                                     fault_tolerant_mode=False)
        num_elements = 10
        ds = dataset_ops.Dataset.range(num_elements)

        dataset_id = data_service_ops.register_dataset(
            cluster.dispatcher_address(), ds)
        ds = data_service_ops.from_dataset_id("parallel_epochs",
                                              cluster.dispatcher_address(),
                                              dataset_id)
        self.assertDatasetProduces(ds, list(range(num_elements)))
Esempio n. 2
0
 def testSharedJobNameMultiIteration(self):
   cluster = data_service_test_base.TestCluster(num_workers=1)
   num_elements = 10
   ds1 = self.make_distributed_range_dataset(
       num_elements, cluster, job_name="job_name")
   ds2 = self.make_distributed_range_dataset(
       num_elements, cluster, job_name="job_name")
   # iteration 1
   self.assertDatasetProduces(ds1, list(range(num_elements)))
   self.assertDatasetProduces(ds2, [])
   # iteration 2
   self.assertDatasetProduces(ds2, list(range(num_elements)))
   self.assertDatasetProduces(ds1, [])
Esempio n. 3
0
 def testZipDifferentProcessingModesDatasetsSharedJobName(self):
   cluster = data_service_test_base.TestCluster(num_workers=1)
   num_elements = 100
   ds1 = dataset_ops.Dataset.range(num_elements)
   ds1 = self.make_distributed_dataset(
       ds1, cluster, processing_mode="distributed_epoch", job_name="job_name")
   ds2 = dataset_ops.Dataset.range(num_elements)
   ds2 = self.make_distributed_dataset(
       ds2, cluster, processing_mode="parallel_epochs", job_name="job_name")
   ds = dataset_ops.Dataset.zip((ds1, ds2))
   with self.assertRaisesRegex(errors.FailedPreconditionError,
                               "but there is already an existing job"):
     self.getDatasetOutput(ds)
Esempio n. 4
0
    def testConcatenate(self, num_workers):
        cluster = data_service_test_base.TestCluster(num_workers=num_workers)
        a = dataset_ops.Dataset.range(100)
        b = dataset_ops.Dataset.range(100, 200)
        ds = a.concatenate(b)
        ds = self.make_distributed_dataset(ds,
                                           cluster,
                                           processing_mode="distributed_epoch")

        assert_items_equal = (num_workers > 1)
        self.assertDatasetProduces(ds,
                                   list(range(200)),
                                   assert_items_equal=assert_items_equal)
Esempio n. 5
0
 def testExplicitProtocolFromDatasetId(self):
   cluster = data_service_test_base.TestCluster(
       num_workers=1, data_transfer_protocol="grpc")
   range_ds = dataset_ops.Dataset.range(10)
   dataset_id = data_service_ops.register_dataset(cluster.dispatcher.target,
                                                  range_ds)
   ds = data_service_ops.from_dataset_id(
       dataset_id=dataset_id,
       processing_mode="parallel_epochs",
       element_spec=range_ds.element_spec,
       service=cluster.dispatcher.target,
       data_transfer_protocol="grpc")
   self.assertDatasetProduces(ds, list(range(10)))
    def testRegisteringDatasetAsTfFunction(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)

        num_elements = 10
        ds = dataset_ops.Dataset.range(num_elements)
        register_func = def_function.function(self.register_dataset)
        dataset_id = register_func(
            (constant_op.constant("grpc"),
             constant_op.constant(cluster.dispatcher_address())), ds)
        from_dataset_id_ds = self.from_dataset_id("parallel_epochs", cluster,
                                                  dataset_id, ds.element_spec)
        self.assertDatasetProduces(from_dataset_id_ds,
                                   list(range(num_elements)))
Esempio n. 7
0
    def testNestedZip(self):
        num_elements = 10
        cluster = data_service_test_base.TestCluster(num_workers=1)
        a = dataset_ops.Dataset.range(num_elements)

        ds = dataset_ops.Dataset.zip((a, a))
        ds = dataset_ops.Dataset.zip((a, a, ds, a))
        ds = self.make_distributed_dataset(ds,
                                           cluster,
                                           processing_mode="distributed_epoch")

        b = list(range(10))
        self.assertDatasetProduces(ds, list(zip(b, b, zip(b, b), b)))
Esempio n. 8
0
  def testSnapshot(self, already_written):
    num_workers = 3
    cluster = data_service_test_base.TestCluster(num_workers=num_workers)
    ds = dataset_ops.Dataset.range(100)
    ds = ds.snapshot(self.get_temp_dir())
    if already_written:
      # Materialize the snapshot.
      self.getDatasetOutput(ds)

    ds = self._make_dynamic_sharding_dataset(ds, cluster)
    error_regex = "Splitting is not implemented for snapshot datasets"
    with self.assertRaisesRegex(errors.UnimplementedError, error_regex):
      self.getDatasetOutput(ds)
Esempio n. 9
0
 def testGcUnusedJob(self, job_name):
     cluster = data_service_test_base.TestCluster(
         num_workers=1, job_gc_check_interval_ms=50, job_gc_timeout_ms=20)
     num_elements = 100
     ds = self.make_distributed_range_dataset(num_elements,
                                              cluster,
                                              job_name=job_name)
     it = iter(ds)
     self.assertEqual(next(it).numpy(), 0)
     self.assertEqual(cluster.workers[0].num_tasks(), 1)
     del it
     while cluster.workers[0].num_tasks() > 0:
         time.sleep(0.1)
Esempio n. 10
0
  def testImbalancedZipMultiWorker(self):
    smaller_num_elements = 200
    larger_num_elements = 1000
    cluster = data_service_test_base.TestCluster(num_workers=3)
    a = dataset_ops.Dataset.range(smaller_num_elements)
    b = dataset_ops.Dataset.range(larger_num_elements)

    ds = dataset_ops.Dataset.zip((a, b))
    ds = self._make_dynamic_sharding_dataset(ds, cluster)

    # Cannot assert specific elements because the range datasets are split
    # nondeterministically and may not line up.
    self.assertLen(self.getDatasetOutput(ds), smaller_num_elements)
Esempio n. 11
0
  def testChooseFromDatasets(self, num_workers):
    cluster = data_service_test_base.TestCluster(num_workers=num_workers)
    words = [b"foo", b"bar", b"baz"]
    datasets = [dataset_ops.Dataset.from_tensors(w).repeat() for w in words]
    choice_array = np.random.randint(3, size=(15,), dtype=np.int64)
    choice_dataset = dataset_ops.Dataset.from_tensor_slices(choice_array)
    ds = dataset_ops.Dataset.choose_from_datasets(datasets, choice_dataset)
    ds = self._make_dynamic_sharding_dataset(ds, cluster)
    expected = [words[i] for i in choice_array]

    assert_items_equal = (num_workers > 1)
    self.assertDatasetProduces(
        ds, expected, assert_items_equal=assert_items_equal)
Esempio n. 12
0
  def testImbalancedZip(self):
    smaller_num_elements = 200
    larger_num_elements = 1000

    cluster = data_service_test_base.TestCluster(num_workers=1)
    a = dataset_ops.Dataset.range(smaller_num_elements)
    b = dataset_ops.Dataset.range(larger_num_elements)

    ds = dataset_ops.Dataset.zip((a, b))
    ds = self._make_dynamic_sharding_dataset(ds, cluster)

    self.assertDatasetProduces(
        ds, list(zip(range(smaller_num_elements), range(smaller_num_elements))))
Esempio n. 13
0
    def testFlatMapWithRepeat(self):
        cluster = data_service_test_base.TestCluster(num_workers=3)
        ds = dataset_ops.Dataset.range(5)

        def flat_map_fn(_):
            return dataset_ops.Dataset.from_tensor_slices(["a", "b",
                                                           "c"]).repeat(10)

        ds = ds.flat_map(flat_map_fn)
        ds = self._make_dynamic_sharding_dataset(ds, cluster)

        self.assertDatasetProduces(ds, [b"a", b"b", b"c"] * 50,
                                   assert_items_equal=True)
  def testConsumerRestart(self):
    cluster = data_service_test_base.TestCluster(num_workers=1)
    num_consumers = 3
    ds = self.make_coordinated_read_dataset(cluster, num_consumers)
    get_next = self.getNext(ds, requires_initialization=False)
    _ = [self.evaluate(get_next()) for _ in range(20)]

    ds2 = self.make_coordinated_read_dataset(cluster, num_consumers)
    with self.assertRaisesRegex(errors.FailedPreconditionError,
                                "current round has already reached"):
      get_next_ds2 = self.getNext(ds2, requires_initialization=False)
      _ = [self.evaluate(get_next_ds2()) for _ in range(20)]
    cluster.stop_workers()
Esempio n. 15
0
    def testDispatcherRestartDuringReading(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        num_elements = 100
        ds = self.make_distributed_range_dataset(num_elements, cluster)
        iterator = iter(ds)
        results = []
        for _ in range(num_elements // 2):
            results.append(next(iterator).numpy())
        cluster.restart_dispatcher()
        for elem in iterator:
            results.append(elem.numpy())

        self.assertEqual(list(range(num_elements)), results)
Esempio n. 16
0
 def testDispatcherStop(self):
     cluster = data_service_test_base.TestCluster(num_workers=1)
     num_elements = 100
     ds = self.make_distributed_range_dataset(num_elements, cluster)
     iterator = iter(ds)
     results = []
     results.append(next(iterator).numpy())
     cluster.stop_dispatcher()
     # After the dispatcher dies, the worker should continue providing the rest
     # of the dataset's elements.
     for _ in range(num_elements - 1):
         results.append(next(iterator).numpy())
     self.assertEqual(results, list(range(num_elements)))
Esempio n. 17
0
 def testFromDatasetIdOmitsElementSpecAndCompression(self, compression):
     cluster = data_service_test_base.TestCluster(
         num_workers=1, data_transfer_protocol="grpc")
     dataset = dataset_ops.Dataset.from_tensor_slices(
         list("ABCDEFGHIJKLMNOPQRSTUVWXYZ"))
     dataset_id = data_service_ops.register_dataset(
         cluster.dispatcher.target,
         dataset=dataset,
         compression=compression)
     dataset = data_service_ops.from_dataset_id(
         processing_mode=data_service_ops.ShardingPolicy.OFF,
         service=cluster.dispatcher.target,
         dataset_id=dataset_id)
     self.assertDatasetProduces(dataset, list("ABCDEFGHIJKLMNOPQRSTUVWXYZ"))
Esempio n. 18
0
    def testFromDatasetIdWrongElementSpec(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)

        num_elements = 10
        ds = dataset_ops.Dataset.range(num_elements)
        dataset_id = data_service_ops.register_dataset(
            cluster.dispatcher_address(), ds)
        wrong_spec = tensor_spec.TensorSpec(shape=(), dtype=dtypes.variant)
        from_dataset_id_ds = data_service_ops.from_dataset_id(
            "parallel_epochs", cluster.dispatcher_address(), dataset_id,
            wrong_spec)
        with self.assertRaisesRegex(errors.FailedPreconditionError,
                                    "Expected a tensor of type variant"):
            self.evaluate(self.getNext(from_dataset_id_ds)())
Esempio n. 19
0
  def testVariables(self, use_resource):
    cluster = data_service_test_base.TestCluster(num_workers=1)
    if not use_resource:
      with variable_scope.variable_scope("foo", use_resource=False):
        v = variables.VariableV1(10, dtype=dtypes.int64)
    else:
      v = variables.Variable(10, dtype=dtypes.int64)

    ds = dataset_ops.Dataset.range(3)
    ds = ds.map(lambda x: x + v)
    ds = self.make_distributed_dataset(ds, cluster)
    self.evaluate(v.initializer)
    self.assertDatasetProduces(
        ds, list(range(10, 13)), requires_initialization=True)
 def testElementSpecGraphMode(self):
     cluster = data_service_test_base.TestCluster(num_workers=1,
                                                  work_dir=NO_WORK_DIR,
                                                  fault_tolerant_mode=False)
     num_elements = 10
     ds = dataset_ops.Dataset.range(num_elements)
     dataset_id = data_service_ops.register_dataset(
         cluster.dispatcher_address(), ds)
     with self.assertRaisesRegex(
             ValueError,
             "In graph mode element_spec must be provided manually."):
         ds = data_service_ops.from_dataset_id("parallel_epochs",
                                               cluster.dispatcher_address(),
                                               dataset_id)
Esempio n. 21
0
 def testZipDifferentProcessingModesDatasets(self):
   cluster = data_service_test_base.TestCluster(num_workers=1)
   num_elements = 100
   ds1 = dataset_ops.Dataset.range(num_elements)
   ds1 = self.make_distributed_dataset(
       ds1, cluster, processing_mode="distributed_epoch")
   ds2 = dataset_ops.Dataset.range(num_elements)
   ds2 = self.make_distributed_dataset(
       ds2, cluster, processing_mode="parallel_epochs")
   ds = dataset_ops.Dataset.zip((ds1, ds2))
   self.assertDatasetProduces(
       ds,
       list(zip(range(num_elements), range(num_elements))),
       assert_items_equal=True)
  def testFromDatasetIdMultipleComponents(self):
    cluster = data_service_test_base.TestCluster(num_workers=1)

    num_elements = 10
    ds = dataset_ops.Dataset.range(num_elements)
    ds = dataset_ops.Dataset.zip({"a": (ds, ds), "b": ds})
    dataset_id = self.register_dataset(cluster.dispatcher_address(), ds)
    from_dataset_id_ds = self.from_dataset_id("parallel_epochs", cluster,
                                              dataset_id, ds.element_spec)
    output = self.getDatasetOutput(from_dataset_id_ds)
    for i in range(num_elements):
      self.assertEqual(i, output[i]["a"][0])
      self.assertEqual(i, output[i]["a"][1])
      self.assertEqual(i, output[i]["b"])
Esempio n. 23
0
    def testCancellation(self):
        self.skipTest("b/162521601")
        sleep_microseconds = int(1e6) * 1000

        cluster = data_service_test_base.TestCluster(num_workers=1)
        # Create a dataset which produces the first element quickly, and the second
        # element slowly. Fetching the first element triggers prefetching of the
        # second element, which we should be able to cancel.
        slow = dataset_ops.Dataset.range(1)
        slow = slow.apply(testing.sleep(sleep_microseconds))
        ds = dataset_ops.Dataset.range(1).concatenate(slow)
        ds = self.make_distributed_dataset(ds, cluster)
        ds = ds.prefetch(1)
        get_next = self.getNext(ds)
        self.assertEqual(0, self.evaluate(get_next()))
  def testFromDatasetIdDoesntRequireElementSpec(self):
    cluster = data_service_test_base.TestCluster(
        num_workers=1,
        work_dir=NO_WORK_DIR,
        fault_tolerant_mode=False,
        data_transfer_protocol="grpc")
    num_elements = 10
    ds = dataset_ops.Dataset.range(num_elements)

    dataset_id = data_service_ops.register_dataset(cluster.dispatcher_address(),
                                                   ds)
    ds = data_service_ops.from_dataset_id("parallel_epochs",
                                          cluster.dispatcher_address(),
                                          dataset_id)
    self.assertDatasetProduces(ds, list(range(num_elements)))
 def testGcAndRecreate(self):
   cluster = data_service_test_base.TestCluster(
       num_workers=3, job_gc_check_interval_ms=50, job_gc_timeout_ms=20)
   num_elements = 1000
   # Repeatedly create and garbage-collect the same job.
   for _ in range(3):
     ds = self.make_distributed_range_dataset(
         num_elements, cluster, job_name="test")
     it = iter(ds)
     for _ in range(50):
       next(it)
     del it
     # Wait for the task to be garbage-collected on all workers.
     while cluster.num_tasks_on_workers() > 0:
       time.sleep(0.1)
Esempio n. 26
0
  def testResourceOnWrongDevice(self):
    cluster = data_service_test_base.TestCluster(num_workers=1)
    with ops.device(self._devices[0]):
      initializer = self.lookupTableInitializer("keyvaluetensor", [10, 11])
      table = lookup_ops.StaticHashTable(initializer, -1)
      self.evaluate(lookup_ops.tables_initializer())

    with ops.device(self._devices[1]):
      ds = dataset_ops.Dataset.range(3)
      ds = ds.map(table.lookup)
      with self.assertRaisesRegex(
          errors.FailedPreconditionError,
          "Serialization error while trying to register a dataset"):
        ds = self.make_distributed_dataset(ds, cluster)
        self.getDatasetOutput(ds, requires_initialization=True)
 def testForeverRepeat(self):
     cluster = data_service_test_base.TestCluster(num_workers=2)
     num_elements = 20
     elements_to_read = 1000
     ds = dataset_ops.Dataset.range(num_elements).repeat()
     ds = self._make_dynamic_sharding_dataset(ds, cluster)
     get_next = self.getNext(ds)
     results = {}
     for _ in range(elements_to_read):
         val = self.evaluate(get_next())
         if val not in results:
             results[val] = 0
         results[val] += 1
     for i in range(num_elements):
         self.assertGreater(results[i], elements_to_read / num_elements / 2)
Esempio n. 28
0
    def testFiniteV1(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        num_elements = 100
        ds = dataset_ops.Dataset.range(num_elements)
        ds = self.make_distributed_dataset(ds,
                                           cluster,
                                           job_name="test",
                                           consumer_index=0,
                                           num_consumers=1)

        with self.assertRaisesRegex(
                errors.FailedPreconditionError,
                "Encountered end of sequence on a "
                "round-robin read iterator"):
            self.getDatasetOutput(ds)
Esempio n. 29
0
  def testGroupByWindow(self):
    # Verify that split providers are not propagated into iterators created for
    # the reduce datasets created by the reduce_fn in group_by_window.
    cluster = data_service_test_base.TestCluster(num_workers=2)
    elements = [1, 5, 0]
    ds = dataset_ops.Dataset.from_tensor_slices(elements)

    def reduce_fn(_, window):
      return dataset_ops.Dataset.zip((window, dataset_ops.Dataset.range(100)))

    ds = ds.group_by_window(lambda x: 0, reduce_fn, window_size=3)
    ds = self._make_dynamic_sharding_dataset(ds, cluster)
    # This will fail if the tensor_slices split provider ispropagated into the
    # `reduce_fn`, since the `zip` requires either 0 or 2 split providers.
    self.getDatasetOutput(ds)
Esempio n. 30
0
 def _testCompressionMismatch(self, dataset):
     cluster = data_service_test_base.TestCluster(
         num_workers=1, data_transfer_protocol="grpc")
     with mock.patch.object(compat,
                            "forward_compatible",
                            return_value=False):
         dataset_id = data_service_ops._register_dataset(
             cluster.dispatcher.target, dataset=dataset, compression=None)
         # `compression` is "AUTO" by default.
         dataset = data_service_ops._from_dataset_id(
             processing_mode=data_service_ops.ShardingPolicy.OFF,
             service=cluster.dispatcher.target,
             dataset_id=dataset_id,
             element_spec=dataset.element_spec)
         with self.assertRaises(errors.InvalidArgumentError):
             self.getDatasetOutput(dataset)