예제 #1
0
파일: api.py 프로젝트: talrasha/pytorch
def launch_agent(
    config: LaunchConfig,
    entrypoint: Union[Callable, str, None],
    args: List[Any],
) -> Dict[int, Any]:
    if not config.run_id:
        run_id = str(uuid.uuid4().int)
        logger.warning(f"config has no run_id, generate a new one: {run_id}")
        config.run_id = run_id

    entrypoint_name = _get_entrypoint_name(entrypoint, args)

    logger.info(f"Starting elastic_operator with launch configs:\n"
                f"  entrypoint       : {entrypoint_name}\n"
                f"  min_nodes        : {config.min_nodes}\n"
                f"  max_nodes        : {config.max_nodes}\n"
                f"  nproc_per_node   : {config.nproc_per_node}\n"
                f"  run_id           : {config.run_id}\n"
                f"  rdzv_backend     : {config.rdzv_backend}\n"
                f"  rdzv_endpoint    : {config.rdzv_endpoint}\n"
                f"  rdzv_configs     : {config.rdzv_configs}\n"
                f"  max_restarts     : {config.max_restarts}\n"
                f"  monitor_interval : {config.monitor_interval}\n"
                f"  log_dir          : {config.log_dir}\n"
                f"  metrics_cfg      : {config.metrics_cfg}\n")

    rdzv_parameters = RendezvousParameters(
        backend=config.rdzv_backend,
        endpoint=config.rdzv_endpoint,
        run_id=config.run_id,
        min_nodes=config.min_nodes,
        max_nodes=config.max_nodes,
        **config.rdzv_configs,
    )

    agent = None
    rdzv_handler = rdzv_registry.get_rendezvous_handler(rdzv_parameters)
    master_addr, master_port = _get_addr_and_port(rdzv_parameters)
    try:
        spec = WorkerSpec(
            role=config.role,
            local_world_size=config.nproc_per_node,
            entrypoint=entrypoint,
            args=tuple(args),
            rdzv_handler=rdzv_handler,
            max_restarts=config.max_restarts,
            monitor_interval=config.monitor_interval,
            redirects=config.redirects,
            tee=config.tee,
            master_addr=master_addr,
            master_port=master_port,
        )

        cfg = metrics.MetricsConfig(
            config.metrics_cfg) if config.metrics_cfg else None
        metrics.initialize_metrics(cfg)

        agent = LocalElasticAgent(spec=spec,
                                  start_method=config.start_method,
                                  log_dir=config.log_dir)

        result = agent.run()
        events.record(agent.get_agent_status_event(WorkerState.SUCCEEDED))
        if result.is_failed():
            # ChildFailedError is treated specially by @record
            # if the error files for the failed children exist
            # @record will copy the first error (root cause)
            # to the error file of the launcher process.
            raise ChildFailedError(
                name=entrypoint_name,
                failures=result.failures,
            )
        else:
            return result.return_values
    except ChildFailedError:
        raise
    except Exception:
        if agent:
            events.record(agent.get_agent_status_event(WorkerState.FAILED))
        else:
            events.record(_construct_event(config))
        raise
    finally:
        rdzv_handler.shutdown()
예제 #2
0
def launch_agent(
    config: LaunchConfig,
    entrypoint: Union[Callable, str, None],
    args: List[Any],
) -> Dict[int, Any]:
    if not config.run_id:
        run_id = str(uuid.uuid4().int)
        logger.warning(
            f"config has no run_id, generated a random run_id: {run_id}")
        config.run_id = run_id

    entrypoint_name = _get_entrypoint_name(entrypoint, args)

    logger.info(f"Starting elastic_operator with launch configs:\n"
                f"  entrypoint       : {entrypoint_name}\n"
                f"  min_nodes        : {config.min_nodes}\n"
                f"  max_nodes        : {config.max_nodes}\n"
                f"  nproc_per_node   : {config.nproc_per_node}\n"
                f"  run_id           : {config.run_id}\n"
                f"  rdzv_backend     : {config.rdzv_backend}\n"
                f"  rdzv_endpoint    : {config.rdzv_endpoint}\n"
                f"  rdzv_configs     : {config.rdzv_configs}\n"
                f"  max_restarts     : {config.max_restarts}\n"
                f"  monitor_interval : {config.monitor_interval}\n"
                f"  log_dir          : {config.log_dir}\n"
                f"  metrics_cfg      : {config.metrics_cfg}\n")

    rdzv_parameters = RendezvousParameters(
        backend=config.rdzv_backend,
        endpoint=config.rdzv_endpoint,
        run_id=config.run_id,
        min_nodes=config.min_nodes,
        max_nodes=config.max_nodes,
        **config.rdzv_configs,
    )

    master_addr, master_port = _get_addr_and_port(rdzv_parameters)

    spec = WorkerSpec(
        role=config.role,
        local_world_size=config.nproc_per_node,
        entrypoint=entrypoint,
        args=tuple(args),
        rdzv_handler=rdzv_registry.get_rendezvous_handler(rdzv_parameters),
        max_restarts=config.max_restarts,
        monitor_interval=config.monitor_interval,
        redirects=config.redirects,
        tee=config.tee,
        master_addr=master_addr,
        master_port=master_port,
    )

    agent = LocalElasticAgent(spec=spec,
                              start_method=config.start_method,
                              log_dir=config.log_dir)

    shutdown_rdzv = True
    try:
        metrics.initialize_metrics(metrics.MetricsConfig(config.metrics_cfg))

        result = agent.run()
        # records that agent.run() has succeeded NOT that workers have succeeded
        events.record(agent.get_event_succeeded())

        if result.is_failed():
            # ChildFailedError is treated specially by @record
            # if the error files for the failed children exist
            # @record will copy the first error (root cause)
            # to the error file of the launcher process.
            raise ChildFailedError(
                name=entrypoint_name,
                failures=result.failures,
            )

        return result.return_values
    except ChildFailedError:
        raise
    except SignalException:
        # when the agent dies with a signal do NOT shutdown the rdzv_handler
        # since this closes the rendezvous on this rdzv_id permanently and
        # prevents any additional scaling events
        shutdown_rdzv = False
        events.record(agent.get_event_failed())
        raise
    except Exception:
        events.record(agent.get_event_failed())
        raise
    finally:
        if shutdown_rdzv:
            spec.rdzv_handler.shutdown()