Exemplo n.º 1
0
 def testMasterNumWorkers(self):
     master = server_lib.MasterServer(0)
     self.assertEqual(0, master._num_workers())
     worker1 = server_lib.WorkerServer(0, master._address)  # pylint: disable=unused-variable
     self.assertEqual(1, master._num_workers())
     worker2 = server_lib.WorkerServer(0, master._address)  # pylint: disable=unused-variable
     self.assertEqual(2, master._num_workers())
Exemplo n.º 2
0
  def testAddWorkerMidJob(self):
    self._master = server_lib.MasterServer(PROTOCOL)
    master_address = self._master.target[len(PROTOCOL + "://"):]
    self._worker = server_lib.WorkerServer(
        PROTOCOL, master_address=master_address)
    num_elements = 100
    ds = dataset_ops.Dataset.range(num_elements)
    ds = _make_distributed_dataset(ds, self._master.target)
    iterator = iter(ds)
    results = []
    # Read halfway through the dataset.
    for _ in range(num_elements // 2):
      results.append(next(iterator).numpy())

    self._new_worker = server_lib.WorkerServer(
        PROTOCOL, master_address=master_address)

    # Wait for the new worker to register with the master.
    while self._master.num_tasks() < 2:
      time.sleep(10 / 1000)  # 10ms

    for elem in iterator:
      results.append(elem.numpy())

    self.assertCountEqual(2 * list(range(num_elements)), results)
Exemplo n.º 3
0
 def testStopStartMaster(self):
     master = server_lib.MasterServer(0)
     master._stop()
     with self.assertRaisesRegex(
             RuntimeError,
             "Server cannot be started after it has been stopped"):
         master.start()
    def testRestartWorker(self, use_same_port):
        self._master = server_lib.MasterServer(port=0, protocol=PROTOCOL)
        self._worker = server_lib.WorkerServer(
            port=0, master_address=self._master._address, protocol=PROTOCOL)
        num_elements = 100
        ds = dataset_ops.Dataset.range(num_elements)
        ds = _make_distributed_dataset(ds, self._master._address)
        iterator = iter(ds)
        # Read halfway through the dataset.
        midpoint = num_elements // 2
        for i in range(midpoint):
            self.assertEqual(i, next(iterator).numpy())

        # Stop the original worker and start a new one.
        port = 0
        if use_same_port:
            port = int(self._worker._address.split(":")[1])
        self._worker._stop()
        self._new_worker = server_lib.WorkerServer(
            port=port, master_address=self._master._address, protocol=PROTOCOL)

        # The dataset starts over now that we read from the new worker.
        for i in range(num_elements):
            val = next(iterator).numpy()
            if val == midpoint and i != midpoint:
                # There may have been one last element prefetched from the first worker
                # before it was stopped.
                val = next(iterator).numpy()
            self.assertEqual(i, val)
Exemplo n.º 5
0
    def testRestartWorker(self, use_same_port):
        self._master = server_lib.MasterServer(PROTOCOL)
        master_address = self._master.target[len(PROTOCOL + "://"):]
        self._worker = server_lib.WorkerServer(PROTOCOL,
                                               master_address=master_address)
        num_elements = 100
        ds = dataset_ops.Dataset.range(num_elements)
        ds = ds.apply(
            data_service_ops._distribute(self._master.target,
                                         task_refresh_interval_hint_ms=20))
        token = data_service_ops.create_job(ds,
                                            processing_mode="parallel_epochs")
        iterator = data_service_ops.create_iterator(ds, token)
        # Read halfway through the dataset.
        for i in range(num_elements // 2):
            self.assertEqual(i, next(iterator).numpy())

        # Stop the original worker and start a new one.
        port = 0
        if use_same_port:
            worker_address = self._worker.target[len(PROTOCOL + "://"):]
            port = int(worker_address.split(":")[1])
        self._worker.stop()
        self._new_worker = server_lib.WorkerServer(
            PROTOCOL, master_address=master_address, port=port)

        # There may be one last element prefetched from the first worker before it
        # was stopped.
        val = next(iterator).numpy()
        self.assertTrue(val == 0 or val == num_elements // 2)
        start_val = 1 if val == 0 else 0

        # The dataset starts over now that we read from the new worker.
        for i in range(start_val, num_elements):
            self.assertEqual(i, next(iterator).numpy())
Exemplo n.º 6
0
    def testAddWorkerMidJob(self):
        self._master = server_lib.MasterServer(PROTOCOL)
        master_address = self._master.target[len(PROTOCOL + "://"):]
        self._worker = server_lib.WorkerServer(PROTOCOL,
                                               master_address=master_address)
        num_elements = 100
        ds = dataset_ops.Dataset.range(num_elements)
        ds = ds.apply(
            data_service_ops._distribute(self._master.target,
                                         task_refresh_interval_hint_ms=20))
        token = data_service_ops.create_job(ds,
                                            processing_mode="parallel_epochs")
        iterator = data_service_ops.create_iterator(ds, token)
        results = []
        # Read halfway through the dataset.
        for _ in range(num_elements // 2):
            results.append(next(iterator).numpy())

        self._new_worker = server_lib.WorkerServer(
            PROTOCOL, master_address=master_address)

        # Give the client time to notice the new task.
        time.sleep(50 / 1000)  # 50ms

        for elem in iterator:
            results.append(elem.numpy())

        # It is possible that reading from the first worker completes before the
        # client notices the second worker. We allow this to avoid flaky failures.
        if len(results) == num_elements:
            self.assertEqual(list(range(num_elements)), results)
        else:
            self.assertCountEqual(2 * list(range(num_elements)), results)
    def create_cluster(self, num_workers):
        """Creates a cluster of tf.data service servers.

    Args:
      num_workers: The number of workers in the cluster.

    Returns:
      The address of the master.
    """
        self._master = server_lib.MasterServer(port=0, protocol=PROTOCOL)
        self._servers = []
        for _ in range(num_workers):
            self._servers.append(
                server_lib.WorkerServer(port=0,
                                        master_address=self._master._address,
                                        protocol=PROTOCOL))

        return self._master._address
Exemplo n.º 8
0
  def create_cluster(self, num_workers):
    """Creates a cluster of tf.data service servers.

    Args:
      num_workers: The number of workers in the cluster.

    Returns:
      A target for connecting to the service, e.g.
      "grpc+local://localhost:2000".
    """
    self._master = server_lib.MasterServer(PROTOCOL)
    master_address = self._master.target[len(PROTOCOL + "://"):]

    self._servers = []
    for _ in range(num_workers):
      self._servers.append(
          server_lib.WorkerServer(PROTOCOL, master_address=master_address))

    return self._master.target
Exemplo n.º 9
0
 def testStartMaster(self):
     master = server_lib.MasterServer(0, start=False)
     master.start()
Exemplo n.º 10
0
 def testJoinWorker(self):
     master = server_lib.MasterServer(0)
     worker = server_lib.WorkerServer(0, master._address)
     worker._stop()
     worker.join()
Exemplo n.º 11
0
 def testJoinMaster(self):
     master = server_lib.MasterServer(0)
     master._stop()
     master.join()
Exemplo n.º 12
0
 def testStopMaster(self):
     master = server_lib.MasterServer(0)
     master._stop()
     master._stop()
Exemplo n.º 13
0
 def testMultipleStartWorker(self):
     master = server_lib.MasterServer(0)
     worker = server_lib.WorkerServer(0, master._address, start=True)
     worker.start()
Exemplo n.º 14
0
 def testStartWorker(self):
     master = server_lib.MasterServer(0)
     worker = server_lib.WorkerServer(0, master._address, start=False)
     worker.start()
Exemplo n.º 15
0
 def testMultipleStartMaster(self):
     master = server_lib.MasterServer(0, start=True)
     master.start()
Exemplo n.º 16
0
 def testStartWorker(self):
   master = server_lib.MasterServer(PROTOCOL)
   worker = server_lib.WorkerServer(PROTOCOL,
                                    master.target[len(PROTOCOL + "://"):])
   self.assertRegex(worker.target, PROTOCOL + "://.*:.*")
Exemplo n.º 17
0
 def testStartMaster(self):
   master = server_lib.MasterServer(PROTOCOL)
   self.assertRegex(master.target, PROTOCOL + "://.*:.*")