Example #1
0
    def testWorkerContinuousFailure(self):
        model = Model(self.cluster_coord)
        model.do_infinite_step.assign(True)
        model.schedule_training_functions(10)

        # Model does infinite training step, so at this moment, we expect to have 2
        # infinite closures inflight, and 8 closures in the queue.
        while (self.cluster_coord._cluster._closure_queue.
               _inflight_closure_count < 2):
            time.sleep(0.1)
        self.assertFalse(self.cluster_coord.done())
        self._cluster.kill_task("worker", 0)
        time.sleep(2)
        self.assertFalse(context.check_alive("/job:worker/replica:0/task:0"))
        self._cluster.start_task("worker", 0)
        time.sleep(2)
        self.assertTrue(context.check_alive("/job:worker/replica:0/task:0"))
        self._cluster.kill_task("worker", 0)
        time.sleep(2)
        self.assertFalse(context.check_alive("/job:worker/replica:0/task:0"))
        self._cluster.start_task("worker", 0)
        time.sleep(2)
        self.assertTrue(context.check_alive("/job:worker/replica:0/task:0"))

        model.join_training_functions()
        self.assertGreaterEqual(model.iterations.numpy(), 10)
Example #2
0
  def _restart(self, downtime_secs, job):
    """Kills `job` (index: 0) and restarts it after `downtime_secs`.

    Args:
      downtime_secs: secs before restarting the job.
      job: a string specifying the job to restart.
    """
    self._cluster.kill_task(job, 0)
    time.sleep(downtime_secs)
    self.assertFalse(context.check_alive("/job:%s/replica:0/task:0" % job))
    self._cluster.start_task(job, 0)
    while not context.check_alive("/job:%s/replica:0/task:0" % job):
      time.sleep(1)
Example #3
0
  def testTwoWorkersPreempted(self):
    model = Model(self.cluster_coord)
    model.schedule_training_functions(10)

    time.sleep(1)
    self.assertFalse(self.cluster_coord.done())
    self._cluster.kill_task("worker", 0)
    self._cluster.kill_task("worker", 1)
    time.sleep(2)
    self.assertFalse(context.check_alive("/job:worker/replica:0/task:0"))
    self.assertFalse(context.check_alive("/job:worker/replica:0/task:1"))
    self._cluster.start_task("worker", 0)
    self._cluster.start_task("worker", 1)
    time.sleep(2)
    self.assertTrue(context.check_alive("/job:worker/replica:0/task:0"))
    self.assertTrue(context.check_alive("/job:worker/replica:0/task:1"))

    model.join_training_functions()
    self.assertGreaterEqual(model.iterations.numpy(), 10)
  def testWorkerContinuousFailure(self):
    model = self._create_model_and_run_indefinitely()

    self.assertFalse(self.cluster_coord.done())
    self._cluster.kill_task("worker", 0)
    time.sleep(2)
    self.assertFalse(context.check_alive("/job:worker/replica:0/task:0"))
    self._cluster.start_task("worker", 0)
    time.sleep(2)
    self.assertTrue(context.check_alive("/job:worker/replica:0/task:0"))
    self._cluster.kill_task("worker", 0)
    time.sleep(2)
    self.assertFalse(context.check_alive("/job:worker/replica:0/task:0"))
    self._cluster.start_task("worker", 0)
    time.sleep(2)
    self.assertTrue(context.check_alive("/job:worker/replica:0/task:0"))

    model.join_training_functions()
    self.assertGreaterEqual(model.iterations.numpy(), 10)
  def testTwoWorkersPreempted(self):
    if self.num_workers < 2:
      self.skipTest("Worker number is less than 2.")
    model = self._create_model_and_run_indefinitely()

    self.assertFalse(self.cluster_coord.done())
    self._cluster.kill_task("worker", 0)
    self._cluster.kill_task("worker", 1)
    time.sleep(2)
    self.assertFalse(context.check_alive("/job:worker/replica:0/task:0"))
    self.assertFalse(context.check_alive("/job:worker/replica:0/task:1"))
    self._cluster.start_task("worker", 0)
    self._cluster.start_task("worker", 1)
    time.sleep(2)
    self.assertTrue(context.check_alive("/job:worker/replica:0/task:0"))
    self.assertTrue(context.check_alive("/job:worker/replica:0/task:1"))

    model.join_training_functions()
    self.assertGreaterEqual(model.iterations.numpy(), 10)
