Ejemplo n.º 1
0
def _submit_job(image_name, client_args, container_args):
    client = k8s.Client(
        image_name=image_name,
        namespace=client_args.namespace,
        job_name=client_args.job_name,
        event_callback=None,
        cluster_spec=client_args.cluster_spec,
        force_use_kube_config_file=client_args.force_use_kube_config_file,
    )

    container_args = wrap_python_args_with_string(container_args)

    master_client_command = (BashCommandTemplate.SET_PIPEFAIL +
                             " python -m elasticdl.python.master.main")
    container_args.insert(0, master_client_command)
    if client_args.log_file_path:
        container_args.append(
            BashCommandTemplate.REDIRECTION.format(client_args.log_file_path))

    python_command = " ".join(container_args)
    container_args = ["-c", python_command]

    if client_args.yaml:
        client.dump_master_yaml(
            resource_requests=client_args.master_resource_request,
            resource_limits=client_args.master_resource_limit,
            args=container_args,
            pod_priority=client_args.master_pod_priority,
            image_pull_policy=client_args.image_pull_policy,
            restart_policy=client_args.restart_policy,
            volume=client_args.volume,
            envs=parse_envs(client_args.envs),
            yaml=client_args.yaml,
        )
        logger.info("ElasticDL job %s YAML has been dumped into file %s." %
                    (client_args.job_name, client_args.yaml))
    else:
        client.create_master(
            resource_requests=client_args.master_resource_request,
            resource_limits=client_args.master_resource_limit,
            args=container_args,
            pod_priority=client_args.master_pod_priority,
            image_pull_policy=client_args.image_pull_policy,
            restart_policy=client_args.restart_policy,
            volume=client_args.volume,
            envs=parse_envs(client_args.envs),
        )
        logger.info("ElasticDL job %s was successfully submitted. "
                    "The master pod is: %s." %
                    (client_args.job_name, client.get_master_pod_name()))
Ejemplo n.º 2
0
def _submit_job(image_name, client_args, container_args):
    client = k8s.Client(
        image_name=image_name,
        namespace=client_args.namespace,
        job_name=client_args.job_name,
        event_callback=None,
        cluster_spec=client_args.cluster_spec,
    )

    if client_args.yaml:
        client.dump_master_yaml(
            resource_requests=client_args.master_resource_request,
            resource_limits=client_args.master_resource_limit,
            args=container_args,
            pod_priority=client_args.master_pod_priority,
            image_pull_policy=client_args.image_pull_policy,
            restart_policy=client_args.restart_policy,
            volume=client_args.volume,
            envs=parse_envs(client_args.envs),
            yaml=client_args.yaml,
        )
        logger.info(
            "ElasticDL job %s YAML has been dumped into file %s."
            % (client_args.job_name, client_args.yaml)
        )
    else:
        client.create_master(
            resource_requests=client_args.master_resource_request,
            resource_limits=client_args.master_resource_limit,
            args=container_args,
            pod_priority=client_args.master_pod_priority,
            image_pull_policy=client_args.image_pull_policy,
            restart_policy=client_args.restart_policy,
            volume=client_args.volume,
            envs=parse_envs(client_args.envs),
        )
        logger.info(
            "ElasticDL job %s was successfully submitted. "
            "The master pod is: %s."
            % (client_args.job_name, client.get_master_pod_name())
        )
Ejemplo n.º 3
0
    def __init__(self, args):
        envs = parse_envs(args.envs)
        self._init_environment(envs)

        (
            self.model_inst,
            self.dataset_fn,
            self.loss_fn,
            self.opt_fn,
            self.eval_metrics_fn,
            self.prediction_outputs_processor,
            self.custom_data_reader,
            self.callback_list,
        ) = get_model_spec(
            model_zoo=args.model_zoo,
            model_def=args.model_def,
            dataset_fn=args.dataset_fn,
            loss=args.loss,
            optimizer=args.optimizer,
            eval_metrics_fn=args.eval_metrics_fn,
            model_params=args.model_params,
            prediction_outputs_processor="",
            custom_data_reader=args.custom_data_reader,
            callbacks=args.callbacks,
        )
        self.opt = self.opt_fn()
        self.epoch = args.num_epochs
        self.evaluation_steps = args.evaluation_steps
        self.batch_size = args.minibatch_size
        self.data_reader_params = get_dict_from_params_str(
            args.data_reader_params
        )
        self.records_per_task = (
            args.minibatch_size * args.num_minibatches_per_task
        )

        create_data_reader_fn = (
            create_data_reader
            if self.custom_data_reader is None
            else self.custom_data_reader
        )
        self.data_reader = create_data_reader_fn(
            data_origin=args.training_data,
            records_per_task=self.records_per_task,
            **self.data_reader_params
        )
        self.training_data = args.training_data
        self.validation_data = args.validation_data
        self.save_model_dir = args.output
