def fn(functions_scheduled_event, test_finished_event): # TODO(b/170664373): This is needed for TF2 parameter server training in # OSS. Remove this when resolved. os.environ["GRPC_FAIL_FAST"] = "use_caller" cluster_resolver = TFConfigClusterResolver() if cluster_resolver.task_type != "chief": utils.start_server(cluster_resolver, "grpc") strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( cluster_resolver) ps_coordinator = coordinator_lib.ClusterCoordinator(strategy) with strategy.scope(): v = variables.Variable(initial_value=0, dtype=dtypes.int32) @def_function.function def worker_fn(): # An ever-running function. for _ in math_ops.range(100000): v.assign_add(1) # Keep the two workers occupied. ps_coordinator.schedule(worker_fn) ps_coordinator.schedule(worker_fn) # Now the main process can terminate. functions_scheduled_event.set() # Verified that join and schedule indeed raise UnavailableError. try: if test_join: ps_coordinator.join() if test_schedule: while ps_coordinator.cluster._closure_queue._error is None: time.sleep(1) ps_coordinator.schedule(worker_fn) except errors.UnavailableError: # The following verifies that after PS fails, continue executing # functions on workers should fail and indicate it's PS failure. for worker_id in range(3): with ops.device( "/job:worker/replica:0/task:{}".format(worker_id)): try: # Executing a function after PS fails should result in a PS # failure. worker_fn() except Exception as e: # pylint: disable=broad-except if coordinator_lib._is_ps_failure(e): if worker_id < 2: continue logging.info( "_test_translate_ps_failure_error ends properly." ) # Now we can safely exit the test. test_finished_event.set() return raise RuntimeError( "Executing a function after PS fails, should " "result in a PS failure.") raise RuntimeError("UnavailableError supposed to be raised.")
def testWorkerExecutionAfterPsFailureRaisesExpectedError(self): model = self._create_model_and_run_indefinitely() for i in range(self.num_ps): self._cluster.kill_task("ps", i) while self.cluster_coord._cluster.closure_queue._error is None: time.sleep(1) @def_function.function def trivial_function(): return model.iterations + 1 for i in range(self.num_workers): try: with ops.device("/job:worker/replica:0/task:{}".format(i)): trivial_function() except Exception as e: # pylint: disable=broad-except if cluster_coordinator._is_ps_failure(e): if i < self.num_workers - 1: continue return raise AssertionError("Executing a function after PS fails, should " "result in a PS failure.")