def testRelaunchPsPod(self): num_ps = 3 instance_manager = InstanceManager( task_d=None, job_name="test-relaunch-ps-pod-%d-%d" % (int(time.time()), random.randint(1, 101)), image_name="gcr.io/google-samples/hello-app:1.0", ps_command=["sleep 10"], ps_args=[], namespace="default", num_ps=num_ps, ) instance_manager.start_parameter_servers() # Check we also have ps services started for i in range(num_ps): service = instance_manager._k8s_client.get_ps_service(i) self.assertTrue(service.metadata.owner_references) owner = service.metadata.owner_references[0] self.assertEqual(owner.kind, "Pod") self.assertEqual( owner.name, instance_manager._k8s_client.get_ps_pod_name(i) ) max_check_num = 60 for _ in range(max_check_num): time.sleep(1) counters = instance_manager.get_ps_counter() if counters["Running"] + counters["Pending"] > 0: break # Note: There is a slight chance of race condition. # Hack to find a ps to remove all_current_ps = set() all_live_ps = set() with instance_manager._lock: for k, (_, phase) in instance_manager._ps_pods_phase.items(): all_current_ps.add(k) if phase in ["Running", "Pending"]: all_live_ps.add(k) self.assertTrue(all_live_ps) ps_to_be_removed = all_live_ps.pop() all_current_ps.remove(ps_to_be_removed) instance_manager._remove_ps(ps_to_be_removed) # Verify a new ps gets launched found = False for _ in range(max_check_num): if found: break time.sleep(1) with instance_manager._lock: for k in instance_manager._ps_pods_phase: if k not in all_current_ps: found = True else: self.fail("Failed to find newly launched ps.") instance_manager.stop_relaunch_and_remove_all_ps()
def testFailedWorkerPod(self): """ Start a pod running a python program destined to fail with restart_policy="Never" to test failed_worker_count """ task_d = _TaskDispatcher({"f": (0, 10)}, {}, {}, 1, 1) task_d.recover_tasks = MagicMock() instance_manager = InstanceManager( task_d, job_name="test-failed-worker-pod-%d-%d" % (int(time.time()), random.randint(1, 101)), image_name="gcr.io/google-samples/hello-app:1.0", worker_command=["badcommand"], worker_args=["badargs"], namespace="default", num_workers=3, restart_policy="Never", ) instance_manager.start_workers() max_check_num = 20 for _ in range(max_check_num): time.sleep(3) counters = instance_manager.get_worker_counter() if counters["Failed"] == 3: break instance_manager.stop_relaunch_and_remove_workers() for _ in range(max_check_num): time.sleep(3) counters = instance_manager.get_worker_counter() if not counters: break task_d.recover_tasks.assert_has_calls( [call(0), call(1), call(2)], any_order=True )
def testCreateDeleteWorkerPod(self): task_d = _TaskDispatcher({"f": (0, 10)}, {}, {}, 1, 1) task_d.recover_tasks = MagicMock() instance_manager = InstanceManager( task_d, job_name="test-create-worker-pod-%d-%d" % (int(time.time()), random.randint(1, 101)), image_name="gcr.io/google-samples/hello-app:1.0", worker_command=["echo"], worker_args=[], namespace="default", num_workers=3, ) instance_manager.start_workers() max_check_num = 20 for _ in range(max_check_num): time.sleep(3) counters = instance_manager.get_worker_counter() if counters["Succeeded"] == 3: break instance_manager.stop_relaunch_and_remove_workers() for _ in range(max_check_num): time.sleep(3) counters = instance_manager.get_worker_counter() if not counters: break task_d.recover_tasks.assert_has_calls( [call(0), call(1), call(2)], any_order=True )
def test_create_delete_worker_pod(self): task_d = _TaskDispatcher({"f": (0, 10)}, {}, {}, 1, 1) task_d.recover_tasks = MagicMock() instance_manager = InstanceManager( task_d, job_name="test-create-worker-pod-%d-%d" % (int(time.time()), random.randint(1, 101)), image_name="ubuntu:18.04", worker_command=["/bin/bash"], worker_args=["-c", "echo"], namespace="default", num_workers=3, ) instance_manager.start_workers() max_check_num = 20 for _ in range(max_check_num): time.sleep(3) counters = instance_manager.get_worker_counter() if counters["Succeeded"] == 3: break instance_manager.stop_relaunch_and_remove_workers() for _ in range(max_check_num): time.sleep(3) counters = instance_manager.get_worker_counter() if not counters: break self.assertFalse(counters)
def test_relaunch_worker_pod(self): num_workers = 3 task_d = _TaskDispatcher({"f": (0, 10)}, {}, {}, 1, 1) instance_manager = InstanceManager( task_d, job_name="test-relaunch-worker-pod-%d-%d" % (int(time.time()), random.randint(1, 101)), image_name="ubuntu:18.04", worker_command=["/bin/bash"], worker_args=["-c", "sleep 10 #"], namespace="default", num_workers=num_workers, ) instance_manager.start_workers() max_check_num = 60 for _ in range(max_check_num): time.sleep(1) counters = instance_manager.get_worker_counter() if counters["Running"] + counters["Pending"] > 0: break # Note: There is a slight chance of race condition. # Hack to find a worker to remove current_workers = set() live_workers = set() with instance_manager._lock: for ( k, (_, _, phase), ) in instance_manager._worker_pods_ip_phase.items(): current_workers.add(k) if phase in ["Running", "Pending"]: live_workers.add(k) self.assertTrue(live_workers) instance_manager._remove_worker(live_workers.pop()) # verify a new worker get launched found = False for _ in range(max_check_num): if found: break time.sleep(1) with instance_manager._lock: for k in instance_manager._worker_pods_ip_phase: if k not in range(num_workers, num_workers * 2): found = True else: self.fail("Failed to find newly launched worker.") instance_manager.stop_relaunch_and_remove_workers()
def test_get_worker_addrs(self): task_d = _TaskDispatcher({"f": (0, 10)}, {}, {}, 1, 1) instance_manager = InstanceManager( task_d, job_name="test-create-worker-pod-%d-%d" % (int(time.time()), random.randint(1, 101)), image_name="ubuntu:18.04", worker_command=["/bin/bash"], worker_args=["-c", "sleep 5 #"], namespace="default", num_workers=3, ) instance_manager.start_workers() max_check_num = 20 for _ in range(max_check_num): time.sleep(3) counters = instance_manager.get_worker_counter() if counters["Running"]: worker_addrs = instance_manager._get_alive_worker_addr() self.assertEqual(len(worker_addrs), counters["Running"]) instance_manager.stop_relaunch_and_remove_workers()
def _create_instance_manager(self, args): instance_manager = None container_command = ["/bin/bash"] if args.num_workers: assert args.worker_image, "Worker image cannot be empty" worker_client_command = ( BashCommandTemplate.SET_PIPEFAIL + " python -m elasticdl.python.worker.main" ) worker_args = [ "--master_addr", self.master_addr, "--job_type", self.job_type, ] worker_args.extend( build_arguments_from_parsed_result(args, filter_args=["envs"]) ) worker_args = wrap_python_args_with_string(worker_args) worker_args.insert(0, worker_client_command) if args.use_go_ps: opt_type, opt_args = get_optimizer_info(self.optimizer) ps_command = "elasticdl_ps" ps_command_args = [ "-job_name=" + args.job_name, "-namespace=" + args.namespace, "-master_addr=" + self.master_addr, "-port=2222", "-use_async=" + ("true" if args.use_async else "false"), "-grads_to_wait=" + str(args.grads_to_wait), "-lr_staleness_modulation=" + ("true" if args.lr_staleness_modulation else "false"), "-sync_version_tolerance=" + str(args.sync_version_tolerance), "-evaluation_steps=" + str(args.evaluation_steps), "-num_ps_pods=" + str(args.num_ps_pods), "-num_workers=" + str(args.num_workers), "-checkpoint_dir=" + str(args.checkpoint_dir), "-checkpoint_steps=" + str(args.checkpoint_steps), "-keep_checkpoint_max=" + str(args.keep_checkpoint_max), "-checkpoint_dir_for_init=" + str(args.checkpoint_dir_for_init), "-opt_type=" + opt_type, "-opt_args=" + opt_args, ] ps_command_args = wrap_go_args_with_string(ps_command_args) # Execute source /root/.bashrc to add the file path # of `elasticdl_ps` into the PATH environment variable. ps_args = [ "source", "/root/.bashrc_elasticdl", "&&", ps_command, ] ps_args.extend(ps_command_args) else: ps_command = ( BashCommandTemplate.SET_PIPEFAIL + " python -m elasticdl.python.ps.main" ) ps_command_args = [ "--grads_to_wait", str(args.grads_to_wait), "--lr_staleness_modulation", str(args.lr_staleness_modulation), "--sync_version_tolerance", str(args.sync_version_tolerance), "--use_async", str(args.use_async), "--model_zoo", args.model_zoo, "--model_def", args.model_def, "--job_name", args.job_name, "--port", "2222", "--master_addr", self.master_addr, "--namespace", args.namespace, "--evaluation_steps", str(args.evaluation_steps), "--checkpoint_dir", str(args.checkpoint_dir), "--checkpoint_steps", str(args.checkpoint_steps), "--keep_checkpoint_max", str(args.keep_checkpoint_max), "--num_ps_pods", str(args.num_ps_pods), "--checkpoint_dir_for_init", str(args.checkpoint_dir_for_init), "--num_workers", str(args.num_workers), "--log_level", str(args.log_level), "--minibatch_size", str(args.minibatch_size), "--num_minibatches_per_task", str(args.num_minibatches_per_task), ] ps_args = wrap_python_args_with_string(ps_command_args) ps_args.insert(0, ps_command) worker_args = ["-c", " ".join(worker_args)] ps_args = ["-c", " ".join(ps_args)] env_dict = parse_envs(args.envs) env = [] for key in env_dict: env.append(V1EnvVar(name=key, value=env_dict[key])) kwargs = get_dict_from_params_str(args.aux_params) disable_relaunch = kwargs.get("disable_relaunch", False) cluster_spec = self._get_image_cluster_spec(args.cluster_spec) instance_manager = InstanceManager( self.task_d, rendezvous_server=self.rendezvous_server, job_name=args.job_name, image_name=args.worker_image, worker_command=container_command, worker_args=worker_args, namespace=args.namespace, num_workers=args.num_workers, worker_resource_request=args.worker_resource_request, worker_resource_limit=args.worker_resource_limit, worker_pod_priority=args.worker_pod_priority, num_ps=args.num_ps_pods, ps_command=container_command, ps_args=ps_args, ps_resource_request=args.ps_resource_request, ps_resource_limit=args.ps_resource_limit, ps_pod_priority=args.ps_pod_priority, volume=args.volume, image_pull_policy=args.image_pull_policy, restart_policy=args.restart_policy, cluster_spec=cluster_spec, envs=env, disable_relaunch=disable_relaunch, log_file_path=args.log_file_path, ) return instance_manager
def _create_instance_manager(self, args): instance_manager = None if args.num_workers: assert args.worker_image, "Worker image cannot be empty" worker_command = ["python"] worker_args = [ "-m", "elasticdl.python.worker.main", "--master_addr", self.master_addr, "--job_type", self.job_type, ] worker_args.extend(build_arguments_from_parsed_result(args)) ps_command = ["python"] ps_args = [ "-m", "elasticdl.python.ps.main", "--grads_to_wait", str(args.grads_to_wait), "--lr_staleness_modulation", str(args.lr_staleness_modulation), "--use_async", str(args.use_async), "--minibatch_size", str(args.minibatch_size), "--model_zoo", args.model_zoo, "--model_def", args.model_def, "--job_name", args.job_name, "--num_minibatches_per_task", str(args.num_minibatches_per_task), "--port", "2222", "--master_addr", self.master_addr, "--namespace", args.namespace, "--evaluation_steps", str(args.evaluation_steps), "--checkpoint_dir", str(args.checkpoint_dir), "--checkpoint_steps", str(args.checkpoint_steps), "--keep_checkpoint_max", str(args.keep_checkpoint_max), "--num_ps_pods", str(args.num_ps_pods), "--checkpoint_dir_for_init", str(args.checkpoint_dir_for_init), "--num_workers", str(args.num_workers), "--log_level", str(args.log_level), ] env_dict = parse_envs(args.envs) env = [] for key in env_dict: env.append(V1EnvVar(name=key, value=env_dict[key])) instance_manager = InstanceManager( self.task_d, job_name=args.job_name, image_name=args.worker_image, worker_command=worker_command, worker_args=worker_args, namespace=args.namespace, num_workers=args.num_workers, worker_resource_request=args.worker_resource_request, worker_resource_limit=args.worker_resource_limit, worker_pod_priority=args.worker_pod_priority, num_ps=args.num_ps_pods, ps_command=ps_command, ps_args=ps_args, ps_resource_request=args.ps_resource_request, ps_resource_limit=args.ps_resource_limit, ps_pod_priority=args.ps_pod_priority, volume=args.volume, image_pull_policy=args.image_pull_policy, restart_policy=args.restart_policy, cluster_spec=args.cluster_spec, envs=env, ) return instance_manager
def _create_instance_manager(self, args): instance_manager = None if args.num_workers: assert args.worker_image, "Worker image cannot be empty" worker_command = ["python"] worker_args = [ "-m", "elasticdl.python.worker.main", "--master_addr", self.master_addr, "--job_type", self.job_type, ] worker_args.extend(build_arguments_from_parsed_result(args)) if args.use_go_ps: opt_type, opt_args = get_optimizer_info(self.optimizer) # TODO: rename the Go PS executable using a meaningful filename ps_command = ["main"] ps_args = [ "-job_name=" + args.job_name, "-namespace=" + args.namespace, "-master_addr=" + self.master_addr, "-port=2222", "-use_async=" + ("true" if args.use_async else "false"), "-grads_to_wait=" + str(args.grads_to_wait), "-lr_staleness_modulation=" + ("true" if args.lr_staleness_modulation else "false"), "-sync_version_tolerance=" + str(args.sync_version_tolerance), "-evaluation_steps=" + str(args.evaluation_steps), "-num_ps_pods=" + str(args.num_ps_pods), "-num_workers=" + str(args.num_workers), "-checkpoint_dir=" + str(args.checkpoint_dir), "-checkpoint_steps=" + str(args.checkpoint_steps), "-keep_checkpoint_max=" + str(args.keep_checkpoint_max), "-checkpoint_dir_for_init=" + str(args.checkpoint_dir_for_init), "-opt_type=" + opt_type, "-opt_args=" + opt_args, ] else: ps_command = ["python"] ps_args = [ "-m", "elasticdl.python.ps.main", "--grads_to_wait", str(args.grads_to_wait), "--lr_staleness_modulation", str(args.lr_staleness_modulation), "--sync_version_tolerance", str(args.sync_version_tolerance), "--use_async", str(args.use_async), "--model_zoo", args.model_zoo, "--model_def", args.model_def, "--job_name", args.job_name, "--port", "2222", "--master_addr", self.master_addr, "--namespace", args.namespace, "--evaluation_steps", str(args.evaluation_steps), "--checkpoint_dir", str(args.checkpoint_dir), "--checkpoint_steps", str(args.checkpoint_steps), "--keep_checkpoint_max", str(args.keep_checkpoint_max), "--num_ps_pods", str(args.num_ps_pods), "--checkpoint_dir_for_init", str(args.checkpoint_dir_for_init), "--num_workers", str(args.num_workers), "--log_level", str(args.log_level), "--minibatch_size", str(args.minibatch_size), "--num_minibatches_per_task", str(args.num_minibatches_per_task), ] env_dict = parse_envs(args.envs) env = [] for key in env_dict: env.append(V1EnvVar(name=key, value=env_dict[key])) instance_manager = InstanceManager( self.task_d, job_name=args.job_name, image_name=args.worker_image, worker_command=worker_command, worker_args=worker_args, namespace=args.namespace, num_workers=args.num_workers, worker_resource_request=args.worker_resource_request, worker_resource_limit=args.worker_resource_limit, worker_pod_priority=args.worker_pod_priority, num_ps=args.num_ps_pods, ps_command=ps_command, ps_args=ps_args, ps_resource_request=args.ps_resource_request, ps_resource_limit=args.ps_resource_limit, ps_pod_priority=args.ps_pod_priority, volume=args.volume, image_pull_policy=args.image_pull_policy, restart_policy=args.restart_policy, cluster_spec=args.cluster_spec, envs=env, expose_ports=self.distribution_strategy == DistributionStrategy.ALLREDUCE, ) return instance_manager