def fn(functions_scheduled_event, test_finished_event):
            # TODO(b/170664373): This is needed for TF2 parameter server training in
            # OSS. Remove this when resolved.
            os.environ["GRPC_FAIL_FAST"] = "use_caller"

            cluster_resolver = TFConfigClusterResolver()
            if cluster_resolver.task_type != "chief":
                utils.start_server(cluster_resolver, "grpc")
            strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
                cluster_resolver)
            ps_coordinator = coordinator_lib.ClusterCoordinator(strategy)

            with strategy.scope():
                v = variables.Variable(initial_value=0, dtype=dtypes.int32)

            @def_function.function
            def worker_fn():
                # An ever-running function.
                for _ in math_ops.range(100000):
                    v.assign_add(1)

            # Keep the two workers occupied.
            ps_coordinator.schedule(worker_fn)
            ps_coordinator.schedule(worker_fn)
            # Now the main process can terminate.
            functions_scheduled_event.set()

            # Verified that join and schedule indeed raise UnavailableError.
            try:
                if test_join:
                    ps_coordinator.join()
                if test_schedule:
                    while ps_coordinator.cluster._closure_queue._error is None:
                        time.sleep(1)
                    ps_coordinator.schedule(worker_fn)
            except errors.UnavailableError:
                # The following verifies that after PS fails, continue executing
                # functions on workers should fail and indicate it's PS failure.
                for worker_id in range(3):
                    with ops.device(
                            "/job:worker/replica:0/task:{}".format(worker_id)):
                        try:
                            # Executing a function after PS fails should result in a PS
                            # failure.
                            worker_fn()
                        except Exception as e:  # pylint: disable=broad-except
                            if coordinator_lib._is_ps_failure(e):
                                if worker_id < 2:
                                    continue
                                logging.info(
                                    "_test_translate_ps_failure_error ends properly."
                                )
                                # Now we can safely exit the test.
                                test_finished_event.set()
                                return
                        raise RuntimeError(
                            "Executing a function after PS fails, should "
                            "result in a PS failure.")

            raise RuntimeError("UnavailableError supposed to be raised.")
Exemplo n.º 2
0
  def testWorkerExecutionAfterPsFailureRaisesExpectedError(self):
    model = self._create_model_and_run_indefinitely()
    for i in range(self.num_ps):
      self._cluster.kill_task("ps", i)
    while self.cluster_coord._cluster.closure_queue._error is None:
      time.sleep(1)

    @def_function.function
    def trivial_function():
      return model.iterations + 1

    for i in range(self.num_workers):
      try:
        with ops.device("/job:worker/replica:0/task:{}".format(i)):
          trivial_function()
      except Exception as e:  # pylint: disable=broad-except
        if cluster_coordinator._is_ps_failure(e):
          if i < self.num_workers - 1:
            continue
          return
      raise AssertionError("Executing a function after PS fails, should "
                           "result in a PS failure.")