コード例 #1
0
    def test_construct_rdzv(self):
        params = parameters.RendezvousParameters("mocked-rdzv",
                                                 "localhost:8081", "1234", 1,
                                                 4)

        rdzv = parameters.get_rendezvous(params)
        self.assertTrue(rdzv.__class__ is MockedRdzv)
コード例 #2
0
def main(args=None):
    # If ``args`` not passed, defaults to ``sys.argv[:1]``
    args = parse_args(args)

    min_nodes, max_nodes = parse_min_max_nnodes(args.nnodes)
    assert 0 < min_nodes <= max_nodes
    assert args.max_restarts > 0

    rdzv_parameters = parameters.RendezvousParameters(
        args.rdzv_backend,
        args.rdzv_endpoint,
        args.rdzv_id,
        min_nodes,
        max_nodes,
        args.rdzv_conf,
    )

    rdzv_handler = parameters.get_rendezvous(rdzv_parameters)

    omp_num_threads = None
    if "OMP_NUM_THREADS" not in os.environ and args.nproc_per_node > 1:
        omp_num_threads = 1
        print(
            f"*****************************************\n"
            f"Setting OMP_NUM_THREADS environment variable for each process to be "
            f"{omp_num_threads} in default, to avoid your system being overloaded, "
            f"please further tune the variable for optimal performance in "
            f"your application as needed. \n"
            f"*****************************************")

    with_python = not args.no_python
    cmd = []
    if with_python:
        cmd = [sys.executable, "-u"]
        if args.module:
            cmd.append("-m")
    else:
        if not args.use_env:
            raise ValueError("When using the '--no_python' flag,"
                             " you must also set the '--use_env' flag.")
        if args.module:
            raise ValueError("Don't use both the '--no_python' flag"
                             " and the '--module' flag at the same time.")

    cmd.append(args.training_script)
    cmd.extend(args.training_script_args)

    spec = WorkerSpec(
        role="default",
        local_world_size=args.nproc_per_node,
        fn=wrapper_fn,
        args=(omp_num_threads, args.use_env, cmd),
        rdzv_handler=rdzv_handler,
        max_restarts=args.max_restarts,
        monitor_interval=args.monitor_interval,
    )
    metrics.initialize_metrics()
    elastic_agent = LocalElasticAgent(spec, start_method=args.start_method)
    elastic_agent.run(spec.role)
コード例 #3
0
    def test_get_rdzv_url_no_conf(self):
        params = parameters.RendezvousParameters("etcd", "localhost:8081",
                                                 "1234", 1, 4)

        actual_url = parameters._construct_rendezvous_url(params)

        expected_url = "etcd://localhost:8081/1234" "?min_workers=1" "&max_workers=4"

        self.assertEqual(expected_url, actual_url)
コード例 #4
0
    def test_construct_rdzv_url(self):
        params = parameters.RendezvousParameters(
            "etcd",
            "localhost:8081",
            "1234",
            1,
            4,
            "timeout=60,protocol=https,key=/etc/kubernetes/certs/client.key",
        )
        actual_url = parameters._construct_rendezvous_url(params)

        expected_url = ("etcd://localhost:8081/1234"
                        "?min_workers=1"
                        "&max_workers=4"
                        "&timeout=60"
                        "&protocol=https"
                        "&key=/etc/kubernetes/certs/client.key")

        self.assertEqual(expected_url, actual_url)
コード例 #5
0
ファイル: launch.py プロジェクト: freegliboracle/elastic
def main(args=None):
    # If ``args`` not passed, defaults to ``sys.argv[:1]``
    args = parse_args(args)

    min_nodes, max_nodes = parse_min_max_nnodes(args.nnodes)
    assert 0 < min_nodes <= max_nodes
    assert args.max_restarts >= 0

    if args.standalone:
        etcd_server = EtcdServer()
        etcd_server.start()
        args.rdzv_backend = "etcd"
        args.rdzv_endpoint = etcd_server.get_endpoint()
        args.rdzv_id = str(uuid.uuid4())
        log.info(
            f"\n**************************************\n"
            f"Rendezvous info:\n"
            f"--rdzv_backend={args.rdzv_backend} "
            f"--rdzv_endpoint={args.rdzv_endpoint} "
            f"--rdzv_id={args.rdzv_id}\n"
            f"**************************************\n"
        )

    nproc_per_node = determine_local_world_size(args.nproc_per_node)
    omp_num_threads = None
    if "OMP_NUM_THREADS" not in os.environ and nproc_per_node > 1:
        omp_num_threads = 1
        print(
            f"*****************************************\n"
            f"Setting OMP_NUM_THREADS environment variable for each process to be "
            f"{omp_num_threads} in default, to avoid your system being overloaded, "
            f"please further tune the variable for optimal performance in "
            f"your application as needed. \n"
            f"*****************************************"
        )

    with_python = not args.no_python
    cmd = []
    if with_python:
        cmd = [sys.executable, "-u"]
        if args.module:
            cmd.append("-m")
    else:
        if args.module:
            raise ValueError(
                "Don't use both the '--no_python' flag"
                " and the '--module' flag at the same time."
            )

    cmd.append(args.training_script)
    cmd.extend(args.training_script_args)

    rdzv_parameters = parameters.RendezvousParameters(
        args.rdzv_backend,
        args.rdzv_endpoint,
        args.rdzv_id,
        min_nodes,
        max_nodes,
        args.rdzv_conf,
    )

    rdzv_handler = parameters.get_rendezvous(rdzv_parameters)

    try:
        spec = WorkerSpec(
            role="default",
            local_world_size=nproc_per_node,
            fn=wrapper_fn,
            args=(omp_num_threads, cmd),
            rdzv_handler=rdzv_handler,
            max_restarts=args.max_restarts,
            monitor_interval=args.monitor_interval,
        )
        metrics.initialize_metrics()
        elastic_agent = LocalElasticAgent(spec, start_method=args.start_method)
        elastic_agent.run(spec.role)
    finally:
        rdzv_handler.shutdown()

    if args.standalone:
        etcd_server.stop()