Esempio n. 1
0
  def testRestartWorker(self, use_same_port):
    self._master = server_lib.MasterServer(port=0, protocol=PROTOCOL)
    self._worker = server_lib.WorkerServer(
        port=0, master_address=self._master._address, protocol=PROTOCOL)
    num_elements = 100
    ds = dataset_ops.Dataset.range(num_elements)
    ds = _make_distributed_dataset(ds, self._master._address)
    iterator = iter(ds)
    # Read halfway through the dataset.
    midpoint = num_elements // 2
    for i in range(midpoint):
      self.assertEqual(i, next(iterator).numpy())

    # Stop the original worker and start a new one.
    port = 0
    if use_same_port:
      port = int(self._worker._address.split(":")[1])
    self._worker._stop()
    self._new_worker = server_lib.WorkerServer(
        port=port, master_address=self._master._address, protocol=PROTOCOL)

    # The dataset starts over now that we read from the new worker.
    for i in range(num_elements):
      val = next(iterator).numpy()
      if val == midpoint and i != midpoint:
        # There may have been one last element prefetched from the first worker
        # before it was stopped.
        val = next(iterator).numpy()
      self.assertEqual(i, val)
Esempio n. 2
0
    def testRestartWorker(self, use_same_port):
        self._master = server_lib.MasterServer(port=0, protocol=PROTOCOL)
        self._worker = server_lib.WorkerServer(
            port=0, master_address=self._master._address, protocol=PROTOCOL)
        num_elements = 100
        ds = dataset_ops.Dataset.range(num_elements)
        ds = _make_distributed_dataset(ds, self._master._address)
        iterator = iter(ds)
        # Read halfway through the dataset.
        midpoint = num_elements // 2
        for i in range(midpoint):
            self.assertEqual(i, next(iterator).numpy())

        # Stop the original worker and start a new one.
        port = 0
        if use_same_port:
            port = int(self._worker._address.split(":")[1])
        self._worker._stop()
        self._new_worker = server_lib.WorkerServer(
            port=port, master_address=self._master._address, protocol=PROTOCOL)

        # There may have been some elements prefetched from the first worker
        # before it was stopped.
        while True:
            val = next(iterator).numpy()
            if val == 0:
                break

        # The dataset starts over now that we read from the new worker.
        # TODO(b/157086991): Iterate until end of sequence when we support
        # detecting lost workers.
        for i in range(1, num_elements // 2):
            val = next(iterator).numpy()
            self.assertEqual(i, val)
Esempio n. 3
0
 def testStopStartMaster(self):
     master = server_lib.MasterServer(0)
     master._stop()
     with self.assertRaisesRegex(
             RuntimeError,
             "Server cannot be started after it has been stopped"):
         master.start()
Esempio n. 4
0
 def testMasterNumWorkers(self):
     master = server_lib.MasterServer(0)
     self.assertEqual(0, master._num_workers())
     worker1 = server_lib.WorkerServer(0, master._address)  # pylint: disable=unused-variable
     self.assertEqual(1, master._num_workers())
     worker2 = server_lib.WorkerServer(0, master._address)  # pylint: disable=unused-variable
     self.assertEqual(2, master._num_workers())
Esempio n. 5
0
  def create_cluster(self, num_workers):
    """Creates a cluster of tf.data service servers.

    Args:
      num_workers: The number of workers in the cluster.

    Returns:
      The address of the master.
    """
    self._master = server_lib.MasterServer(port=0, protocol=PROTOCOL)
    self._servers = []
    for _ in range(num_workers):
      self._servers.append(
          server_lib.WorkerServer(
              port=0, master_address=self._master._address, protocol=PROTOCOL))

    return self._master._address
Esempio n. 6
0
    def testAddWorkerMidJob(self):
        self._master = server_lib.MasterServer(port=0, protocol=PROTOCOL)
        self._worker = server_lib.WorkerServer(
            port=0, master_address=self._master._address, protocol=PROTOCOL)
        num_elements = 100
        ds = dataset_ops.Dataset.range(num_elements)
        ds = _make_distributed_dataset(ds, self._master._address)
        iterator = iter(ds)
        results = []
        # Read halfway through the dataset.
        for _ in range(num_elements // 2):
            results.append(next(iterator).numpy())

        self._new_worker = server_lib.WorkerServer(
            port=0, master_address=self._master._address, protocol=PROTOCOL)

        # Wait for the new worker to register with the master.
        while self._master._num_workers() < 2:
            time.sleep(10 / 1000)  # 10ms

        for elem in iterator:
            results.append(elem.numpy())

        self.assertCountEqual(2 * list(range(num_elements)), results)
Esempio n. 7
0
 def testJoinWorker(self):
     master = server_lib.MasterServer(0)
     worker = server_lib.WorkerServer(0, master._address)
     worker._stop()
     worker.join()
Esempio n. 8
0
 def testJoinMaster(self):
     master = server_lib.MasterServer(0)
     master._stop()
     master.join()
Esempio n. 9
0
 def testStopMaster(self):
     master = server_lib.MasterServer(0)
     master._stop()
     master._stop()
Esempio n. 10
0
 def testMultipleStartWorker(self):
     master = server_lib.MasterServer(0)
     worker = server_lib.WorkerServer(0, master._address, start=True)
     worker.start()
Esempio n. 11
0
 def testStartWorker(self):
     master = server_lib.MasterServer(0)
     worker = server_lib.WorkerServer(0, master._address, start=False)
     worker.start()
Esempio n. 12
0
 def testMultipleStartMaster(self):
     master = server_lib.MasterServer(0, start=True)
     master.start()
Esempio n. 13
0
 def testStartMaster(self):
     master = server_lib.MasterServer(0, start=False)
     master.start()