Пример #1
0
 def test_child_failed_error(self):
     pf0 = self.failure_with_error_file(exception=SentinelError("rank 0"))
     pf1 = self.failure_with_error_file(exception=SentinelError("rank 1"))
     pf2 = self.failure_without_error_file(exitcode=138)
     ex = ChildFailedError("trainer.par", {0: pf0, 1: pf1, 2: pf2})
     self.assertEqual(pf0, ex.get_first_failure()[1])
     # print is intentional and should prints something like this:
     """
     *********************************************
           trainer.par FAILED
     =============================================
     Root Cause:
     [0]:
       time: 2020-11-25_21:22:31
       rank: 0 (local_rank: 0)
       exitcode: 1 (pid: 997)
       error_file: /tmp/ApiTesttbb37ier/error.json
       msg: "SentinelError: rank 0"
     =============================================
     Other Failures:
     [1]:
       time: 2020-11-25_21:22:31
       rank: 1 (local_rank: 0)
       exitcode: 1 (pid: 997)
       error_file: /tmp/ApiTesttbb37ier/error.json
       msg: "SentinelError: rank 1"
     [2]:
       time: 2020-11-25_21:22:31
       rank: 2 (local_rank: 0)
       exitcode: 138 (pid: 997)
       error_file: <N/A>
       msg: "Process failed with exitcode 138"
     *********************************************
     """
     print(ex)
Пример #2
0
def raise_child_failure_error_fn(name, child_error_file=""):
    if child_error_file:
        _write_error(SentinelError("foobar"), child_error_file)
    pf = ProcessFailure(local_rank=0,
                        pid=997,
                        exitcode=1,
                        error_file=child_error_file)
    raise ChildFailedError(name, {0: pf})
Пример #3
0
def raise_child_failure_error_fn(name, child_error_file=""):
    if child_error_file:
        with mock.patch.dict(os.environ,
                             {"TORCHELASTIC_ERROR_FILE": child_error_file}):
            ErrorHandler().record_exception(SentinelError("foobar"))
    pf = ProcessFailure(local_rank=0,
                        pid=997,
                        exitcode=1,
                        error_file=child_error_file)
    raise ChildFailedError(name, {0: pf})
Пример #4
0
    def run_agent(
        self,
        conf: Conf,
        agent_results: Optional[mp.Queue] = None,  # (role, agent_result)
        min_nodes=1,
        max_nodes=1,
        start_method: str = "spawn",
        max_restarts: int = 0,
        exit_barrier_timeout=5,
        master_addr_override: Optional[str] = None,
        master_port_override: Optional[int] = None,
        is_host=True,
    ) -> Optional[RunResult]:
        """
        Runs a single agent. This method can be called either on a separate process
        or the main test process. When calling this method on a sparate process make
        sure to pass the ``agent_results`` multiprocessing Queue so that the agent's
        run results can be returned. If ``agent_results`` is omitted, then the
        run result is returned from the method.
        """

        spec = self.get_worker_spec(
            node_config=conf,
            min_nodes=min_nodes,
            max_nodes=max_nodes,
            max_restarts=max_restarts,
            master_addr_override=master_addr_override,
            master_port_override=master_port_override,
            is_host=is_host,
        )
        agent = self.get_agent(
            spec=spec,
            start_method=start_method,
            exit_barrier_timeout=exit_barrier_timeout,
        )

        result = agent.run()
        spec.rdzv_handler.shutdown()

        if agent_results:
            agent_results.put((conf.role, result))

        if result.is_failed():
            raise ChildFailedError(spec.get_entrypoint_name(), result.failures)
        else:
            if not agent_results:
                return result
