def testWorkerContinuousFailure(self): model = Model(self.cluster_coord) model.do_infinite_step.assign(True) model.schedule_training_functions(10) # Model does infinite training step, so at this moment, we expect to have 2 # infinite closures inflight, and 8 closures in the queue. while (self.cluster_coord._cluster._closure_queue. _inflight_closure_count < 2): time.sleep(0.1) self.assertFalse(self.cluster_coord.done()) self._cluster.kill_task("worker", 0) time.sleep(2) self.assertFalse(context.check_alive("/job:worker/replica:0/task:0")) self._cluster.start_task("worker", 0) time.sleep(2) self.assertTrue(context.check_alive("/job:worker/replica:0/task:0")) self._cluster.kill_task("worker", 0) time.sleep(2) self.assertFalse(context.check_alive("/job:worker/replica:0/task:0")) self._cluster.start_task("worker", 0) time.sleep(2) self.assertTrue(context.check_alive("/job:worker/replica:0/task:0")) model.join_training_functions() self.assertGreaterEqual(model.iterations.numpy(), 10)
def _restart(self, downtime_secs, job): """Kills `job` (index: 0) and restarts it after `downtime_secs`. Args: downtime_secs: secs before restarting the job. job: a string specifying the job to restart. """ self._cluster.kill_task(job, 0) time.sleep(downtime_secs) self.assertFalse(context.check_alive("/job:%s/replica:0/task:0" % job)) self._cluster.start_task(job, 0) while not context.check_alive("/job:%s/replica:0/task:0" % job): time.sleep(1)
def testTwoWorkersPreempted(self): model = Model(self.cluster_coord) model.schedule_training_functions(10) time.sleep(1) self.assertFalse(self.cluster_coord.done()) self._cluster.kill_task("worker", 0) self._cluster.kill_task("worker", 1) time.sleep(2) self.assertFalse(context.check_alive("/job:worker/replica:0/task:0")) self.assertFalse(context.check_alive("/job:worker/replica:0/task:1")) self._cluster.start_task("worker", 0) self._cluster.start_task("worker", 1) time.sleep(2) self.assertTrue(context.check_alive("/job:worker/replica:0/task:0")) self.assertTrue(context.check_alive("/job:worker/replica:0/task:1")) model.join_training_functions() self.assertGreaterEqual(model.iterations.numpy(), 10)
def testWorkerContinuousFailure(self): model = self._create_model_and_run_indefinitely() self.assertFalse(self.cluster_coord.done()) self._cluster.kill_task("worker", 0) time.sleep(2) self.assertFalse(context.check_alive("/job:worker/replica:0/task:0")) self._cluster.start_task("worker", 0) time.sleep(2) self.assertTrue(context.check_alive("/job:worker/replica:0/task:0")) self._cluster.kill_task("worker", 0) time.sleep(2) self.assertFalse(context.check_alive("/job:worker/replica:0/task:0")) self._cluster.start_task("worker", 0) time.sleep(2) self.assertTrue(context.check_alive("/job:worker/replica:0/task:0")) model.join_training_functions() self.assertGreaterEqual(model.iterations.numpy(), 10)
def testTwoWorkersPreempted(self): if self.num_workers < 2: self.skipTest("Worker number is less than 2.") model = self._create_model_and_run_indefinitely() self.assertFalse(self.cluster_coord.done()) self._cluster.kill_task("worker", 0) self._cluster.kill_task("worker", 1) time.sleep(2) self.assertFalse(context.check_alive("/job:worker/replica:0/task:0")) self.assertFalse(context.check_alive("/job:worker/replica:0/task:1")) self._cluster.start_task("worker", 0) self._cluster.start_task("worker", 1) time.sleep(2) self.assertTrue(context.check_alive("/job:worker/replica:0/task:0")) self.assertTrue(context.check_alive("/job:worker/replica:0/task:1")) model.join_training_functions() self.assertGreaterEqual(model.iterations.numpy(), 10)
def testKillAndStartTask(self): self.assertTrue(context.check_alive("/job:worker/replica:0/task:0")) # It is not allowed to start a task before killing it. with self.assertRaises(ValueError): self._cluster.start_task("worker", 0) self._cluster.kill_task("worker", 0) self.assertFalse(context.check_alive("/job:worker/replica:0/task:0")) # The task is already killed. with self.assertRaises(ValueError): self._cluster.kill_task("worker", 0) self._cluster.start_task("worker", 0) # Without a call to update_server_def, the next check_alive will return # False. Alternatively sleeping for 2 seconds here also works. context.context().update_server_def(context.get_server_def()) self.assertTrue(context.check_alive("/job:worker/replica:0/task:0"))
def testAsyncWaitIsNoOp(self): if self.num_workers < 2: self.skipTest("Worker number is less than 2.") model = self._create_model_and_run_indefinitely() self.assertFalse(self.cluster_coord.done()) self._cluster.kill_task("worker", 0) time.sleep(2) self.assertFalse(context.check_alive("/job:worker/replica:0/task:0")) # Should pass without exception even with failed remote workers context.async_wait() model.join_training_functions() self.assertGreaterEqual(model.iterations.numpy(), 10)
def testCheckAlive(self): with self.assertRaisesRegex(ValueError, "Context is not initialized."): context.check_alive("/job:remote_device/task:0") context.context().ensure_initialized() self.assertTrue(context.check_alive("/job:remote_device/replica:0/task:0")) self.assertTrue(context.check_alive("/job:remote_device/replica:0/task:1")) with self.assertRaisesRegex(errors.InvalidArgumentError, "Unable to find worker interface"): context.check_alive("/job:remote_device/replica:0/task:10")
def testFetchFromPSAfterWorkerFailure(self): # Test for flaky failures when reading from a parameter server while a # worker is recovering. # Place some variables on PSes using distribute_datasets_from_function, # kill a worker, and continuously poll one of those variables. model = Model(self.cluster_coord) # kill the worker after a delay to make sure variable reading runs while # worker is up, while it's down, and while it restarts def kill_after_delay(): time.sleep(3) logging.info("Killing worker 0") self._cluster.kill_task("worker", 0) time.sleep(1) logging.info("Restarting worker 0") self._cluster.start_task("worker", 0) kill_thread = threading.Thread(target=kill_after_delay) kill_thread.start() model.do_infinite_step.assign(True) model.schedule_training_functions(1) num_reads = 0 num_reads_after_restart = 0 read_interval_secs = 0.1 worker_has_stopped = False # limit runtime of the test: stop after doing a few reads after worker # is back up, or after a fixed maximum number of reads while num_reads_after_restart <= 5 and num_reads < 200: worker_up = context.check_alive("/job:worker/replica:0/task:0") if not worker_up: worker_has_stopped = True if worker_up and worker_has_stopped: num_reads_after_restart += 1 model.join_training_functions() start = time.time() while time.time() < start + read_interval_secs: model.iterations.read_value() num_reads += 1 # run another epoch model.do_infinite_step.assign(True) model.schedule_training_functions(1)
def testCheckAlive(self): with self.assertRaisesRegexp(ValueError, "Context is not initialized."): context.check_alive("/job:remote_device/task:0") context.context().ensure_initialized() self.assertTrue(context.check_alive("/job:remote_device/replica:0/task:0")) self.assertTrue(context.check_alive("/job:remote_device/replica:0/task:1")) with self.assertRaisesRegexp( errors.InvalidArgumentError, "Client for target /job:remote_device/replica:0/task:10 not found."): context.check_alive("/job:remote_device/replica:0/task:10")
def testStop(self): self._cluster.stop() self.assertFalse(context.check_alive("/job:worker/replica:0/task:0")) self.assertFalse(context.check_alive("/job:worker/replica:0/task:1")) self.assertFalse(context.check_alive("/job:ps/replica:0/task:0")) self.assertFalse(context.check_alive("/job:chief/replica:0/task:0"))
def testClusterIsAlive(self): self.assertTrue(context.check_alive("/job:worker/replica:0/task:0")) self.assertTrue(context.check_alive("/job:worker/replica:0/task:1")) self.assertTrue(context.check_alive("/job:ps/replica:0/task:0")) self.assertTrue(context.check_alive("/job:chief/replica:0/task:0"))