def testStartServersLate(self):
        # Test that the data service client performs retries instead of failing when
        # the dataset is created before the master and worker are started.
        try:
            import portpicker  # pylint: disable=g-import-not-at-top
            dispatcher_port = portpicker.pick_unused_port()
        except:
            raise self.skipTest(
                "Flakes in portpicker library do not represent "
                "TensorFlow errors.")
        cluster = data_service_test_base.TestCluster(
            num_workers=1, dispatcher_port=dispatcher_port, start=False)

        def start_servers():
            time.sleep(0.5)
            cluster.start_dispatcher()
            cluster.start_workers()

        start_servers_thread = threading.Thread(target=start_servers,
                                                daemon=True)
        start_servers_thread.start()

        num_elements = 10
        ds = self.make_distributed_range_dataset(num_elements, cluster)
        results = [elem.numpy() for elem in ds]
        self.assertEqual(list(range(num_elements)), results)
        start_servers_thread.join()
    def testRoundRobinRestartWorker(self):
        num_workers = 3
        # Set a shutdown quiet period to prevent workers from shutting down partway
        # through a round.
        cluster = data_service_test_base.TestCluster(
            num_workers, worker_shutdown_quiet_period_ms=2000)
        # Round robin reads can cause slow cluster shutdown.
        data_service_test_base.GLOBAL_CLUSTERS.add(cluster)
        num_consumers = 5
        ds = self.make_round_robin_dataset(cluster, num_consumers)

        get_next = self.getNext(ds, requires_initialization=True)
        results = []

        self.read(get_next, results, 20)
        cluster.workers[1].stop()
        # Check that we can continue to read even with a worker stopped.
        self.read(get_next, results, 20)
        cluster.workers[1].restart()
        # Read until we get results from the restarted worker, then read some more.
        while results[-1] != 0:
            results.append(self.evaluate(get_next()))
        self.read(get_next, results, 20)

        self.checkRoundRobinGroups(results, num_consumers)
