コード例 #1
0
ファイル: config.py プロジェクト: rosinality/tensorfn
def elastic_config(args):
    min_node, max_node = parse_min_max_nodes(args.n_node)
    n_proc = local_world_size(args.n_proc)

    rdzv_configs = _parse_rendezvous_config(args.rdzv_conf)

    if args.rdzv_backend == "static":
        rdzv_configs["rank"] = args.node_rank

    rdzv_endpoint = get_rdzv_endpoint(args, max_node)

    config = LaunchConfig(
        min_nodes=min_node,
        max_nodes=max_node,
        nproc_per_node=n_proc,
        run_id=args.rdzv_id,
        role=args.role,
        rdzv_endpoint=rdzv_endpoint,
        rdzv_backend=args.rdzv_backend,
        rdzv_configs=rdzv_configs,
        max_restarts=args.max_restarts,
        monitor_interval=args.monitor_interval,
        start_method=args.start_method,
        redirects=Std.from_str(args.redirects),
        tee=Std.from_str(args.tee),
        log_dir=args.log_dir,
    )

    return config
コード例 #2
0
ファイル: api_test.py プロジェクト: xsacha/pytorch
 def get_test_launch_config(
     self,
     min_nodes: int,
     max_nodes: int,
     nproc_per_node: int,
     run_id: str = "",
     rdzv_backend: str = "etcd",
     config: Optional[Dict[str, Any]] = None,
     rdzv_endpoint: Optional[str] = None,
 ) -> LaunchConfig:
     rdzv_configs = {}
     if config:
         rdzv_configs.update(config)
     endpoint = self._etcd_endpoint
     if rdzv_endpoint:
         endpoint = rdzv_endpoint
     return LaunchConfig(
         min_nodes=min_nodes,
         max_nodes=max_nodes,
         nproc_per_node=nproc_per_node,
         run_id=run_id,
         rdzv_endpoint=endpoint,
         monitor_interval=1,
         rdzv_backend=rdzv_backend,
         start_method="spawn",
         max_restarts=0,
         rdzv_configs=rdzv_configs,
     )
コード例 #3
0
ファイル: api_test.py プロジェクト: yuhc/ava-pytorch
 def get_test_launch_config(
     self,
     min_nodes: int,
     max_nodes: int,
     nproc_per_node: int,
     run_id: str = "",
 ) -> LaunchConfig:
     return LaunchConfig(
         min_nodes=min_nodes,
         max_nodes=max_nodes,
         nproc_per_node=nproc_per_node,
         run_id=run_id,
         rdzv_endpoint=self._etcd_endpoint,
         monitor_interval=1,
         rdzv_backend="etcd",
         start_method="fork",
         max_restarts=0,
     )
コード例 #4
0
ファイル: run.py プロジェクト: zacker150/pytorch
def config_from_args(
        args) -> Tuple[LaunchConfig, Union[Callable, str], List[str]]:
    # If ``args`` not passed, defaults to ``sys.argv[:1]``
    min_nodes, max_nodes = parse_min_max_nnodes(args.nnodes)
    assert 0 < min_nodes <= max_nodes
    assert args.max_restarts >= 0

    nproc_per_node = determine_local_world_size(args.nproc_per_node)
    if "OMP_NUM_THREADS" not in os.environ and nproc_per_node > 1:
        omp_num_threads = 1
        log.warning(
            f"\n*****************************************\n"
            f"Setting OMP_NUM_THREADS environment variable for each process to be "
            f"{omp_num_threads} in default, to avoid your system being overloaded, "
            f"please further tune the variable for optimal performance in "
            f"your application as needed. \n"
            f"*****************************************")
        # This env variable will be passed down to the subprocesses
        os.environ["OMP_NUM_THREADS"] = str(omp_num_threads)

    rdzv_configs = _parse_rendezvous_config(args.rdzv_conf)

    if args.rdzv_backend == "static":
        rdzv_configs["rank"] = args.node_rank

    rdzv_endpoint = get_rdzv_endpoint(args)

    config = LaunchConfig(
        min_nodes=min_nodes,
        max_nodes=max_nodes,
        nproc_per_node=nproc_per_node,
        run_id=args.rdzv_id,
        role=args.role,
        rdzv_endpoint=rdzv_endpoint,
        rdzv_backend=args.rdzv_backend,
        rdzv_configs=rdzv_configs,
        max_restarts=args.max_restarts,
        monitor_interval=args.monitor_interval,
        start_method=args.start_method,
        redirects=Std.from_str(args.redirects),
        tee=Std.from_str(args.tee),
        log_dir=args.log_dir,
    )

    with_python = not args.no_python
    cmd: Union[Callable, str]
    cmd_args = []
    use_env = get_use_env(args)
    if args.run_path:
        cmd = run_script_path
        cmd_args.append(args.training_script)
    else:
        if with_python:
            cmd = os.getenv("PYTHON_EXEC", sys.executable)
            cmd_args.append("-u")
            if args.module:
                cmd_args.append("-m")
            cmd_args.append(args.training_script)
        else:
            if args.module:
                raise ValueError("Don't use both the '--no_python' flag"
                                 " and the '--module' flag at the same time.")
            cmd = args.training_script
    if not use_env:
        cmd_args.append(f"--local_rank={macros.local_rank}")
    cmd_args.extend(args.training_script_args)

    return config, cmd, cmd_args