Ejemplo n.º 4
0
def main():
    args = parse_args()
    logger = get_logger("master", level=args.log_level.upper())

    # Master addr
    master_ip = os.getenv("MY_POD_IP", "localhost")
    master_addr = "%s:%d" % (master_ip, args.port)

    # Start TensorBoard service if requested
    if args.tensorboard_log_dir:
        logger.info(
            "Starting TensorBoard service with log directory %s",
            args.tensorboard_log_dir,
        )
        # Start TensorBoard CLI
        tb_service = TensorboardService(args.tensorboard_log_dir, master_ip)
        tb_service.start()
    else:
        tb_service = None

    # Start task queue
    logger.debug(
        "Starting task queue with training data directory %s, "
        "evaluation data directory %s, "
        "and prediction data directory %s",
        args.training_data_dir,
        args.evaluation_data_dir,
        args.prediction_data_dir,
    )
    task_d = _make_task_dispatcher(
        args.training_data_dir,
        args.evaluation_data_dir,
        args.prediction_data_dir,
        args.records_per_task,
        args.num_epochs,
    )
    model_module = load_module(
        get_module_file_path(args.model_zoo, args.model_def)).__dict__
    model_inst = load_model_from_module(args.model_def, model_module,
                                        args.model_params)
    optimizer = model_module[args.optimizer]()

    if all((
            args.training_data_dir,
            args.evaluation_data_dir,
            args.evaluation_throttle_secs or args.evaluation_steps,
    )):
        job_type = JobType.TRAINING_WITH_EVALUATION
    elif all((
            args.evaluation_data_dir,
            not args.training_data_dir,
            not args.prediction_data_dir,
    )):
        job_type = JobType.EVALUATION_ONLY
    elif all((
            args.prediction_data_dir,
            not args.evaluation_data_dir,
            not args.training_data_dir,
    )):
        job_type = JobType.PREDICTION_ONLY
    else:
        job_type = JobType.TRAINING_ONLY

    # Initialize checkpoint service
    if args.checkpoint_steps or job_type == JobType.TRAINING_WITH_EVALUATION:
        logger.info("Starting checkpoint service")
        checkpoint_service = CheckpointService(
            args.checkpoint_dir,
            args.checkpoint_steps,
            args.keep_checkpoint_max,
            job_type == JobType.TRAINING_WITH_EVALUATION,
        )
    else:
        checkpoint_service = None

    # Initialize evaluation service
    evaluation_service = None
    if (job_type == JobType.TRAINING_WITH_EVALUATION
            or job_type == JobType.EVALUATION_ONLY):
        logger.info(
            "Starting evaluation service with throttle seconds %d "
            " and evaluation steps %d",
            args.evaluation_throttle_secs,
            args.evaluation_steps,
        )
        evaluation_service = EvaluationService(
            checkpoint_service,
            tb_service,
            task_d,
            args.evaluation_start_delay_secs,
            args.evaluation_throttle_secs,
            args.evaluation_steps,
            job_type == JobType.EVALUATION_ONLY,
        )
        evaluation_service.start()
        task_d.set_evaluation_service(evaluation_service)

    embedding_service_endpoint = None
    embedding_dims = {}
    # Search for embedding layers in the model,
    # if found, initialize embedding service
    layers = find_layer(model_inst, Embedding)
    if layers:
        embedding_service = EmbeddingService()
        embedding_service_endpoint = embedding_service.start_embedding_service(
            job_name=args.job_name,
            image_name=args.worker_image,
            namespace=args.namespace,
            resource_request=args.master_resource_request,
            resource_limit=args.master_resource_limit,
            pod_priority=args.worker_pod_priority,
            volume=args.volume,
            image_pull_policy=args.image_pull_policy,
            restart_policy=args.restart_policy,
            cluster_spec=args.cluster_spec,
        )
        logger.info("Embedding service start succeeded. The endpoint is %s." %
                    str(embedding_service_endpoint))
        embedding_dims = dict([(layer.name, layer.output_dim)
                               for layer in layers])

    # The master service
    logger.info("Starting master service")
    server = grpc.server(
        futures.ThreadPoolExecutor(max_workers=64),
        options=[
            ("grpc.max_send_message_length", GRPC.MAX_SEND_MESSAGE_LENGTH),
            (
                "grpc.max_receive_message_length",
                GRPC.MAX_RECEIVE_MESSAGE_LENGTH,
            ),
        ],
    )
    master_servicer = MasterServicer(
        args.grads_to_wait,
        args.minibatch_size,
        optimizer,
        task_d,
        init_var=model_inst.trainable_variables if model_inst.built else [],
        embedding_dims=embedding_dims,
        checkpoint_filename_for_init=args.checkpoint_filename_for_init,
        checkpoint_service=checkpoint_service,
        evaluation_service=evaluation_service,
        embedding_service_endpoint=embedding_service_endpoint,
        lr_staleness_modulation=args.lr_staleness_modulation,
        use_async=args.use_async,
    )
    elasticdl_pb2_grpc.add_MasterServicer_to_server(master_servicer, server)
    server.add_insecure_port("[::]:{}".format(args.port))
    server.start()
    logger.info("Server started at port: %d", args.port)

    worker_manager = None
    if args.num_workers:
        assert args.worker_image, "Worker image cannot be empty"

        worker_command = ["python"]
        worker_args = [
            "-m",
            "elasticdl.python.worker.main",
            "--model_zoo",
            args.model_zoo,
            "--master_addr",
            master_addr,
            "--log_level",
            args.log_level,
            "--dataset_fn",
            args.dataset_fn,
            "--loss",
            args.loss,
            "--optimizer",
            args.optimizer,
            "--eval_metrics_fn",
            args.eval_metrics_fn,
            "--model_def",
            args.model_def,
            "--job_type",
            job_type,
            "--minibatch_size",
            str(args.minibatch_size),
            "--embedding_service_endpoint",
            str(embedding_service_endpoint),
            "--get_model_steps",
            str(args.get_model_steps),
        ]

        env_dict = parse_envs(args.envs)
        env = []
        for key in env_dict:
            env.append(V1EnvVar(name=key, value=env_dict[key]))

        worker_manager = WorkerManager(
            task_d,
            job_name=args.job_name,
            image_name=args.worker_image,
            command=worker_command,
            args=worker_args,
            namespace=args.namespace,
            num_workers=args.num_workers,
            worker_resource_request=args.worker_resource_request,
            worker_resource_limit=args.worker_resource_limit,
            pod_priority=args.worker_pod_priority,
            volume=args.volume,
            image_pull_policy=args.image_pull_policy,
            restart_policy=args.restart_policy,
            cluster_spec=args.cluster_spec,
            envs=env,
        )
        worker_manager.update_status(WorkerManagerStatus.PENDING)
        logger.info("Launching %d workers", args.num_workers)
        worker_manager.start_workers()
        worker_manager.update_status(WorkerManagerStatus.RUNNING)

    # Start TensorBoard k8s Service if requested
    if tb_service:
        TensorBoardClient(
            job_name=args.job_name,
            image_name=args.worker_image,
            namespace=args.namespace,
        ).start_tensorboard_service()

    try:
        while True:
            if task_d.finished():
                if worker_manager:
                    worker_manager.update_status(WorkerManagerStatus.FINISHED)
                if args.output:
                    master_servicer.save_latest_checkpoint(args.output)
                break
            time.sleep(30)
    except KeyboardInterrupt:
        logger.warning("Server stopping")

    if evaluation_service:
        logger.info("Stopping evaluation service")
        evaluation_service.stop()

    logger.info("Stopping RPC server")
    server.stop(0)

    # Keep TensorBoard running when all the tasks are finished
    if tb_service:
        logger.info(
            "All tasks finished. Keeping TensorBoard service running...")
        while True:
            if tb_service.is_active():
                time.sleep(10)
            else:
                logger.warning("Unable to keep TensorBoard running. "
                               "It has already terminated")
                break
    logger.info("Master stopped")
