예제 #1
0
 def get_ps_args(self, args):
     if args.distribution_strategy == DistributionStrategy.PARAMETER_SERVER:
         opt_type, opt_args = get_optimizer_info(self._optimizer)
         ps_command = "elasticdl_ps"
         ps_command_args = [
             "-job_name=" + args.job_name,
             "-namespace=" + args.namespace,
             "-master_addr=" + self.master_addr,
             "-port=2222",
             "-use_async=" + ("true" if args.use_async else "false"),
             "-grads_to_wait=" + str(args.grads_to_wait),
             "-lr_staleness_modulation=" +
             ("true" if args.lr_staleness_modulation else "false"),
             "-sync_version_tolerance=" + str(args.sync_version_tolerance),
             "-evaluation_steps=" + str(args.evaluation_steps),
             "-num_ps_pods=" + str(args.num_ps_pods),
             "-num_workers=" + str(args.num_workers),
             "-checkpoint_dir=" + str(args.checkpoint_dir),
             "-checkpoint_steps=" + str(args.checkpoint_steps),
             "-keep_checkpoint_max=" + str(args.keep_checkpoint_max),
             "-checkpoint_dir_for_init=" +
             str(args.checkpoint_dir_for_init),
             "-opt_type=" + opt_type,
             "-opt_args=" + opt_args,
         ]
         ps_command_args = wrap_go_args_with_string(ps_command_args)
         # Execute source /root/.bashrc to add the file path
         # of `elasticdl_ps` into the PATH environment variable.
         ps_args = ["source", "/root/.bashrc_elasticdl", "&&", ps_command]
         ps_args.extend(ps_command_args)
         ps_args = ["-c", " ".join(ps_args)]
         return ps_args
     else:
         return []
예제 #2
0
 def test_wrap_go_args_with_string(self):
     args = [
         "-ps_id=0",
         "-job_name=test_args",
         "-opt_args=learning_rate=0.1;momentum=0.0;nesterov=False",
     ]
     args = wrap_go_args_with_string(args)
     expected_args = [
         "-ps_id='0'",
         "-job_name='test_args'",
         "-opt_args='learning_rate=0.1;momentum=0.0;nesterov=False'",
     ]
     self.assertListEqual(args, expected_args)
예제 #3
0
    def _create_instance_manager(self, args):
        instance_manager = None

        container_command = ["/bin/bash"]
        if args.num_workers:
            assert args.worker_image, "Worker image cannot be empty"

            worker_client_command = (
                BashCommandTemplate.SET_PIPEFAIL
                + " python -m elasticdl.python.worker.main"
            )
            worker_args = [
                "--master_addr",
                self.master_addr,
                "--job_type",
                self.job_type,
            ]
            worker_args.extend(
                build_arguments_from_parsed_result(args, filter_args=["envs"])
            )
            worker_args = wrap_python_args_with_string(worker_args)
            worker_args.insert(0, worker_client_command)

            if args.use_go_ps:
                opt_type, opt_args = get_optimizer_info(self.optimizer)
                ps_command = "elasticdl_ps"
                ps_command_args = [
                    "-job_name=" + args.job_name,
                    "-namespace=" + args.namespace,
                    "-master_addr=" + self.master_addr,
                    "-port=2222",
                    "-use_async=" + ("true" if args.use_async else "false"),
                    "-grads_to_wait=" + str(args.grads_to_wait),
                    "-lr_staleness_modulation="
                    + ("true" if args.lr_staleness_modulation else "false"),
                    "-sync_version_tolerance="
                    + str(args.sync_version_tolerance),
                    "-evaluation_steps=" + str(args.evaluation_steps),
                    "-num_ps_pods=" + str(args.num_ps_pods),
                    "-num_workers=" + str(args.num_workers),
                    "-checkpoint_dir=" + str(args.checkpoint_dir),
                    "-checkpoint_steps=" + str(args.checkpoint_steps),
                    "-keep_checkpoint_max=" + str(args.keep_checkpoint_max),
                    "-checkpoint_dir_for_init="
                    + str(args.checkpoint_dir_for_init),
                    "-opt_type=" + opt_type,
                    "-opt_args=" + opt_args,
                ]
                ps_command_args = wrap_go_args_with_string(ps_command_args)
                # Execute source /root/.bashrc to add the file path
                # of `elasticdl_ps` into the PATH environment variable.
                ps_args = [
                    "source",
                    "/root/.bashrc_elasticdl",
                    "&&",
                    ps_command,
                ]
                ps_args.extend(ps_command_args)
            else:
                ps_command = (
                    BashCommandTemplate.SET_PIPEFAIL
                    + " python -m elasticdl.python.ps.main"
                )
                ps_command_args = [
                    "--grads_to_wait",
                    str(args.grads_to_wait),
                    "--lr_staleness_modulation",
                    str(args.lr_staleness_modulation),
                    "--sync_version_tolerance",
                    str(args.sync_version_tolerance),
                    "--use_async",
                    str(args.use_async),
                    "--model_zoo",
                    args.model_zoo,
                    "--model_def",
                    args.model_def,
                    "--job_name",
                    args.job_name,
                    "--port",
                    "2222",
                    "--master_addr",
                    self.master_addr,
                    "--namespace",
                    args.namespace,
                    "--evaluation_steps",
                    str(args.evaluation_steps),
                    "--checkpoint_dir",
                    str(args.checkpoint_dir),
                    "--checkpoint_steps",
                    str(args.checkpoint_steps),
                    "--keep_checkpoint_max",
                    str(args.keep_checkpoint_max),
                    "--num_ps_pods",
                    str(args.num_ps_pods),
                    "--checkpoint_dir_for_init",
                    str(args.checkpoint_dir_for_init),
                    "--num_workers",
                    str(args.num_workers),
                    "--log_level",
                    str(args.log_level),
                    "--minibatch_size",
                    str(args.minibatch_size),
                    "--num_minibatches_per_task",
                    str(args.num_minibatches_per_task),
                ]
                ps_args = wrap_python_args_with_string(ps_command_args)
                ps_args.insert(0, ps_command)

            worker_args = ["-c", " ".join(worker_args)]
            ps_args = ["-c", " ".join(ps_args)]

            env_dict = parse_envs(args.envs)
            env = []
            for key in env_dict:
                env.append(V1EnvVar(name=key, value=env_dict[key]))

            kwargs = get_dict_from_params_str(args.aux_params)
            disable_relaunch = kwargs.get("disable_relaunch", False)
            cluster_spec = self._get_image_cluster_spec(args.cluster_spec)

            instance_manager = InstanceManager(
                self.task_d,
                rendezvous_server=self.rendezvous_server,
                job_name=args.job_name,
                image_name=args.worker_image,
                worker_command=container_command,
                worker_args=worker_args,
                namespace=args.namespace,
                num_workers=args.num_workers,
                worker_resource_request=args.worker_resource_request,
                worker_resource_limit=args.worker_resource_limit,
                worker_pod_priority=args.worker_pod_priority,
                num_ps=args.num_ps_pods,
                ps_command=container_command,
                ps_args=ps_args,
                ps_resource_request=args.ps_resource_request,
                ps_resource_limit=args.ps_resource_limit,
                ps_pod_priority=args.ps_pod_priority,
                volume=args.volume,
                image_pull_policy=args.image_pull_policy,
                restart_policy=args.restart_policy,
                cluster_spec=cluster_spec,
                envs=env,
                disable_relaunch=disable_relaunch,
                log_file_path=args.log_file_path,
            )

        return instance_manager