def make_client(num_workers, num_ps): cluster_def = multi_worker_test_base.create_in_process_cluster( num_workers=num_workers, num_ps=num_ps, rpc_layer="grpc") cluster_def["chief"] = [ "localhost:%d" % multi_worker_test_base.pick_unused_port() ] cluster_resolver = SimpleClusterResolver(ClusterSpec(cluster_def), rpc_layer="grpc") return parameter_server_client.ParameterServerClient(cluster_resolver)
def proc_func(functions_scheduled_event, test_finished_event): cluster_resolver = TFConfigClusterResolver() if cluster_resolver.task_type != "chief": utils.start_server(cluster_resolver, "grpc") ps_client = parameter_server_client.ParameterServerClient( cluster_resolver) with ps_client._strategy.scope(): v = variables.Variable(initial_value=0, dtype=dtypes.int32) @def_function.function def worker_fn(): # An ever-running function. for _ in math_ops.range(100000): v.assign_add(1) # Keep the two workers occupied. ps_client.schedule(worker_fn) ps_client.schedule(worker_fn) # Now the main process can terminate. functions_scheduled_event.set() # Verified that join and schedule indeed raise # ParameterServerFailureError. try: if test_join: ps_client.join() if test_schedule: while ps_client.cluster._closure_queue._error is None: time.sleep(1) ps_client.schedule(worker_fn) except client.ParameterServerFailureError: # The following verifies that after PS fails, continue executing # functions on workers should fail and indicate it's PS failure. for worker_id in range(3): with ops.device( "/job:worker/replica:0/task:{}".format(worker_id)): try: # Executing a function after PS fails should result in a PS # failure. worker_fn() except Exception as e: # pylint: disable=broad-except if client._is_ps_failure(e): if worker_id < 2: continue logging.info( "_test_translate_ps_failure_error ends properly." ) # Now we can safely exit the test. test_finished_event.set() return raise RuntimeError( "Executing a function after PS fails, should " "result in a PS failure.") raise RuntimeError( "ParameterServerFailureError supposed to be raised.")