Ejemplo n.º 5
0
    def _create_instance_manager(self, args):
        instance_manager = None
        if args.num_workers:
            assert args.worker_image, "Worker image cannot be empty"

            worker_command = ["python"]
            worker_args = [
                "-m",
                "elasticdl.python.worker.main",
                "--master_addr",
                self.master_addr,
                "--job_type",
                self.job_type,
            ]
            worker_args.extend(build_arguments_from_parsed_result(args))

            ps_command = ["python"]
            ps_args = [
                "-m",
                "elasticdl.python.ps.main",
                "--grads_to_wait",
                str(args.grads_to_wait),
                "--lr_staleness_modulation",
                str(args.lr_staleness_modulation),
                "--use_async",
                str(args.use_async),
                "--minibatch_size",
                str(args.minibatch_size),
                "--model_zoo",
                args.model_zoo,
                "--model_def",
                args.model_def,
                "--job_name",
                args.job_name,
                "--num_minibatches_per_task",
                str(args.num_minibatches_per_task),
                "--port",
                "2222",
                "--master_addr",
                self.master_addr,
                "--namespace",
                args.namespace,
                "--evaluation_steps",
                str(args.evaluation_steps),
                "--checkpoint_dir",
                str(args.checkpoint_dir),
                "--checkpoint_steps",
                str(args.checkpoint_steps),
                "--keep_checkpoint_max",
                str(args.keep_checkpoint_max),
                "--num_ps_pods",
                str(args.num_ps_pods),
                "--checkpoint_dir_for_init",
                str(args.checkpoint_dir_for_init),
                "--num_workers",
                str(args.num_workers),
                "--log_level",
                str(args.log_level),
            ]

            env_dict = parse_envs(args.envs)
            env = []
            for key in env_dict:
                env.append(V1EnvVar(name=key, value=env_dict[key]))

            instance_manager = InstanceManager(
                self.task_d,
                job_name=args.job_name,
                image_name=args.worker_image,
                worker_command=worker_command,
                worker_args=worker_args,
                namespace=args.namespace,
                num_workers=args.num_workers,
                worker_resource_request=args.worker_resource_request,
                worker_resource_limit=args.worker_resource_limit,
                worker_pod_priority=args.worker_pod_priority,
                num_ps=args.num_ps_pods,
                ps_command=ps_command,
                ps_args=ps_args,
                ps_resource_request=args.ps_resource_request,
                ps_resource_limit=args.ps_resource_limit,
                ps_pod_priority=args.ps_pod_priority,
                volume=args.volume,
                image_pull_policy=args.image_pull_policy,
                restart_policy=args.restart_policy,
                cluster_spec=args.cluster_spec,
                envs=env,
            )

        return instance_manager
