Esempio n. 1
0
 def make_distributed_dataset(self,
                              dataset,
                              cluster,
                              processing_mode="parallel_epochs",
                              job_name=None,
                              consumer_index=None,
                              num_consumers=None,
                              max_outstanding_requests=None,
                              data_transfer_protocol=None,
                              compression="AUTO",
                              cross_trainer_cache=None,
                              target_workers="AUTO"):
   # pylint: disable=protected-access
   return dataset.apply(
       data_service_ops._distribute(
           processing_mode,
           cluster.dispatcher_address(),
           job_name=job_name,
           consumer_index=consumer_index,
           num_consumers=num_consumers,
           max_outstanding_requests=max_outstanding_requests,
           task_refresh_interval_hint_ms=20,
           data_transfer_protocol=data_transfer_protocol,
           compression=compression,
           cross_trainer_cache=cross_trainer_cache,
           target_workers=target_workers))
Esempio n. 2
0
    def testRestartWorker(self, use_same_port):
        self._master = server_lib.MasterServer(PROTOCOL)
        master_address = self._master.target[len(PROTOCOL + "://"):]
        self._worker = server_lib.WorkerServer(PROTOCOL,
                                               master_address=master_address)
        num_elements = 100
        ds = dataset_ops.Dataset.range(num_elements)
        ds = ds.apply(
            data_service_ops._distribute(self._master.target,
                                         task_refresh_interval_hint_ms=20))
        token = data_service_ops.create_job(ds,
                                            processing_mode="parallel_epochs")
        iterator = data_service_ops.create_iterator(ds, token)
        # Read halfway through the dataset.
        for i in range(num_elements // 2):
            self.assertEqual(i, next(iterator).numpy())

        # Stop the original worker and start a new one.
        port = 0
        if use_same_port:
            worker_address = self._worker.target[len(PROTOCOL + "://"):]
            port = int(worker_address.split(":")[1])
        self._worker.stop()
        self._new_worker = server_lib.WorkerServer(
            PROTOCOL, master_address=master_address, port=port)

        # There may be one last element prefetched from the first worker before it
        # was stopped.
        val = next(iterator).numpy()
        self.assertTrue(val == 0 or val == num_elements // 2)
        start_val = 1 if val == 0 else 0

        # The dataset starts over now that we read from the new worker.
        for i in range(start_val, num_elements):
            self.assertEqual(i, next(iterator).numpy())
Esempio n. 3
0
def _make_distributed_dataset(dataset, address, job_name=None):
    """Creates a distributed dataset with a short task refresh interval."""
    return dataset.apply(
        data_service_ops._distribute("parallel_epochs",
                                     "{0}://{1}".format(PROTOCOL, address),
                                     job_name=job_name,
                                     task_refresh_interval_hint_ms=20))
Esempio n. 4
0
    def testAddWorkerMidJob(self):
        self._master = server_lib.MasterServer(PROTOCOL)
        master_address = self._master.target[len(PROTOCOL + "://"):]
        self._worker = server_lib.WorkerServer(PROTOCOL,
                                               master_address=master_address)
        num_elements = 100
        ds = dataset_ops.Dataset.range(num_elements)
        ds = ds.apply(
            data_service_ops._distribute(self._master.target,
                                         task_refresh_interval_hint_ms=20))
        token = data_service_ops.create_job(ds,
                                            processing_mode="parallel_epochs")
        iterator = data_service_ops.create_iterator(ds, token)
        results = []
        # Read halfway through the dataset.
        for _ in range(num_elements // 2):
            results.append(next(iterator).numpy())

        self._new_worker = server_lib.WorkerServer(
            PROTOCOL, master_address=master_address)

        # Give the client time to notice the new task.
        time.sleep(50 / 1000)  # 50ms

        for elem in iterator:
            results.append(elem.numpy())

        # It is possible that reading from the first worker completes before the
        # client notices the second worker. We allow this to avoid flaky failures.
        if len(results) == num_elements:
            self.assertEqual(list(range(num_elements)), results)
        else:
            self.assertCountEqual(2 * list(range(num_elements)), results)
Esempio n. 5
0
    def testGcClient(self):
        dispatcher = server_lib.DispatchServer(
            service_config_pb2.DispatcherConfig(protocol="grpc",
                                                job_gc_check_interval_ms=50,
                                                job_gc_timeout_ms=20,
                                                client_timeout_ms=50))
        dispatcher_address = dispatcher.target.split("://")[1]
        _ = server_lib.WorkerServer(
            server_lib.WorkerConfig(dispatcher_address=dispatcher_address,
                                    heartbeat_interval_ms=100))

        num_elements = 1000
        dataset = dataset_ops.Dataset.range(num_elements)
        dataset = dataset.apply(
            data_service_ops._distribute(processing_mode=ShardingPolicy.OFF,
                                         service=dispatcher.target,
                                         task_refresh_interval_hint_ms=10000))
        get_next = self.getNext(dataset)

        # The client does not heartbeat in 10 seconds. It will be garbage-collected.
        with self.assertRaisesRegex(errors.NotFoundError,
                                    "Unknown job client id"):
            self.evaluate(get_next())
            time.sleep(3)
            self.getIteratorOutput(get_next)
    def testKeepClientAliveBeforeReading(self):
        dispatcher = server_lib.DispatchServer(
            service_config_pb2.DispatcherConfig(protocol="grpc",
                                                job_gc_check_interval_ms=50,
                                                job_gc_timeout_ms=20,
                                                client_timeout_ms=1000))
        dispatcher_address = dispatcher.target.split("://")[1]
        _ = server_lib.WorkerServer(
            server_lib.WorkerConfig(dispatcher_address=dispatcher_address,
                                    heartbeat_interval_ms=100))

        num_elements = 1000
        dataset = dataset_ops.Dataset.range(num_elements)
        dataset = dataset.apply(
            data_service_ops._distribute(
                processing_mode=data_service_ops.ShardingPolicy.OFF,
                service=dispatcher.target,
                task_refresh_interval_hint_ms=100))
        get_next = self.getNext(dataset)

        # The client regularly heartbeats in 100 milliseconds. It should not be
        # garbage-collected even if it does not start reading in 3 seconds.
        time.sleep(3)
        self.assertEqual(self.getIteratorOutput(get_next),
                         list(range(num_elements)))
Esempio n. 7
0
def _make_distributed_dataset(dataset, service, job_name=None):
    """Creates a distributed dataset with a short task refresh interval."""
    return dataset.apply(
        data_service_ops._distribute("parallel_epochs",
                                     service,
                                     job_name=job_name,
                                     task_refresh_interval_hint_ms=20))
def _make_distributed_dataset(dataset,
                              dispatcher,
                              job_name=None,
                              max_outstanding_requests=None):
    return dataset.apply(
        data_service_ops._distribute(
            "parallel_epochs",
            dispatcher.target,
            job_name=job_name,
            max_outstanding_requests=max_outstanding_requests,
            task_refresh_interval_hint_ms=20))
Esempio n. 9
0
 def testMaxOutstandingRequests(self):
     num_elements = 10
     num_workers = 3
     address = self.create_cluster(num_workers)
     ds = dataset_ops.Dataset.range(num_elements)
     ds = ds.apply(
         data_service_ops._distribute("parallel_epochs",
                                      "{0}://{1}".format(PROTOCOL, address),
                                      max_outstanding_requests=1,
                                      task_refresh_interval_hint_ms=20))
     self.assertCountEqual(num_workers * list(range(num_elements)),
                           self.getDatasetOutput(ds))
Esempio n. 10
0
 def make_distributed_dataset(self,
                              dataset,
                              cluster,
                              job_name=None,
                              max_outstanding_requests=None):
     # pylint: disable=protected-access
     return dataset.apply(
         data_service_ops._distribute(
             "parallel_epochs",
             cluster.target,
             job_name=job_name,
             max_outstanding_requests=max_outstanding_requests,
             task_refresh_interval_hint_ms=20))
Esempio n. 11
0
 def make_distributed_dataset(self,
                              dataset,
                              cluster,
                              processing_mode="parallel_epochs",
                              job_name=None,
                              consumer_index=None,
                              num_consumers=None,
                              max_outstanding_requests=None):
     # pylint: disable=protected-access
     return dataset.apply(
         data_service_ops._distribute(
             processing_mode,
             cluster.target,
             job_name=job_name,
             consumer_index=consumer_index,
             num_consumers=num_consumers,
             max_outstanding_requests=max_outstanding_requests,
             task_refresh_interval_hint_ms=20))