def test_construct_rdzv(self): params = parameters.RendezvousParameters("mocked-rdzv", "localhost:8081", "1234", 1, 4) rdzv = parameters.get_rendezvous(params) self.assertTrue(rdzv.__class__ is MockedRdzv)
def main(args=None): # If ``args`` not passed, defaults to ``sys.argv[:1]`` args = parse_args(args) min_nodes, max_nodes = parse_min_max_nnodes(args.nnodes) assert 0 < min_nodes <= max_nodes assert args.max_restarts > 0 rdzv_parameters = parameters.RendezvousParameters( args.rdzv_backend, args.rdzv_endpoint, args.rdzv_id, min_nodes, max_nodes, args.rdzv_conf, ) rdzv_handler = parameters.get_rendezvous(rdzv_parameters) omp_num_threads = None if "OMP_NUM_THREADS" not in os.environ and args.nproc_per_node > 1: omp_num_threads = 1 print( f"*****************************************\n" f"Setting OMP_NUM_THREADS environment variable for each process to be " f"{omp_num_threads} in default, to avoid your system being overloaded, " f"please further tune the variable for optimal performance in " f"your application as needed. \n" f"*****************************************") with_python = not args.no_python cmd = [] if with_python: cmd = [sys.executable, "-u"] if args.module: cmd.append("-m") else: if not args.use_env: raise ValueError("When using the '--no_python' flag," " you must also set the '--use_env' flag.") if args.module: raise ValueError("Don't use both the '--no_python' flag" " and the '--module' flag at the same time.") cmd.append(args.training_script) cmd.extend(args.training_script_args) spec = WorkerSpec( role="default", local_world_size=args.nproc_per_node, fn=wrapper_fn, args=(omp_num_threads, args.use_env, cmd), rdzv_handler=rdzv_handler, max_restarts=args.max_restarts, monitor_interval=args.monitor_interval, ) metrics.initialize_metrics() elastic_agent = LocalElasticAgent(spec, start_method=args.start_method) elastic_agent.run(spec.role)
def test_get_rdzv_url_no_conf(self): params = parameters.RendezvousParameters("etcd", "localhost:8081", "1234", 1, 4) actual_url = parameters._construct_rendezvous_url(params) expected_url = "etcd://localhost:8081/1234" "?min_workers=1" "&max_workers=4" self.assertEqual(expected_url, actual_url)
def test_construct_rdzv_url(self): params = parameters.RendezvousParameters( "etcd", "localhost:8081", "1234", 1, 4, "timeout=60,protocol=https,key=/etc/kubernetes/certs/client.key", ) actual_url = parameters._construct_rendezvous_url(params) expected_url = ("etcd://localhost:8081/1234" "?min_workers=1" "&max_workers=4" "&timeout=60" "&protocol=https" "&key=/etc/kubernetes/certs/client.key") self.assertEqual(expected_url, actual_url)
def main(args=None): # If ``args`` not passed, defaults to ``sys.argv[:1]`` args = parse_args(args) min_nodes, max_nodes = parse_min_max_nnodes(args.nnodes) assert 0 < min_nodes <= max_nodes assert args.max_restarts >= 0 if args.standalone: etcd_server = EtcdServer() etcd_server.start() args.rdzv_backend = "etcd" args.rdzv_endpoint = etcd_server.get_endpoint() args.rdzv_id = str(uuid.uuid4()) log.info( f"\n**************************************\n" f"Rendezvous info:\n" f"--rdzv_backend={args.rdzv_backend} " f"--rdzv_endpoint={args.rdzv_endpoint} " f"--rdzv_id={args.rdzv_id}\n" f"**************************************\n" ) nproc_per_node = determine_local_world_size(args.nproc_per_node) omp_num_threads = None if "OMP_NUM_THREADS" not in os.environ and nproc_per_node > 1: omp_num_threads = 1 print( f"*****************************************\n" f"Setting OMP_NUM_THREADS environment variable for each process to be " f"{omp_num_threads} in default, to avoid your system being overloaded, " f"please further tune the variable for optimal performance in " f"your application as needed. \n" f"*****************************************" ) with_python = not args.no_python cmd = [] if with_python: cmd = [sys.executable, "-u"] if args.module: cmd.append("-m") else: if args.module: raise ValueError( "Don't use both the '--no_python' flag" " and the '--module' flag at the same time." ) cmd.append(args.training_script) cmd.extend(args.training_script_args) rdzv_parameters = parameters.RendezvousParameters( args.rdzv_backend, args.rdzv_endpoint, args.rdzv_id, min_nodes, max_nodes, args.rdzv_conf, ) rdzv_handler = parameters.get_rendezvous(rdzv_parameters) try: spec = WorkerSpec( role="default", local_world_size=nproc_per_node, fn=wrapper_fn, args=(omp_num_threads, cmd), rdzv_handler=rdzv_handler, max_restarts=args.max_restarts, monitor_interval=args.monitor_interval, ) metrics.initialize_metrics() elastic_agent = LocalElasticAgent(spec, start_method=args.start_method) elastic_agent.run(spec.role) finally: rdzv_handler.shutdown() if args.standalone: etcd_server.stop()