Ejemplo n.º 6
0
    def _create_instance_manager(self, args):
        instance_manager = None

        container_command = ["/bin/bash"]
        if args.num_workers:
            assert args.worker_image, "Worker image cannot be empty"

            worker_client_command = (BashCommandTemplate.SET_PIPEFAIL +
                                     " python -m elasticdl.python.worker.main")
            worker_args = [
                "--master_addr",
                self.master_addr,
                "--job_type",
                self.job_type,
            ]
            worker_args.extend(
                build_arguments_from_parsed_result(args, filter_args=["envs"]))
            worker_args = wrap_python_args_with_string(worker_args)
            worker_args.insert(0, worker_client_command)

            if args.use_go_ps:
                opt_type, opt_args = get_optimizer_info(self.optimizer)
                # TODO: rename the Go PS executable using a meaningful filename
                ps_client_command = "main"
                ps_args = [
                    "-job_name=" + args.job_name,
                    "-namespace=" + args.namespace,
                    "-master_addr=" + self.master_addr,
                    "-port=2222",
                    "-use_async=" + ("true" if args.use_async else "false"),
                    "-grads_to_wait=" + str(args.grads_to_wait),
                    "-lr_staleness_modulation=" +
                    ("true" if args.lr_staleness_modulation else "false"),
                    "-sync_version_tolerance=" +
                    str(args.sync_version_tolerance),
                    "-evaluation_steps=" + str(args.evaluation_steps),
                    "-num_ps_pods=" + str(args.num_ps_pods),
                    "-num_workers=" + str(args.num_workers),
                    "-checkpoint_dir=" + str(args.checkpoint_dir),
                    "-checkpoint_steps=" + str(args.checkpoint_steps),
                    "-keep_checkpoint_max=" + str(args.keep_checkpoint_max),
                    "-checkpoint_dir_for_init=" +
                    str(args.checkpoint_dir_for_init),
                    "-opt_type=" + opt_type,
                    "-opt_args=" + opt_args,
                ]
                ps_args = wrap_go_args_with_string(ps_args)
                ps_args.insert(0, ps_client_command)
            else:
                ps_client_command = (BashCommandTemplate.SET_PIPEFAIL +
                                     " python -m elasticdl.python.ps.main")
                ps_args = [
                    "--grads_to_wait",
                    str(args.grads_to_wait),
                    "--lr_staleness_modulation",
                    str(args.lr_staleness_modulation),
                    "--sync_version_tolerance",
                    str(args.sync_version_tolerance),
                    "--use_async",
                    str(args.use_async),
                    "--model_zoo",
                    args.model_zoo,
                    "--model_def",
                    args.model_def,
                    "--job_name",
                    args.job_name,
                    "--port",
                    "2222",
                    "--master_addr",
                    self.master_addr,
                    "--namespace",
                    args.namespace,
                    "--evaluation_steps",
                    str(args.evaluation_steps),
                    "--checkpoint_dir",
                    str(args.checkpoint_dir),
                    "--checkpoint_steps",
                    str(args.checkpoint_steps),
                    "--keep_checkpoint_max",
                    str(args.keep_checkpoint_max),
                    "--num_ps_pods",
                    str(args.num_ps_pods),
                    "--checkpoint_dir_for_init",
                    str(args.checkpoint_dir_for_init),
                    "--num_workers",
                    str(args.num_workers),
                    "--log_level",
                    str(args.log_level),
                    "--minibatch_size",
                    str(args.minibatch_size),
                    "--num_minibatches_per_task",
                    str(args.num_minibatches_per_task),
                ]
                ps_args = wrap_python_args_with_string(ps_args)
                ps_args.insert(0, ps_client_command)

            worker_args = ["-c", " ".join(worker_args)]
            ps_args = ["-c", " ".join(ps_args)]

            env_dict = parse_envs(args.envs)
            env = []
            for key in env_dict:
                env.append(V1EnvVar(name=key, value=env_dict[key]))

            kwargs = get_dict_from_params_str(args.aux_params)
            disable_relaunch = kwargs.get("disable_relaunch", False)

            instance_manager = InstanceManager(
                self.task_d,
                job_name=args.job_name,
                image_name=args.worker_image,
                worker_command=container_command,
                worker_args=worker_args,
                namespace=args.namespace,
                num_workers=args.num_workers,
                worker_resource_request=args.worker_resource_request,
                worker_resource_limit=args.worker_resource_limit,
                worker_pod_priority=args.worker_pod_priority,
                num_ps=args.num_ps_pods,
                ps_command=container_command,
                ps_args=ps_args,
                ps_resource_request=args.ps_resource_request,
                ps_resource_limit=args.ps_resource_limit,
                ps_pod_priority=args.ps_pod_priority,
                volume=args.volume,
                image_pull_policy=args.image_pull_policy,
                restart_policy=args.restart_policy,
                cluster_spec=args.cluster_spec,
                envs=env,
                expose_ports=self.distribution_strategy ==
                DistributionStrategy.ALLREDUCE,
                disable_relaunch=disable_relaunch,
                log_file_path=args.log_file_path,
            )

        return instance_manager