Пример #5
0
def launch_agent(
    config: LaunchConfig,
    entrypoint: Union[Callable, str, None],
    args: List[Any],
) -> Dict[int, Any]:
    if not config.run_id:
        run_id = str(uuid.uuid4().int)
        logger.warning(f"config has no run_id, generate a new one: {run_id}")
        config.run_id = run_id

    entrypoint_name = _get_entrypoint_name(entrypoint, args)

    logger.info(f"Starting elastic_operator with launch configs:\n"
                f"  entrypoint       : {entrypoint_name}\n"
                f"  min_nodes        : {config.min_nodes}\n"
                f"  max_nodes        : {config.max_nodes}\n"
                f"  nproc_per_node   : {config.nproc_per_node}\n"
                f"  run_id           : {config.run_id}\n"
                f"  rdzv_backend     : {config.rdzv_backend}\n"
                f"  rdzv_endpoint    : {config.rdzv_endpoint}\n"
                f"  rdzv_configs     : {config.rdzv_configs}\n"
                f"  max_restarts     : {config.max_restarts}\n"
                f"  monitor_interval : {config.monitor_interval}\n"
                f"  log_dir          : {config.log_dir}\n"
                f"  metrics_cfg      : {config.metrics_cfg}\n")

    rdzv_parameters = RendezvousParameters(
        backend=config.rdzv_backend,
        endpoint=config.rdzv_endpoint,
        run_id=config.run_id,
        min_nodes=config.min_nodes,
        max_nodes=config.max_nodes,
        **config.rdzv_configs,
    )

    agent = None
    rdzv_handler = rdzv_registry.get_rendezvous_handler(rdzv_parameters)
    master_addr, master_port = _get_addr_and_port(rdzv_parameters)
    try:
        spec = WorkerSpec(
            role=config.role,
            local_world_size=config.nproc_per_node,
            entrypoint=entrypoint,
            args=tuple(args),
            rdzv_handler=rdzv_handler,
            max_restarts=config.max_restarts,
            monitor_interval=config.monitor_interval,
            redirects=config.redirects,
            tee=config.tee,
            master_addr=master_addr,
            master_port=master_port,
        )

        cfg = metrics.MetricsConfig(
            config.metrics_cfg) if config.metrics_cfg else None
        metrics.initialize_metrics(cfg)

        agent = LocalElasticAgent(spec=spec,
                                  start_method=config.start_method,
                                  log_dir=config.log_dir)

        result = agent.run()
        events.record(agent.get_agent_status_event(WorkerState.SUCCEEDED))
        if result.is_failed():
            # ChildFailedError is treated specially by @record
            # if the error files for the failed children exist
            # @record will copy the first error (root cause)
            # to the error file of the launcher process.
            raise ChildFailedError(
                name=entrypoint_name,
                failures=result.failures,
            )
        else:
            return result.return_values
    except ChildFailedError:
        raise
    except Exception:
        if agent:
            events.record(agent.get_agent_status_event(WorkerState.FAILED))
        else:
            events.record(_construct_event(config))
        raise
    finally:
        rdzv_handler.shutdown()
Пример #6
0
def launch_agent(
    config: LaunchConfig,
    entrypoint: Union[Callable, str, None],
    args: List[Any],
) -> Dict[int, Any]:
    if not config.run_id:
        run_id = str(uuid.uuid4().int)
        logger.warning(
            f"config has no run_id, generated a random run_id: {run_id}")
        config.run_id = run_id

    entrypoint_name = _get_entrypoint_name(entrypoint, args)

    logger.info(f"Starting elastic_operator with launch configs:\n"
                f"  entrypoint       : {entrypoint_name}\n"
                f"  min_nodes        : {config.min_nodes}\n"
                f"  max_nodes        : {config.max_nodes}\n"
                f"  nproc_per_node   : {config.nproc_per_node}\n"
                f"  run_id           : {config.run_id}\n"
                f"  rdzv_backend     : {config.rdzv_backend}\n"
                f"  rdzv_endpoint    : {config.rdzv_endpoint}\n"
                f"  rdzv_configs     : {config.rdzv_configs}\n"
                f"  max_restarts     : {config.max_restarts}\n"
                f"  monitor_interval : {config.monitor_interval}\n"
                f"  log_dir          : {config.log_dir}\n"
                f"  metrics_cfg      : {config.metrics_cfg}\n")

    rdzv_parameters = RendezvousParameters(
        backend=config.rdzv_backend,
        endpoint=config.rdzv_endpoint,
        run_id=config.run_id,
        min_nodes=config.min_nodes,
        max_nodes=config.max_nodes,
        **config.rdzv_configs,
    )

    master_addr, master_port = _get_addr_and_port(rdzv_parameters)

    spec = WorkerSpec(
        role=config.role,
        local_world_size=config.nproc_per_node,
        entrypoint=entrypoint,
        args=tuple(args),
        rdzv_handler=rdzv_registry.get_rendezvous_handler(rdzv_parameters),
        max_restarts=config.max_restarts,
        monitor_interval=config.monitor_interval,
        redirects=config.redirects,
        tee=config.tee,
        master_addr=master_addr,
        master_port=master_port,
    )

    agent = LocalElasticAgent(spec=spec,
                              start_method=config.start_method,
                              log_dir=config.log_dir)

    shutdown_rdzv = True
    try:
        metrics.initialize_metrics(metrics.MetricsConfig(config.metrics_cfg))

        result = agent.run()
        # records that agent.run() has succeeded NOT that workers have succeeded
        events.record(agent.get_event_succeeded())

        if result.is_failed():
            # ChildFailedError is treated specially by @record
            # if the error files for the failed children exist
            # @record will copy the first error (root cause)
            # to the error file of the launcher process.
            raise ChildFailedError(
                name=entrypoint_name,
                failures=result.failures,
            )

        return result.return_values
    except ChildFailedError:
        raise
    except SignalException:
        # when the agent dies with a signal do NOT shutdown the rdzv_handler
        # since this closes the rendezvous on this rdzv_id permanently and
        # prevents any additional scaling events
        shutdown_rdzv = False
        events.record(agent.get_event_failed())
        raise
    except Exception:
        events.record(agent.get_event_failed())
        raise
    finally:
        if shutdown_rdzv:
            spec.rdzv_handler.shutdown()
