示例#1
0
 def testMultipleDistributeCalls(self):
     service = self.create_cluster(1)
     ds1 = dataset_ops.Dataset.range(1)
     ds1 = ds1.apply(data_service_ops.distribute(service))
     ds2 = dataset_ops.Dataset.range(1)
     ds2 = ds2.apply(data_service_ops.distribute(service))
     ds = dataset_ops.Dataset.zip((ds1, ds2))
     with self.assertRaisesWithLiteralMatch(
             ValueError,
             "Datasets containing multiple calls to .distribute(...) "
             "are not supported"):
         data_service_ops.create_job(ds, processing_mode="parallel_epochs")
示例#2
0
    def testRestartWorker(self, use_same_port):
        self._master = server_lib.MasterServer(PROTOCOL)
        master_address = self._master.target[len(PROTOCOL + "://"):]
        self._worker = server_lib.WorkerServer(PROTOCOL,
                                               master_address=master_address)
        num_elements = 100
        ds = dataset_ops.Dataset.range(num_elements)
        ds = ds.apply(
            data_service_ops._distribute(self._master.target,
                                         task_refresh_interval_hint_ms=20))
        token = data_service_ops.create_job(ds,
                                            processing_mode="parallel_epochs")
        iterator = data_service_ops.create_iterator(ds, token)
        # Read halfway through the dataset.
        for i in range(num_elements // 2):
            self.assertEqual(i, next(iterator).numpy())

        # Stop the original worker and start a new one.
        port = 0
        if use_same_port:
            worker_address = self._worker.target[len(PROTOCOL + "://"):]
            port = int(worker_address.split(":")[1])
        self._worker.stop()
        self._new_worker = server_lib.WorkerServer(
            PROTOCOL, master_address=master_address, port=port)

        # There may be one last element prefetched from the first worker before it
        # was stopped.
        val = next(iterator).numpy()
        self.assertTrue(val == 0 or val == num_elements // 2)
        start_val = 1 if val == 0 else 0

        # The dataset starts over now that we read from the new worker.
        for i in range(start_val, num_elements):
            self.assertEqual(i, next(iterator).numpy())
示例#3
0
    def testAddWorkerMidJob(self):
        self._master = server_lib.MasterServer(PROTOCOL)
        master_address = self._master.target[len(PROTOCOL + "://"):]
        self._worker = server_lib.WorkerServer(PROTOCOL,
                                               master_address=master_address)
        num_elements = 100
        ds = dataset_ops.Dataset.range(num_elements)
        ds = ds.apply(
            data_service_ops._distribute(self._master.target,
                                         task_refresh_interval_hint_ms=20))
        token = data_service_ops.create_job(ds,
                                            processing_mode="parallel_epochs")
        iterator = data_service_ops.create_iterator(ds, token)
        results = []
        # Read halfway through the dataset.
        for _ in range(num_elements // 2):
            results.append(next(iterator).numpy())

        self._new_worker = server_lib.WorkerServer(
            PROTOCOL, master_address=master_address)

        # Give the client time to notice the new task.
        time.sleep(50 / 1000)  # 50ms

        for elem in iterator:
            results.append(elem.numpy())

        # It is possible that reading from the first worker completes before the
        # client notices the second worker. We allow this to avoid flaky failures.
        if len(results) == num_elements:
            self.assertEqual(list(range(num_elements)), results)
        else:
            self.assertCountEqual(2 * list(range(num_elements)), results)
示例#4
0
 def testMultipleEpochs(self):
     service = self.create_cluster(1)
     ds = dataset_ops.Dataset.range(3)
     ds = ds.apply(data_service_ops.distribute(service))
     for _ in range(10):
         token = data_service_ops.create_job(
             ds, processing_mode="parallel_epochs")
         it = data_service_ops.create_iterator(ds, token)
         self.assertEqual(list(range(3)), [t.numpy() for t in it])
示例#5
0
 def testDistributeBasic(self):
     num_elements = 10
     service = self.create_cluster(1)
     ds = dataset_ops.Dataset.range(num_elements)
     ds = ds.apply(data_service_ops.distribute(service))
     token = data_service_ops.create_job(ds,
                                         processing_mode="parallel_epochs")
     it = data_service_ops.create_iterator(ds, token)
     results = [t.numpy() for t in it]
     self.assertEqual(list(range(num_elements)), results)
示例#6
0
 def testMultiWorker(self):
     num_workers = 3
     num_elements = 10
     service = self.create_cluster(num_workers)
     ds = dataset_ops.Dataset.range(num_elements)
     ds = ds.apply(data_service_ops.distribute(service))
     token = data_service_ops.create_job(ds,
                                         processing_mode="parallel_epochs")
     iterator = data_service_ops.create_iterator(ds, token)
     results = [elem.numpy() for elem in iterator]
     self.assertCountEqual(num_workers * list(range(num_elements)), results)
示例#7
0
    def run_stateful(self, external_state_policy):
        num_elements = 10
        ds = dataset_ops.Dataset.range(num_elements).map(
            lambda _: random_ops.random_uniform(()))

        options = dataset_ops.Options()
        options.experimental_external_state_policy = external_state_policy
        ds = ds.with_options(options)

        service = self.create_cluster(3)
        ds = ds.apply(data_service_ops.distribute(service))
        token = data_service_ops.create_job(ds,
                                            processing_mode="parallel_epochs")
        iterator = data_service_ops.create_iterator(ds, token)
        next(iterator)
示例#8
0
 def f():
     ds = dataset_ops.Dataset.range(num_elements)
     ds = ds.apply(data_service_ops.distribute(service))
     token = data_service_ops.create_job(
         ds, processing_mode="parallel_epochs")
     it = data_service_ops.create_iterator(ds, token)
     result = tensor_array_ops.TensorArray(dtypes.int64,
                                           size=num_workers *
                                           num_elements,
                                           dynamic_size=True)
     i = 0
     for elem in it:
         result = result.write(i, elem)
         i += 1
     return result.stack()
示例#9
0
    def testConcurrentEpoch(self):
        num_elements = 10
        num_datasets = 3
        service = self.create_cluster(1)
        iterators = []
        results = []
        for _ in range(num_datasets):
            ds = dataset_ops.Dataset.range(num_elements)
            ds = ds.apply(data_service_ops.distribute(service))
            token = data_service_ops.create_job(
                ds, processing_mode="parallel_epochs")
            it = data_service_ops.create_iterator(ds, token)
            iterators.append(it)
            results.append([])

        for _ in range(num_elements):
            for dataset_ind in range(num_datasets):
                result = next(iterators[dataset_ind]).numpy()
                results[dataset_ind].append(result)
        for result in results:
            self.assertEqual(list(range(num_elements)), result)
示例#10
0
    def testSharedEpoch(self):
        num_elements = 10
        num_iterators = 3
        service = self.create_cluster(1)
        ds = dataset_ops.Dataset.range(num_elements)
        ds = ds.apply(data_service_ops.distribute(service))
        result = []
        iterators = []
        token = data_service_ops.create_job(ds,
                                            processing_mode="parallel_epochs")
        for _ in range(num_iterators):
            iterators.append(data_service_ops.create_iterator(ds, token))

        # Alternate reading between the iterators.
        for _ in range(2):
            for it in iterators:
                result.append(next(it).numpy())

        # Drain the rest of the elements.
        for it in iterators:
            for elem in it:
                result.append(elem.numpy())

        self.assertCountEqual(list(range(num_elements)), result)
示例#11
0
 def testNoDistributeCalls(self):
     ds = dataset_ops.Dataset.range(1)
     with self.assertRaisesWithLiteralMatch(
             ValueError,
             "Dataset does not contain any distribute() transformations"):
         data_service_ops.create_job(ds, processing_mode="parallel_epochs")