Esempio n. 3
0
 def testDistributeDistributedEpochTensorSlices(self):
   cluster = data_service_test_base.TestCluster(num_workers=2)
   vals = [5, 1, 2, 4]
   ds = dataset_ops.Dataset.from_tensor_slices(vals)
   ds = self.make_distributed_dataset(
       ds, cluster, processing_mode="distributed_epoch")
   self.assertDatasetProduces(ds, vals, assert_items_equal=True)
    def testSharedJobName(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        num_elements = 1000

        def make_ds():
            return dataset_ops.Dataset.range(num_elements).shuffle(
                num_elements)

        ds1 = self.make_distributed_dataset(make_ds(),
                                            cluster,
                                            job_name="job_name")
        ds2 = self.make_distributed_dataset(make_ds(),
                                            cluster,
                                            job_name="job_name")
        iter1 = iter(ds1)
        iter2 = iter(ds2)
        results = []
        for _ in range(num_elements // 5):
            results.append(next(iter1).numpy())
            results.append(next(iter2).numpy())
        for elem in iter1:
            results.append(elem.numpy())
        for elem in iter2:
            results.append(elem.numpy())
        self.assertCountEqual(list(range(num_elements)), results)
 def testSharedJobNameRepeat(self):
     cluster = data_service_test_base.TestCluster(num_workers=1)
     num_elements = 100
     num_repetitions = 3
     ds1 = self.make_distributed_range_dataset(num_elements,
                                               cluster,
                                               job_name="job_name")
     ds1 = ds1.repeat(num_repetitions)
     ds2 = self.make_distributed_range_dataset(num_elements,
                                               cluster,
                                               job_name="job_name")
     ds2 = ds2.repeat(num_repetitions)
     results = []
     iter1 = iter(ds1)
     iter2 = iter(ds2)
     for _ in range((num_elements * num_repetitions) // 5):
         results.append(next(iter1).numpy())
     for _ in range((num_elements * num_repetitions) // 5):
         results.append(next(iter2).numpy())
     for elem in iter1:
         results.append(elem.numpy())
     for elem in iter2:
         results.append(elem.numpy())
     self.assertCountEqual(num_repetitions * list(range(num_elements)),
                           results)
Esempio n. 6
0
 def testRegisterDifferentDatasets(self):
   ds_1 = dataset_ops.Dataset.range(10)
   ds_2 = dataset_ops.Dataset.range(20)
   cluster = data_service_test_base.TestCluster(num_workers=1)
   id_1 = data_service_ops.register_dataset(cluster.dispatcher_address(), ds_1)
   id_2 = data_service_ops.register_dataset(cluster.dispatcher_address(), ds_2)
   self.assertNotEqual(self.evaluate(id_1), self.evaluate(id_2))
    def testRestartWorker(self, use_same_port, work_dir, fault_tolerant_mode):
        cluster = data_service_test_base.TestCluster(
            num_workers=1,
            work_dir=work_dir,
            fault_tolerant_mode=fault_tolerant_mode)
        num_elements = 100
        ds = self.make_distributed_range_dataset(num_elements, cluster)
        iterator = iter(ds)
        # Read halfway through the dataset.
        midpoint = num_elements // 2
        for i in range(midpoint):
            self.assertEqual(i, next(iterator).numpy())

        # Stop the original worker and start a new one.
        cluster.workers[0].restart(use_same_port=use_same_port)

        # There may have been some elements prefetched from the first worker
        # before it was stopped.
        while True:
            val = next(iterator).numpy()
            if val == 0:
                break

        # The dataset starts over now that we read from the new worker.
        # TODO(b/157086991): Iterate until end of sequence when we support
        # detecting lost workers.
        for i in range(1, num_elements // 2):
            val = next(iterator).numpy()
            self.assertEqual(i, val)
 def testDispatcherRestartBetweenIterations(self):
     cluster = data_service_test_base.TestCluster(num_workers=1)
     num_elements = 100
     ds = self.make_distributed_range_dataset(100, cluster)
     self.assertDatasetProduces(ds, list(range(num_elements)))
     cluster.restart_dispatcher()
     self.assertDatasetProduces(ds, list(range(num_elements)))
 def testDistributeCompression(self, compression):
     cluster = data_service_test_base.TestCluster(num_workers=1)
     num_elements = 10
     ds = self.make_distributed_range_dataset(num_elements,
                                              cluster,
                                              compression=compression)
     self.assertDatasetProduces(ds, list(range(num_elements)))
Esempio n. 10
0
 def testRegisterDifferentDatasets(self):
     ds_1 = dataset_ops.Dataset.range(10)
     ds_2 = dataset_ops.Dataset.range(20)
     cluster = data_service_test_base.TestCluster(num_workers=1)
     id_1 = data_service_ops.register_dataset(cluster.target, ds_1)
     id_2 = data_service_ops.register_dataset(cluster.target, ds_2)
     self.assertNotEqual(id_1.numpy(), id_2.numpy())
    def testApplyDeterminismOption(self):
        elements = list(range(10))
        cluster = data_service_test_base.TestCluster(num_workers=1)

        def dataset_fn(delay_ms):
            def interleave_fn(x):
                ds = dataset_ops.Dataset.from_tensors(x)
                if math_ops.equal(x, 0):
                    ds = ds.apply(testing.sleep(delay_ms * 1000))
                else:
                    ds = ds.apply(testing.sleep(0))
                return ds

            ds = dataset_ops.Dataset.from_tensor_slices(elements)
            ds = ds.interleave(interleave_fn,
                               cycle_length=10,
                               num_parallel_calls=10)
            opts = dataset_ops.Options()
            opts.experimental_deterministic = False
            ds = ds.with_options(opts)
            ds = self.make_distributed_dataset(ds, cluster)
            return ds

        self.checkDeterminism(dataset_fn=dataset_fn,
                              expect_determinism=False,
                              expected_elements=elements)
Esempio n. 12
0
 def testMultiWorker(self):
   num_workers = 3
   cluster = data_service_test_base.TestCluster(num_workers=num_workers)
   num_elements = 10
   ds = self.make_distributed_range_dataset(num_elements, cluster)
   self.assertDatasetProduces(
       ds, num_workers * list(range(num_elements)), assert_items_equal=True)
    def testRoundRobinAddWorkers(self, workers_to_add):
        starting_workers = 3
        cluster = data_service_test_base.TestCluster(
            num_workers=starting_workers)
        # Round robin reads can cause slow cluster shutdown.
        data_service_test_base.GLOBAL_CLUSTERS.add(cluster)
        num_consumers = 7
        ds = self.make_round_robin_dataset(cluster, num_consumers)

        get_next = self.getNext(ds, requires_initialization=True)
        results = []
        zeros_seen = 0
        for _ in range(25):
            results.append(self.evaluate(get_next()))
            if results[-1] == 0:
                zeros_seen += 1
        for _ in range(workers_to_add):
            cluster.add_worker()
        # Read until all new workers have joined.
        while zeros_seen < starting_workers + workers_to_add:
            results.append(self.evaluate(get_next()))
            if results[-1] == 0:
                zeros_seen += 1
        # Read some more.
        for _ in range(25):
            results.append(self.evaluate(get_next()))

        self.checkRoundRobinGroups(results, num_consumers)
    def testRoundRobinMultiStartStop(self):
        num_workers = 3
        # Set a shutdown quiet period to prevent workers from shutting down partway
        # through a round.
        cluster = data_service_test_base.TestCluster(
            num_workers, worker_shutdown_quiet_period_ms=2000)
        # Round robin reads can cause slow cluster shutdown.
        data_service_test_base.GLOBAL_CLUSTERS.add(cluster)
        num_consumers = 5
        ds = self.make_round_robin_dataset(cluster, num_consumers)

        get_next = self.getNext(ds, requires_initialization=True)
        results = []

        self.read(get_next, results, 20)
        for i in range(num_workers):
            cluster.workers[i].stop()
            self.read(get_next, results, 20)
            cluster.workers[i].restart()
            self.read(get_next, results, 20)

        cluster.add_worker()
        cluster.restart_dispatcher()
        for i in range(num_workers):
            cluster.workers[i].stop()
        self.read(get_next, results, 20)

        self.checkRoundRobinGroups(results, num_consumers)
Esempio n. 15
0
 def testMultiWorker(self):
     num_workers = 3
     cluster = data_service_test_base.TestCluster(num_workers=num_workers)
     num_elements = 10
     ds = self.make_distributed_range_dataset(num_elements, cluster)
     results = [elem.numpy() for elem in ds]
     self.assertCountEqual(num_workers * list(range(num_elements)), results)
Esempio n. 16
0
 def testMultipleEpochs(self):
     cluster = data_service_test_base.TestCluster(num_workers=1)
     num_elements = 3
     ds = self.make_distributed_range_dataset(num_elements, cluster)
     for _ in range(10):
         self.assertEqual(list(range(num_elements)),
                          [elem.numpy() for elem in ds])
Esempio n. 17
0
 def testDistributeBasic(self, work_dir, fault_tolerant_mode):
     cluster = data_service_test_base.TestCluster(
         num_workers=1,
         work_dir=work_dir,
         fault_tolerant_mode=fault_tolerant_mode)
     num_elements = 10
     ds = self.make_distributed_range_dataset(num_elements, cluster)
     self.assertDatasetProduces(ds, list(range(num_elements)))
 def testDistributeExplicitProtocol(self):
     cluster = data_service_test_base.TestCluster(num_workers=1)
     ds = dataset_ops.Dataset.range(10)
     ds = ds.apply(
         data_service_ops.distribute(processing_mode="parallel_epochs",
                                     service="grpc://" +
                                     cluster.dispatcher_address()))
     self.assertDatasetProduces(ds, list(range(10)))
Esempio n. 19
0
 def testDistributeCompression(self, compression):
     cluster = data_service_test_base.TestCluster(num_workers=1)
     num_elements = 10
     ds = self.make_distributed_range_dataset(num_elements,
                                              cluster,
                                              compression=compression)
     results = [elem.numpy() for elem in ds]
     self.assertEqual(list(range(num_elements)), results)
Esempio n. 20
0
 def testDistributeDistributedEpoch(self):
   cluster = data_service_test_base.TestCluster(num_workers=2)
   num_elements = 100
   ds = dataset_ops.Dataset.range(num_elements)
   ds = self.make_distributed_dataset(
       ds, cluster, processing_mode="distributed_epoch")
   self.assertDatasetProduces(
       ds, list(range(num_elements)), assert_items_equal=True)
Esempio n. 21
0
 def testDistributeLargeGraph(self):
   cluster = data_service_test_base.TestCluster(
       num_workers=1, work_dir=NO_WORK_DIR, fault_tolerant_mode=False)
   # Larger than default OSS grpc message size limit of 4MB.
   tensor = array_ops.ones((2, 1000, 1000), dtype=dtypes.float32)
   ds = dataset_ops.Dataset.from_tensors(tensor)
   ds = self.make_distributed_dataset(ds, cluster)
   self.assertDatasetProduces(ds, [tensor])
Esempio n. 22
0
  def testDistributedEpochOnDistributedDataset(self):
    cluster_1 = data_service_test_base.TestCluster(num_workers=1)
    cluster_2 = data_service_test_base.TestCluster(num_workers=1)
    num_sizes = 10
    size_repeats = 5
    numbers = [1 * i for i in range(num_sizes)] * size_repeats
    ds = dataset_ops.Dataset.from_tensor_slices(numbers)
    ds = self.make_distributed_dataset(
        ds, cluster_1, processing_mode="parallel_epochs")
    ds = ds.map(lambda x: x + 1)
    ds = self.make_distributed_dataset(
        ds, cluster_2, processing_mode="distributed_epoch")

    error_regex = "Cannot create a split provider for dataset " + \
        "of type DataServiceDataset"
    with self.assertRaisesRegex(errors.UnimplementedError, error_regex):
      self.getDatasetOutput(ds)
Esempio n. 23
0
 def testMaxOutstandingRequests(self):
   num_workers = 3
   cluster = data_service_test_base.TestCluster(num_workers=num_workers)
   num_elements = 10
   ds = self.make_distributed_range_dataset(
       num_elements, cluster, max_outstanding_requests=1)
   self.assertDatasetProduces(
       ds, num_workers * list(range(num_elements)), assert_items_equal=True)
Esempio n. 24
0
 def testDistributeDistributedEpochFlatMap(self):
   cluster = data_service_test_base.TestCluster(num_workers=2)
   elements = [1, 5, 0]
   ds = dataset_ops.Dataset.from_tensor_slices(elements)
   ds = ds.flat_map(lambda x: dataset_ops.Dataset.from_tensor_slices([x]))
   ds = self.make_distributed_dataset(
       ds, cluster, processing_mode="distributed_epoch")
   self.assertDatasetProduces(ds, elements, assert_items_equal=True)
Esempio n. 25
0
 def testRoundRobin(self, num_workers, num_consumers):
     cluster = data_service_test_base.TestCluster(num_workers=num_workers)
     # Round robin reads can cause slow cluster shutdown.
     data_service_test_base.GLOBAL_CLUSTERS.add(cluster)
     ds = self.make_round_robin_dataset(cluster, num_consumers)
     ds = ds.take(100)
     results = self.getDatasetOutput(ds)
     self.checkRoundRobinGroups(results, num_consumers)
Esempio n. 26
0
 def testRepeatedDataset(self):
   cluster = data_service_test_base.TestCluster(num_workers=1)
   num_elements = 10
   num_repetitions = 5
   ds = self.make_distributed_range_dataset(num_elements, cluster)
   ds = ds.repeat(num_repetitions)
   self.assertDatasetProduces(
       ds, expected_output=num_repetitions * list(range(num_elements)))
Esempio n. 27
0
 def testDifferentJobNames(self):
   cluster = data_service_test_base.TestCluster(num_workers=1)
   num_elements = 10
   ds1 = self.make_distributed_range_dataset(
       num_elements, cluster, job_name="job_name1")
   ds2 = self.make_distributed_range_dataset(
       num_elements, cluster, job_name="job_name2")
   self.assertDatasetProduces(ds1, list(range(num_elements)))
   self.assertDatasetProduces(ds2, list(range(num_elements)))
Esempio n. 28
0
 def testMaxOutstandingRequests(self):
     num_workers = 3
     cluster = data_service_test_base.TestCluster(num_workers=num_workers)
     num_elements = 10
     ds = self.make_distributed_range_dataset(num_elements,
                                              cluster,
                                              max_outstanding_requests=1)
     self.assertCountEqual(num_workers * list(range(num_elements)),
                           self.getDatasetOutput(ds))
Esempio n. 29
0
    def testFromDatasetIdNotRegistered(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)

        dataset_id = 0
        element_spec = tensor_spec.TensorSpec(shape=(), dtype=dtypes.variant)
        from_dataset_id_ds = data_service_ops.from_dataset_id(
            "parallel_epochs", cluster.target, dataset_id, element_spec)
        with self.assertRaisesRegex(errors.NotFoundError, "Dataset id"):
            self.evaluate(self.getNext(from_dataset_id_ds)())
Esempio n. 30
0
 def testDistributeBasic(self, work_dir, fault_tolerant_mode):
     cluster = data_service_test_base.TestCluster(
         num_workers=1,
         work_dir=work_dir,
         fault_tolerant_mode=fault_tolerant_mode)
     num_elements = 10
     ds = self.make_distributed_range_dataset(num_elements, cluster)
     results = [elem.numpy() for elem in ds]
     self.assertEqual(list(range(num_elements)), results)