Example #6
0
    def testKillAndStartTask(self):
        self.assertTrue(context.check_alive("/job:worker/replica:0/task:0"))

        # It is not allowed to start a task before killing it.
        with self.assertRaises(ValueError):
            self._cluster.start_task("worker", 0)

        self._cluster.kill_task("worker", 0)
        self.assertFalse(context.check_alive("/job:worker/replica:0/task:0"))

        # The task is already killed.
        with self.assertRaises(ValueError):
            self._cluster.kill_task("worker", 0)

        self._cluster.start_task("worker", 0)

        # Without a call to update_server_def, the next check_alive will return
        # False. Alternatively sleeping for 2 seconds here also works.
        context.context().update_server_def(context.get_server_def())

        self.assertTrue(context.check_alive("/job:worker/replica:0/task:0"))
  def testAsyncWaitIsNoOp(self):
    if self.num_workers < 2:
      self.skipTest("Worker number is less than 2.")
    model = self._create_model_and_run_indefinitely()

    self.assertFalse(self.cluster_coord.done())
    self._cluster.kill_task("worker", 0)
    time.sleep(2)
    self.assertFalse(context.check_alive("/job:worker/replica:0/task:0"))
    # Should pass without exception even with failed remote workers
    context.async_wait()

    model.join_training_functions()
    self.assertGreaterEqual(model.iterations.numpy(), 10)
Example #8
0
  def testCheckAlive(self):
    with self.assertRaisesRegex(ValueError, "Context is not initialized."):
      context.check_alive("/job:remote_device/task:0")
    context.context().ensure_initialized()

    self.assertTrue(context.check_alive("/job:remote_device/replica:0/task:0"))
    self.assertTrue(context.check_alive("/job:remote_device/replica:0/task:1"))

    with self.assertRaisesRegex(errors.InvalidArgumentError,
                                "Unable to find worker interface"):
      context.check_alive("/job:remote_device/replica:0/task:10")
    def testFetchFromPSAfterWorkerFailure(self):
        # Test for flaky failures when reading from a parameter server while a
        # worker is recovering.
        # Place some variables on PSes using distribute_datasets_from_function,
        # kill a worker, and continuously poll one of those variables.

        model = Model(self.cluster_coord)

        # kill the worker after a delay to make sure variable reading runs while
        # worker is up, while it's down, and while it restarts
        def kill_after_delay():
            time.sleep(3)
            logging.info("Killing worker 0")
            self._cluster.kill_task("worker", 0)
            time.sleep(1)
            logging.info("Restarting worker 0")
            self._cluster.start_task("worker", 0)

        kill_thread = threading.Thread(target=kill_after_delay)
        kill_thread.start()

        model.do_infinite_step.assign(True)
        model.schedule_training_functions(1)

        num_reads = 0
        num_reads_after_restart = 0
        read_interval_secs = 0.1
        worker_has_stopped = False
        # limit runtime of the test: stop after doing a few reads after worker
        # is back up, or after a fixed maximum number of reads
        while num_reads_after_restart <= 5 and num_reads < 200:
            worker_up = context.check_alive("/job:worker/replica:0/task:0")
            if not worker_up:
                worker_has_stopped = True
            if worker_up and worker_has_stopped:
                num_reads_after_restart += 1

            model.join_training_functions()
            start = time.time()
            while time.time() < start + read_interval_secs:
                model.iterations.read_value()

            num_reads += 1
            # run another epoch
            model.do_infinite_step.assign(True)
            model.schedule_training_functions(1)
Example #10
0
  def testCheckAlive(self):
    with self.assertRaisesRegexp(ValueError, "Context is not initialized."):
      context.check_alive("/job:remote_device/task:0")
    context.context().ensure_initialized()

    self.assertTrue(context.check_alive("/job:remote_device/replica:0/task:0"))
    self.assertTrue(context.check_alive("/job:remote_device/replica:0/task:1"))

    with self.assertRaisesRegexp(
        errors.InvalidArgumentError,
        "Client for target /job:remote_device/replica:0/task:10 not found."):
      context.check_alive("/job:remote_device/replica:0/task:10")
Example #11
0
 def testStop(self):
     self._cluster.stop()
     self.assertFalse(context.check_alive("/job:worker/replica:0/task:0"))
     self.assertFalse(context.check_alive("/job:worker/replica:0/task:1"))
     self.assertFalse(context.check_alive("/job:ps/replica:0/task:0"))
     self.assertFalse(context.check_alive("/job:chief/replica:0/task:0"))
Example #12
0
 def testClusterIsAlive(self):
     self.assertTrue(context.check_alive("/job:worker/replica:0/task:0"))
     self.assertTrue(context.check_alive("/job:worker/replica:0/task:1"))
     self.assertTrue(context.check_alive("/job:ps/replica:0/task:0"))
     self.assertTrue(context.check_alive("/job:chief/replica:0/task:0"))