Пример #7
0
def main(args=None):
    # If ``args`` not passed, defaults to ``sys.argv[:1]``
    args = parse_args(args)
    min_nodes, max_nodes = parse_min_max_nnodes(args.nnodes)
    assert 0 < min_nodes <= max_nodes
    assert args.max_restarts >= 0

    elastic_agent = None

    if args.standalone:
        etcd_server = EtcdServer()
        etcd_server.start()
        args.rdzv_backend = "etcd"
        args.rdzv_endpoint = etcd_server.get_endpoint()
        args.rdzv_id = str(uuid.uuid4())
        log.info(f"\n**************************************\n"
                 f"Rendezvous info:\n"
                 f"--rdzv_backend={args.rdzv_backend} "
                 f"--rdzv_endpoint={args.rdzv_endpoint} "
                 f"--rdzv_id={args.rdzv_id}\n"
                 f"**************************************\n")

    nproc_per_node = determine_local_world_size(args.nproc_per_node)
    if "OMP_NUM_THREADS" not in os.environ and nproc_per_node > 1:
        omp_num_threads = 1
        print(
            f"*****************************************\n"
            f"Setting OMP_NUM_THREADS environment variable for each process to be "
            f"{omp_num_threads} in default, to avoid your system being overloaded, "
            f"please further tune the variable for optimal performance in "
            f"your application as needed. \n"
            f"*****************************************")
        # This env variable will be passed down to the subprocesses
        os.environ["OMP_NUM_THREADS"] = str(omp_num_threads)

    with_python = not args.no_python
    cmd = []
    if with_python:
        cmd = [sys.executable, "-u"]
        if args.module:
            cmd.append("-m")
    else:
        if args.module:
            raise ValueError("Don't use both the '--no_python' flag"
                             " and the '--module' flag at the same time.")

    cmd.append(args.training_script)
    cmd.extend(args.training_script_args)

    rdzv_parameters = RendezvousParameters(
        backend=args.rdzv_backend,
        endpoint=args.rdzv_endpoint,
        run_id=args.rdzv_id,
        min_nodes=min_nodes,
        max_nodes=max_nodes,
        **_parse_rendezvous_config(args.rdzv_conf),
    )

    rdzv_handler = rdzv_registry.get_rendezvous_handler(rdzv_parameters)
    try:
        spec = WorkerSpec(
            role=args.role,
            local_world_size=nproc_per_node,
            entrypoint=cmd[0],
            args=(*cmd[1:], ),
            rdzv_handler=rdzv_handler,
            max_restarts=args.max_restarts,
            monitor_interval=args.monitor_interval,
            redirects=Std.from_str(args.redirects),
            tee=Std.from_str(args.tee),
        )
        metrics.initialize_metrics()
        elastic_agent = LocalElasticAgent(spec=spec,
                                          start_method=args.start_method,
                                          log_dir=args.log_dir)
        run_result = elastic_agent.run(spec.role)
        events.record(
            elastic_agent.get_agent_status_event(WorkerState.SUCCEEDED))
        if run_result.is_failed():
            # ChildFailedError is treated specially by @record
            # if the error files for the failed children exist
            # @record will copy the first error (root cause)
            # to the error file of the launcher process
            raise ChildFailedError(
                name=args.training_script,
                failures=run_result.failures,
            )
    except ChildFailedError:
        raise
    except Exception:
        if elastic_agent:
            events.record(
                elastic_agent.get_agent_status_event(WorkerState.FAILED))
        else:
            events.record(_construct_event(args))
        raise
    finally:
        rdzv_handler.shutdown()
        if args.standalone:
            etcd_server.stop()