def make_client(num_workers, num_ps):
    cluster_def = multi_worker_test_base.create_in_process_cluster(
        num_workers=num_workers, num_ps=num_ps, rpc_layer="grpc")
    cluster_def["chief"] = [
        "localhost:%d" % multi_worker_test_base.pick_unused_port()
    ]
    cluster_resolver = SimpleClusterResolver(ClusterSpec(cluster_def),
                                             rpc_layer="grpc")
    return parameter_server_client.ParameterServerClient(cluster_resolver)
Esempio n. 2
0
        def proc_func(functions_scheduled_event, test_finished_event):
            cluster_resolver = TFConfigClusterResolver()
            if cluster_resolver.task_type != "chief":
                utils.start_server(cluster_resolver, "grpc")
            ps_client = parameter_server_client.ParameterServerClient(
                cluster_resolver)
            with ps_client._strategy.scope():
                v = variables.Variable(initial_value=0, dtype=dtypes.int32)

            @def_function.function
            def worker_fn():
                # An ever-running function.
                for _ in math_ops.range(100000):
                    v.assign_add(1)

            # Keep the two workers occupied.
            ps_client.schedule(worker_fn)
            ps_client.schedule(worker_fn)
            # Now the main process can terminate.
            functions_scheduled_event.set()

            # Verified that join and schedule indeed raise
            # ParameterServerFailureError.
            try:
                if test_join:
                    ps_client.join()
                if test_schedule:
                    while ps_client.cluster._closure_queue._error is None:
                        time.sleep(1)
                    ps_client.schedule(worker_fn)
            except client.ParameterServerFailureError:
                # The following verifies that after PS fails, continue executing
                # functions on workers should fail and indicate it's PS failure.
                for worker_id in range(3):
                    with ops.device(
                            "/job:worker/replica:0/task:{}".format(worker_id)):
                        try:
                            # Executing a function after PS fails should result in a PS
                            # failure.
                            worker_fn()
                        except Exception as e:  # pylint: disable=broad-except
                            if client._is_ps_failure(e):
                                if worker_id < 2:
                                    continue
                                logging.info(
                                    "_test_translate_ps_failure_error ends properly."
                                )
                                # Now we can safely exit the test.
                                test_finished_event.set()
                                return
                        raise RuntimeError(
                            "Executing a function after PS fails, should "
                            "result in a PS failure.")

            raise RuntimeError(
                "ParameterServerFailureError supposed to be raised.")