Example #1
0
 def start_tensorboard_service(self):
     tb_client = TensorBoardClient(self._k8s_client)
     tb_client.create_tensorboard_service()
     logger.info("Waiting for the URL for TensorBoard service...")
     tb_url = tb_client.get_tensorboard_url()
     if tb_url:
         logger.info("TensorBoard service is available at: %s" % tb_url)
     else:
         logger.warning("Unable to get the URL for TensorBoard service")
 def test_create_tensorboard_service(self):
     tb_client = TensorBoardClient(
         image_name=None,
         namespace="default",
         job_name="test-job-%d-%d" %
         (int(time.time()), random.randint(1, 101)),
         event_callback=None,
     )
     tb_client._create_tensorboard_service(port=80,
                                           service_type="LoadBalancer")
     time.sleep(1)
     service = tb_client._get_tensorboard_service()
     self.assertTrue("load_balancer" in service["status"])
     self.assertEqual(service["spec"]["ports"][0]["port"], 80)
Example #3
0
    def __init__(self, args):
        self.logger = get_logger("master", level=args.log_level.upper())

        self.num_ps_pods = args.num_ps_pods
        self.checkpoint_output_path = args.checkpoint_dir
        self.distribution_strategy = args.distribution_strategy

        # Master addr
        master_ip = os.getenv("MY_POD_IP", "localhost")
        self.master_addr = "%s:%d" % (master_ip, args.port)
        self.job_type = Master._get_job_type(args)
        self.rendezvous_server = None
        if self.distribution_strategy == DistributionStrategy.ALLREDUCE:
            self.rendezvous_server = HorovodRendezvousServer(master_ip)

        # Initialize TensorBoard service if requested
        self.tb_service = self._create_tensorboard_service(
            args.tensorboard_log_dir, master_ip
        )
        if self.tb_service:
            self.tb_client = TensorBoardClient(
                job_name=args.job_name,
                image_name=args.worker_image,
                namespace=args.namespace,
            )

        # Initialize the components from the model definition
        self.model_module = load_module(
            get_module_file_path(args.model_zoo, args.model_def)
        ).__dict__
        self.model_inst = load_model_from_module(
            args.model_def, self.model_module, args.model_params
        )
        self.optimizer = self.model_module[args.optimizer]()
        self._create_data_reader_fn = create_data_reader
        if args.custom_data_reader in self.model_module:
            self._create_data_reader_fn = self.model_module[
                args.custom_data_reader
            ]

        # Initialize the callbacks
        self.callbacks_list = load_callbacks_from_module(
            args.callbacks, self.model_module
        )
        self.callbacks_list.set_model(self.model_inst)
        set_callback_parameters(
            self.callbacks_list,
            batch_size=args.minibatch_size,
            saved_model_path=args.output,
            checkpoint_path=args.checkpoint_dir,
        )
        self._set_completed_steps_by_checkpoint(args.checkpoint_dir_for_init)

        # Start task queue
        records_per_task = args.minibatch_size * args.num_minibatches_per_task
        self.task_d = _make_task_dispatcher(
            args.training_data,
            args.validation_data,
            args.prediction_data,
            records_per_task,
            args.num_epochs,
            args.data_reader_params,
            self._create_data_reader_fn,
            self.callbacks_list,
        )

        self.task_d.add_deferred_callback_create_train_end_task()
        self.evaluation_service = self._create_evaluation_service(args)

        # Initialize instance manager
        self.instance_manager = self._create_instance_manager(args)

        # Initialize master service
        self.master_servicer, self.server = self._create_master_service(args)

        self._should_stop = False
        self._exit_code = 0
        threading.Thread(
            target=self._check_timeout_tasks,
            name="check_timeout_tasks",
            daemon=True,
        ).start()
Example #4
0
class Master(object):
    def __init__(self, args):
        self.logger = get_logger("master", level=args.log_level.upper())

        self.num_ps_pods = args.num_ps_pods
        self.checkpoint_output_path = args.checkpoint_dir
        self.distribution_strategy = args.distribution_strategy

        # Master addr
        master_ip = os.getenv("MY_POD_IP", "localhost")
        self.master_addr = "%s:%d" % (master_ip, args.port)
        self.job_type = Master._get_job_type(args)
        self.rendezvous_server = None
        if self.distribution_strategy == DistributionStrategy.ALLREDUCE:
            self.rendezvous_server = HorovodRendezvousServer(master_ip)

        # Initialize TensorBoard service if requested
        self.tb_service = self._create_tensorboard_service(
            args.tensorboard_log_dir, master_ip
        )
        if self.tb_service:
            self.tb_client = TensorBoardClient(
                job_name=args.job_name,
                image_name=args.worker_image,
                namespace=args.namespace,
            )

        # Initialize the components from the model definition
        self.model_module = load_module(
            get_module_file_path(args.model_zoo, args.model_def)
        ).__dict__
        self.model_inst = load_model_from_module(
            args.model_def, self.model_module, args.model_params
        )
        self.optimizer = self.model_module[args.optimizer]()
        self._create_data_reader_fn = create_data_reader
        if args.custom_data_reader in self.model_module:
            self._create_data_reader_fn = self.model_module[
                args.custom_data_reader
            ]

        # Initialize the callbacks
        self.callbacks_list = load_callbacks_from_module(
            args.callbacks, self.model_module
        )
        self.callbacks_list.set_model(self.model_inst)
        set_callback_parameters(
            self.callbacks_list,
            batch_size=args.minibatch_size,
            saved_model_path=args.output,
            checkpoint_path=args.checkpoint_dir,
        )
        self._set_completed_steps_by_checkpoint(args.checkpoint_dir_for_init)

        # Start task queue
        records_per_task = args.minibatch_size * args.num_minibatches_per_task
        self.task_d = _make_task_dispatcher(
            args.training_data,
            args.validation_data,
            args.prediction_data,
            records_per_task,
            args.num_epochs,
            args.data_reader_params,
            self._create_data_reader_fn,
            self.callbacks_list,
        )

        self.task_d.add_deferred_callback_create_train_end_task()
        self.evaluation_service = self._create_evaluation_service(args)

        # Initialize instance manager
        self.instance_manager = self._create_instance_manager(args)

        # Initialize master service
        self.master_servicer, self.server = self._create_master_service(args)

        self._should_stop = False
        self._exit_code = 0
        threading.Thread(
            target=self._check_timeout_tasks,
            name="check_timeout_tasks",
            daemon=True,
        ).start()

    def _set_completed_steps_by_checkpoint(self, checkpoint_dir_for_init):
        if not checkpoint_dir_for_init:
            return

        if not CheckpointSaver.check_checkpoint_valid(checkpoint_dir_for_init):
            raise ValueError(
                "Invalid checkpoint directory {}".format(
                    checkpoint_dir_for_init
                )
            )

        model_verion = CheckpointSaver.get_version_from_checkpoint(
            checkpoint_dir_for_init
        )
        for callback in self.callbacks_list.callbacks:
            if isinstance(callback, MaxStepsStopping):
                callback.set_completed_steps(model_verion)

    def request_stop(self, err_msg=None):
        """Request master to quit"""
        self._should_stop = True
        if err_msg:
            self.logger.error(err_msg)
            # TODO (chengfu.wcy) create meaningful status codes
            self._exit_code = -1

    def prepare(self):
        """
        Start the components one by one. Make sure that it is ready to run.
        """
        # Start the evaluation service if requested
        if self.evaluation_service:
            self.logger.info("Starting evaluation service")
            self.evaluation_service.start()
            self.logger.info("Evaluation service started")

        # Start the master GRPC server
        self.logger.info("Starting master RPC server")
        self.server.start()
        self.logger.info("Master RPC server started")

        # Start the worker manager if requested
        if self.instance_manager:
            self.instance_manager.update_status(InstanceManagerStatus.PENDING)
            if self.distribution_strategy == DistributionStrategy.ALLREDUCE:
                # Start rendezvous server for workers to initialize Horovod
                self.rendezvous_server.start()
            else:
                self.instance_manager.start_parameter_servers()
            self.instance_manager.start_workers()
            self.instance_manager.update_status(InstanceManagerStatus.RUNNING)

        # Start TensorBoard k8s Service if requested
        if self.tb_service and self.tb_client:
            self.logger.info("Starting tensorboard service")
            self.tb_service.start()
            self.tb_client.start_tensorboard_service()
            self.logger.info("Tensorboard service started")

    def run(self):
        """
        The main loop of master.
        Dispatch the tasks to the workers until all the tasks are completed.
        """
        try:
            while True:
                if self.instance_manager.all_workers_failed:
                    raise Exception(
                        "All workers fail with unrecoverable errors"
                    )
                    break
                if self.task_d.finished():
                    if self.instance_manager:
                        self.instance_manager.update_status(
                            InstanceManagerStatus.FINISHED
                        )
                    break
                if self._should_stop:
                    break
                time.sleep(30)
        except KeyboardInterrupt:
            self.logger.warning("Server stopping")
        finally:
            self._stop()
        return self._exit_code

    def _stop(self):
        """
        Stop all the components.
        Make sure that the created services and components are shut down.
        """
        self.logger.info("Stopping master")

        if self.evaluation_service:
            self.logger.info("Stopping evaluation service")
            self.evaluation_service.stop()
            self.logger.info("Evaluation service stopped")

        self.logger.info("Stopping RPC server")
        self.server.stop(None)  # grace = None
        self.logger.info("RPC server stopped")

        # Keep TensorBoard running when all the tasks are finished
        if self.tb_service:
            self.logger.info(
                "All tasks finished. Keeping TensorBoard service running..."
            )
            while True:
                if self.tb_service.is_active():
                    time.sleep(10)
                else:
                    self.logger.warning(
                        "Unable to keep TensorBoard running. "
                        "It has already terminated"
                    )
                    break
        self.logger.info("Master stopped")

    @staticmethod
    def _get_job_type(args):
        if all(
            (
                args.training_data,
                args.validation_data,
                args.evaluation_throttle_secs or args.evaluation_steps,
            )
        ):
            job_type = JobType.TRAINING_WITH_EVALUATION
        elif all(
            (
                args.validation_data,
                not args.training_data,
                not args.prediction_data,
            )
        ):
            job_type = JobType.EVALUATION_ONLY
        elif all(
            (
                args.prediction_data,
                not args.validation_data,
                not args.training_data,
            )
        ):
            job_type = JobType.PREDICTION_ONLY
        else:
            job_type = JobType.TRAINING_ONLY

        return job_type

    def _create_tensorboard_service(self, tensorboard_log_dir, master_ip):
        tb_service = None
        if tensorboard_log_dir:
            self.logger.info(
                "Creating TensorBoard service with log directory %s",
                tensorboard_log_dir,
            )
            # Start TensorBoard CLI
            tb_service = TensorboardService(tensorboard_log_dir, master_ip)

        return tb_service

    def _create_evaluation_service(self, args):
        evaluation_service = None
        if (
            self.job_type == JobType.TRAINING_WITH_EVALUATION
            or self.job_type == JobType.EVALUATION_ONLY
        ):
            self.logger.info(
                "Creating evaluation service with throttle seconds %d "
                " and evaluation steps %d",
                args.evaluation_throttle_secs,
                args.evaluation_steps,
            )
            evaluation_service = EvaluationService(
                self.tb_service,
                self.task_d,
                args.evaluation_start_delay_secs,
                args.evaluation_throttle_secs,
                args.evaluation_steps,
                self.job_type == JobType.EVALUATION_ONLY,
                self.model_module[args.eval_metrics_fn],
            )
            self.task_d.set_evaluation_service(evaluation_service)

        return evaluation_service

    def _create_master_service(self, args):
        self.logger.info("Creating master service")
        server = grpc.server(
            futures.ThreadPoolExecutor(max_workers=64),
            options=[
                ("grpc.max_send_message_length", GRPC.MAX_SEND_MESSAGE_LENGTH),
                (
                    "grpc.max_receive_message_length",
                    GRPC.MAX_RECEIVE_MESSAGE_LENGTH,
                ),
            ],
        )
        master_servicer = MasterServicer(
            args.minibatch_size,
            evaluation_service=self.evaluation_service,
            master=self,
        )
        elasticdl_pb2_grpc.add_MasterServicer_to_server(
            master_servicer, server
        )
        server.add_insecure_port("[::]:{}".format(args.port))
        self.logger.info("The port of the master server is: %d", args.port)

        return master_servicer, server

    def _create_instance_manager(self, args):
        instance_manager = None

        container_command = ["/bin/bash"]
        if args.num_workers:
            assert args.worker_image, "Worker image cannot be empty"

            worker_client_command = (
                BashCommandTemplate.SET_PIPEFAIL
                + " python -m elasticdl.python.worker.main"
            )
            worker_args = [
                "--master_addr",
                self.master_addr,
                "--job_type",
                self.job_type,
            ]
            worker_args.extend(
                build_arguments_from_parsed_result(args, filter_args=["envs"])
            )
            worker_args = wrap_python_args_with_string(worker_args)
            worker_args.insert(0, worker_client_command)

            if args.use_go_ps:
                opt_type, opt_args = get_optimizer_info(self.optimizer)
                ps_command = "elasticdl_ps"
                ps_command_args = [
                    "-job_name=" + args.job_name,
                    "-namespace=" + args.namespace,
                    "-master_addr=" + self.master_addr,
                    "-port=2222",
                    "-use_async=" + ("true" if args.use_async else "false"),
                    "-grads_to_wait=" + str(args.grads_to_wait),
                    "-lr_staleness_modulation="
                    + ("true" if args.lr_staleness_modulation else "false"),
                    "-sync_version_tolerance="
                    + str(args.sync_version_tolerance),
                    "-evaluation_steps=" + str(args.evaluation_steps),
                    "-num_ps_pods=" + str(args.num_ps_pods),
                    "-num_workers=" + str(args.num_workers),
                    "-checkpoint_dir=" + str(args.checkpoint_dir),
                    "-checkpoint_steps=" + str(args.checkpoint_steps),
                    "-keep_checkpoint_max=" + str(args.keep_checkpoint_max),
                    "-checkpoint_dir_for_init="
                    + str(args.checkpoint_dir_for_init),
                    "-opt_type=" + opt_type,
                    "-opt_args=" + opt_args,
                ]
                ps_command_args = wrap_go_args_with_string(ps_command_args)
                # Execute source /root/.bashrc to add the file path
                # of `elasticdl_ps` into the PATH environment variable.
                ps_args = [
                    "source",
                    "/root/.bashrc_elasticdl",
                    "&&",
                    ps_command,
                ]
                ps_args.extend(ps_command_args)
            else:
                ps_command = (
                    BashCommandTemplate.SET_PIPEFAIL
                    + " python -m elasticdl.python.ps.main"
                )
                ps_command_args = [
                    "--grads_to_wait",
                    str(args.grads_to_wait),
                    "--lr_staleness_modulation",
                    str(args.lr_staleness_modulation),
                    "--sync_version_tolerance",
                    str(args.sync_version_tolerance),
                    "--use_async",
                    str(args.use_async),
                    "--model_zoo",
                    args.model_zoo,
                    "--model_def",
                    args.model_def,
                    "--job_name",
                    args.job_name,
                    "--port",
                    "2222",
                    "--master_addr",
                    self.master_addr,
                    "--namespace",
                    args.namespace,
                    "--evaluation_steps",
                    str(args.evaluation_steps),
                    "--checkpoint_dir",
                    str(args.checkpoint_dir),
                    "--checkpoint_steps",
                    str(args.checkpoint_steps),
                    "--keep_checkpoint_max",
                    str(args.keep_checkpoint_max),
                    "--num_ps_pods",
                    str(args.num_ps_pods),
                    "--checkpoint_dir_for_init",
                    str(args.checkpoint_dir_for_init),
                    "--num_workers",
                    str(args.num_workers),
                    "--log_level",
                    str(args.log_level),
                    "--minibatch_size",
                    str(args.minibatch_size),
                    "--num_minibatches_per_task",
                    str(args.num_minibatches_per_task),
                ]
                ps_args = wrap_python_args_with_string(ps_command_args)
                ps_args.insert(0, ps_command)

            worker_args = ["-c", " ".join(worker_args)]
            ps_args = ["-c", " ".join(ps_args)]

            env_dict = parse_envs(args.envs)
            env = []
            for key in env_dict:
                env.append(V1EnvVar(name=key, value=env_dict[key]))

            kwargs = get_dict_from_params_str(args.aux_params)
            disable_relaunch = kwargs.get("disable_relaunch", False)
            cluster_spec = self._get_image_cluster_spec(args.cluster_spec)

            instance_manager = InstanceManager(
                self.task_d,
                rendezvous_server=self.rendezvous_server,
                job_name=args.job_name,
                image_name=args.worker_image,
                worker_command=container_command,
                worker_args=worker_args,
                namespace=args.namespace,
                num_workers=args.num_workers,
                worker_resource_request=args.worker_resource_request,
                worker_resource_limit=args.worker_resource_limit,
                worker_pod_priority=args.worker_pod_priority,
                num_ps=args.num_ps_pods,
                ps_command=container_command,
                ps_args=ps_args,
                ps_resource_request=args.ps_resource_request,
                ps_resource_limit=args.ps_resource_limit,
                ps_pod_priority=args.ps_pod_priority,
                volume=args.volume,
                image_pull_policy=args.image_pull_policy,
                restart_policy=args.restart_policy,
                cluster_spec=cluster_spec,
                envs=env,
                disable_relaunch=disable_relaunch,
                log_file_path=args.log_file_path,
            )

        return instance_manager

    def _get_image_cluster_spec(self, cluster_spec):
        if cluster_spec:
            filename = os.path.basename(cluster_spec)
            image_cluster_spec = os.path.join(
                ClusterSpecConfig.CLUSTER_SPEC_DIR, filename
            )
            return image_cluster_spec
        return cluster_spec

    def _check_timeout_tasks(self):
        while True:
            doing_tasks = self.task_d._doing.copy()
            cur_time = time.time()
            avg_time = self.master_servicer.get_average_task_complete_time()
            for task_id, (worker_id, task, start_time) in doing_tasks.items():
                if task.type == elasticdl_pb2.TRAINING:
                    start_time = self.master_servicer.get_worker_liveness_time(
                        worker_id
                    )
                if task.type in [
                    elasticdl_pb2.TRAINING,
                    elasticdl_pb2.EVALUATION,
                ]:
                    if (cur_time - start_time) > 3 * avg_time[task.type]:
                        self.logger.info(
                            "worker %d timeout, relaunch it" % worker_id
                        )
                        self.task_d.recover_tasks(worker_id)
                        # TODO: save worker logs before remove it
                        self.instance_manager._remove_worker(worker_id)
                        break
            time.sleep(30)
Example #5
0
def main():
    args = parse_args()
    logger = get_logger("master", level=args.log_level.upper())

    # Master addr
    master_ip = os.getenv("MY_POD_IP", "localhost")
    master_addr = "%s:%d" % (master_ip, args.port)

    # Start TensorBoard service if requested
    if args.tensorboard_log_dir:
        logger.info(
            "Starting TensorBoard service with log directory %s",
            args.tensorboard_log_dir,
        )
        # Start TensorBoard CLI
        tb_service = TensorboardService(args.tensorboard_log_dir, master_ip)
        tb_service.start()
    else:
        tb_service = None

    # Start task queue
    logger.debug(
        "Starting task queue with training data directory %s, "
        "evaluation data directory %s, "
        "and prediction data directory %s",
        args.training_data_dir,
        args.evaluation_data_dir,
        args.prediction_data_dir,
    )
    task_d = _make_task_dispatcher(
        args.training_data_dir,
        args.evaluation_data_dir,
        args.prediction_data_dir,
        args.records_per_task,
        args.num_epochs,
    )
    model_module = load_module(
        get_module_file_path(args.model_zoo, args.model_def)).__dict__
    model_inst = load_model_from_module(args.model_def, model_module,
                                        args.model_params)
    optimizer = model_module[args.optimizer]()

    if all((
            args.training_data_dir,
            args.evaluation_data_dir,
            args.evaluation_throttle_secs or args.evaluation_steps,
    )):
        job_type = JobType.TRAINING_WITH_EVALUATION
    elif all((
            args.evaluation_data_dir,
            not args.training_data_dir,
            not args.prediction_data_dir,
    )):
        job_type = JobType.EVALUATION_ONLY
    elif all((
            args.prediction_data_dir,
            not args.evaluation_data_dir,
            not args.training_data_dir,
    )):
        job_type = JobType.PREDICTION_ONLY
    else:
        job_type = JobType.TRAINING_ONLY

    # Initialize checkpoint service
    if args.checkpoint_steps or job_type == JobType.TRAINING_WITH_EVALUATION:
        logger.info("Starting checkpoint service")
        checkpoint_service = CheckpointService(
            args.checkpoint_dir,
            args.checkpoint_steps,
            args.keep_checkpoint_max,
            job_type == JobType.TRAINING_WITH_EVALUATION,
        )
    else:
        checkpoint_service = None

    # Initialize evaluation service
    evaluation_service = None
    if (job_type == JobType.TRAINING_WITH_EVALUATION
            or job_type == JobType.EVALUATION_ONLY):
        logger.info(
            "Starting evaluation service with throttle seconds %d "
            " and evaluation steps %d",
            args.evaluation_throttle_secs,
            args.evaluation_steps,
        )
        evaluation_service = EvaluationService(
            checkpoint_service,
            tb_service,
            task_d,
            args.evaluation_start_delay_secs,
            args.evaluation_throttle_secs,
            args.evaluation_steps,
            job_type == JobType.EVALUATION_ONLY,
        )
        evaluation_service.start()
        task_d.set_evaluation_service(evaluation_service)

    embedding_service_endpoint = None
    embedding_dims = {}
    # Search for embedding layers in the model,
    # if found, initialize embedding service
    layers = find_layer(model_inst, Embedding)
    if layers:
        embedding_service = EmbeddingService()
        embedding_service_endpoint = embedding_service.start_embedding_service(
            job_name=args.job_name,
            image_name=args.worker_image,
            namespace=args.namespace,
            resource_request=args.master_resource_request,
            resource_limit=args.master_resource_limit,
            pod_priority=args.worker_pod_priority,
            volume=args.volume,
            image_pull_policy=args.image_pull_policy,
            restart_policy=args.restart_policy,
            cluster_spec=args.cluster_spec,
        )
        logger.info("Embedding service start succeeded. The endpoint is %s." %
                    str(embedding_service_endpoint))
        embedding_dims = dict([(layer.name, layer.output_dim)
                               for layer in layers])

    # The master service
    logger.info("Starting master service")
    server = grpc.server(
        futures.ThreadPoolExecutor(max_workers=64),
        options=[
            ("grpc.max_send_message_length", GRPC.MAX_SEND_MESSAGE_LENGTH),
            (
                "grpc.max_receive_message_length",
                GRPC.MAX_RECEIVE_MESSAGE_LENGTH,
            ),
        ],
    )
    master_servicer = MasterServicer(
        args.grads_to_wait,
        args.minibatch_size,
        optimizer,
        task_d,
        init_var=model_inst.trainable_variables if model_inst.built else [],
        embedding_dims=embedding_dims,
        checkpoint_filename_for_init=args.checkpoint_filename_for_init,
        checkpoint_service=checkpoint_service,
        evaluation_service=evaluation_service,
        embedding_service_endpoint=embedding_service_endpoint,
        lr_staleness_modulation=args.lr_staleness_modulation,
        use_async=args.use_async,
    )
    elasticdl_pb2_grpc.add_MasterServicer_to_server(master_servicer, server)
    server.add_insecure_port("[::]:{}".format(args.port))
    server.start()
    logger.info("Server started at port: %d", args.port)

    worker_manager = None
    if args.num_workers:
        assert args.worker_image, "Worker image cannot be empty"

        worker_command = ["python"]
        worker_args = [
            "-m",
            "elasticdl.python.worker.main",
            "--model_zoo",
            args.model_zoo,
            "--master_addr",
            master_addr,
            "--log_level",
            args.log_level,
            "--dataset_fn",
            args.dataset_fn,
            "--loss",
            args.loss,
            "--optimizer",
            args.optimizer,
            "--eval_metrics_fn",
            args.eval_metrics_fn,
            "--model_def",
            args.model_def,
            "--job_type",
            job_type,
            "--minibatch_size",
            str(args.minibatch_size),
            "--embedding_service_endpoint",
            str(embedding_service_endpoint),
            "--get_model_steps",
            str(args.get_model_steps),
        ]

        env_dict = parse_envs(args.envs)
        env = []
        for key in env_dict:
            env.append(V1EnvVar(name=key, value=env_dict[key]))

        worker_manager = WorkerManager(
            task_d,
            job_name=args.job_name,
            image_name=args.worker_image,
            command=worker_command,
            args=worker_args,
            namespace=args.namespace,
            num_workers=args.num_workers,
            worker_resource_request=args.worker_resource_request,
            worker_resource_limit=args.worker_resource_limit,
            pod_priority=args.worker_pod_priority,
            volume=args.volume,
            image_pull_policy=args.image_pull_policy,
            restart_policy=args.restart_policy,
            cluster_spec=args.cluster_spec,
            envs=env,
        )
        worker_manager.update_status(WorkerManagerStatus.PENDING)
        logger.info("Launching %d workers", args.num_workers)
        worker_manager.start_workers()
        worker_manager.update_status(WorkerManagerStatus.RUNNING)

    # Start TensorBoard k8s Service if requested
    if tb_service:
        TensorBoardClient(
            job_name=args.job_name,
            image_name=args.worker_image,
            namespace=args.namespace,
        ).start_tensorboard_service()

    try:
        while True:
            if task_d.finished():
                if worker_manager:
                    worker_manager.update_status(WorkerManagerStatus.FINISHED)
                if args.output:
                    master_servicer.save_latest_checkpoint(args.output)
                break
            time.sleep(30)
    except KeyboardInterrupt:
        logger.warning("Server stopping")

    if evaluation_service:
        logger.info("Stopping evaluation service")
        evaluation_service.stop()

    logger.info("Stopping RPC server")
    server.stop(0)

    # Keep TensorBoard running when all the tasks are finished
    if tb_service:
        logger.info(
            "All tasks finished. Keeping TensorBoard service running...")
        while True:
            if tb_service.is_active():
                time.sleep(10)
            else:
                logger.warning("Unable to keep TensorBoard running. "
                               "It has already terminated")
                break
    logger.info("Master stopped")
Example #6
0
    def __init__(self, args):
        self.logger = get_logger("master", level=args.log_level.upper())

        self.num_ps_pods = args.num_ps_pods
        self.checkpoint_output_path = args.checkpoint_dir

        # Master addr
        master_ip = os.getenv("MY_POD_IP", "localhost")
        self.master_addr = "%s:%d" % (master_ip, args.port)
        self.job_type = Master._get_job_type(args)

        # Initialize TensorBoard service if requested
        self.tb_service = self._create_tensorboard_service(
            args.tensorboard_log_dir, master_ip)
        if self.tb_service:
            self.tb_client = TensorBoardClient(
                job_name=args.job_name,
                image_name=args.worker_image,
                namespace=args.namespace,
            )

        # Initialize the components from the model definition
        self.model_module = load_module(
            get_module_file_path(args.model_zoo, args.model_def)).__dict__
        self.model_inst = load_model_from_module(args.model_def,
                                                 self.model_module,
                                                 args.model_params)
        model_handler = ModelHandler.get_model_handler(
            args.distribution_strategy, checkpoint_dir=args.checkpoint_dir)
        self.model_inst = model_handler.get_model_to_train(self.model_inst)
        self.optimizer = self.model_module[args.optimizer]()
        self._create_data_reader_fn = create_data_reader
        if args.custom_data_reader in self.model_module:
            self._create_data_reader_fn = self.model_module[
                args.custom_data_reader]

        # Start task queue
        records_per_task = args.minibatch_size * args.num_minibatches_per_task
        self.task_d = _make_task_dispatcher(
            args.training_data,
            args.validation_data,
            args.prediction_data,
            records_per_task,
            args.num_epochs,
            args.data_reader_params,
            self._create_data_reader_fn,
        )

        saved_model_path = args.output
        if saved_model_path is not None and self.job_type in [
                JobType.TRAINING_ONLY,
                JobType.TRAINING_WITH_EVALUATION,
        ]:
            self.task_d.add_deferred_callback_create_save_model_task(
                saved_model_path)

        self.evaluation_service = self._create_evaluation_service(args)

        # Initialize master service
        self.master_servicer, self.server = self._create_master_service(args)

        # Initialize instance manager
        self.instance_manager = self._create_instance_manager(args)

        self._should_stop = False
        self._exit_code = 0
Example #7
0
class Master(object):
    def __init__(self, args):
        self.logger = get_logger("master", level=args.log_level.upper())

        self.num_ps_pods = args.num_ps_pods
        self.checkpoint_output_path = args.checkpoint_dir

        # Master addr
        master_ip = os.getenv("MY_POD_IP", "localhost")
        self.master_addr = "%s:%d" % (master_ip, args.port)
        self.job_type = Master._get_job_type(args)

        # Initialize TensorBoard service if requested
        self.tb_service = self._create_tensorboard_service(
            args.tensorboard_log_dir, master_ip)
        if self.tb_service:
            self.tb_client = TensorBoardClient(
                job_name=args.job_name,
                image_name=args.worker_image,
                namespace=args.namespace,
            )

        # Initialize the components from the model definition
        self.model_module = load_module(
            get_module_file_path(args.model_zoo, args.model_def)).__dict__
        self.model_inst = load_model_from_module(args.model_def,
                                                 self.model_module,
                                                 args.model_params)
        model_handler = ModelHandler.get_model_handler(
            args.distribution_strategy, checkpoint_dir=args.checkpoint_dir)
        self.model_inst = model_handler.get_model_to_train(self.model_inst)
        self.optimizer = self.model_module[args.optimizer]()
        self._create_data_reader_fn = create_data_reader
        if args.custom_data_reader in self.model_module:
            self._create_data_reader_fn = self.model_module[
                args.custom_data_reader]

        # Start task queue
        records_per_task = args.minibatch_size * args.num_minibatches_per_task
        self.task_d = _make_task_dispatcher(
            args.training_data,
            args.validation_data,
            args.prediction_data,
            records_per_task,
            args.num_epochs,
            args.data_reader_params,
            self._create_data_reader_fn,
        )

        saved_model_path = args.output
        if saved_model_path is not None and self.job_type in [
                JobType.TRAINING_ONLY,
                JobType.TRAINING_WITH_EVALUATION,
        ]:
            self.task_d.add_deferred_callback_create_save_model_task(
                saved_model_path)

        self.evaluation_service = self._create_evaluation_service(args)

        # Initialize master service
        self.master_servicer, self.server = self._create_master_service(args)

        # Initialize instance manager
        self.instance_manager = self._create_instance_manager(args)

        self._should_stop = False
        self._exit_code = 0

    def request_stop(self, err_msg=None):
        """Request master to quit"""
        self._should_stop = True
        if err_msg:
            self.logger.error(err_msg)
            # TODO (chengfu.wcy) create meaningful status codes
            self._exit_code = -1

    def prepare(self):
        """
        Start the components one by one. Make sure that it is ready to run.
        """
        # Start the evaluation service if requested
        if self.evaluation_service:
            self.logger.info("Starting evaluation service")
            self.evaluation_service.start()
            self.logger.info("Evaluation service started")

        # Start the master GRPC server
        self.logger.info("Starting master RPC server")
        self.server.start()
        self.logger.info("Master RPC server started")

        # Start the worker manager if requested
        if self.instance_manager:
            self.instance_manager.update_status(InstanceManagerStatus.PENDING)
            self.instance_manager.start_parameter_servers()
            self.instance_manager.start_workers()
            self.instance_manager.update_status(InstanceManagerStatus.RUNNING)

        # Start TensorBoard k8s Service if requested
        if self.tb_service and self.tb_client:
            self.logger.info("Starting tensorboard service")
            self.tb_service.start()
            self.tb_client.start_tensorboard_service()
            self.logger.info("Tensorboard service started")

    def run(self):
        """
        The main loop of master.
        Dispatch the tasks to the workers until all the tasks are completed.
        """
        try:
            while True:
                if self.task_d.finished():
                    if self.instance_manager:
                        self.instance_manager.update_status(
                            InstanceManagerStatus.FINISHED)
                    break
                if self._should_stop:
                    break
                time.sleep(30)
        except KeyboardInterrupt:
            self.logger.warning("Server stopping")
        finally:
            self._stop()
        return self._exit_code

    def _stop(self):
        """
        Stop all the components.
        Make sure that the created services and components are shut down.
        """
        self.logger.info("Stopping master")

        if self.evaluation_service:
            self.logger.info("Stopping evaluation service")
            self.evaluation_service.stop()
            self.logger.info("Evaluation service stopped")

        self.logger.info("Stopping RPC server")
        self.server.stop(None)  # grace = None
        self.logger.info("RPC server stopped")

        # Keep TensorBoard running when all the tasks are finished
        if self.tb_service:
            self.logger.info(
                "All tasks finished. Keeping TensorBoard service running...")
            while True:
                if self.tb_service.is_active():
                    time.sleep(10)
                else:
                    self.logger.warning("Unable to keep TensorBoard running. "
                                        "It has already terminated")
                    break
        self.logger.info("Master stopped")

    @staticmethod
    def _get_job_type(args):
        if all((
                args.training_data,
                args.validation_data,
                args.evaluation_throttle_secs or args.evaluation_steps,
        )):
            job_type = JobType.TRAINING_WITH_EVALUATION
        elif all((
                args.validation_data,
                not args.training_data,
                not args.prediction_data,
        )):
            job_type = JobType.EVALUATION_ONLY
        elif all((
                args.prediction_data,
                not args.validation_data,
                not args.training_data,
        )):
            job_type = JobType.PREDICTION_ONLY
        else:
            job_type = JobType.TRAINING_ONLY

        return job_type

    def _create_tensorboard_service(self, tensorboard_log_dir, master_ip):
        tb_service = None
        if tensorboard_log_dir:
            self.logger.info(
                "Creating TensorBoard service with log directory %s",
                tensorboard_log_dir,
            )
            # Start TensorBoard CLI
            tb_service = TensorboardService(tensorboard_log_dir, master_ip)

        return tb_service

    def _create_evaluation_service(self, args):
        evaluation_service = None
        if (self.job_type == JobType.TRAINING_WITH_EVALUATION
                or self.job_type == JobType.EVALUATION_ONLY):
            self.logger.info(
                "Creating evaluation service with throttle seconds %d "
                " and evaluation steps %d",
                args.evaluation_throttle_secs,
                args.evaluation_steps,
            )
            evaluation_service = EvaluationService(
                self.tb_service,
                self.task_d,
                args.evaluation_start_delay_secs,
                args.evaluation_throttle_secs,
                args.evaluation_steps,
                self.job_type == JobType.EVALUATION_ONLY,
                self.model_module[args.eval_metrics_fn],
            )
            self.task_d.set_evaluation_service(evaluation_service)

        return evaluation_service

    def _create_master_service(self, args):
        self.logger.info("Creating master service")
        server = grpc.server(
            futures.ThreadPoolExecutor(max_workers=64),
            options=[
                ("grpc.max_send_message_length", GRPC.MAX_SEND_MESSAGE_LENGTH),
                (
                    "grpc.max_receive_message_length",
                    GRPC.MAX_RECEIVE_MESSAGE_LENGTH,
                ),
            ],
        )
        master_servicer = MasterServicer(
            args.minibatch_size,
            self.task_d,
            evaluation_service=self.evaluation_service,
        )
        elasticdl_pb2_grpc.add_MasterServicer_to_server(
            master_servicer, server)
        server.add_insecure_port("[::]:{}".format(args.port))
        self.logger.info("The port of the master server is: %d", args.port)

        return master_servicer, server

    def _create_instance_manager(self, args):
        instance_manager = None
        if args.num_workers:
            assert args.worker_image, "Worker image cannot be empty"

            worker_command = ["python"]
            worker_args = [
                "-m",
                "elasticdl.python.worker.main",
                "--master_addr",
                self.master_addr,
                "--job_type",
                self.job_type,
            ]
            worker_args.extend(build_arguments_from_parsed_result(args))

            ps_command = ["python"]
            ps_args = [
                "-m",
                "elasticdl.python.ps.main",
                "--grads_to_wait",
                str(args.grads_to_wait),
                "--lr_staleness_modulation",
                str(args.lr_staleness_modulation),
                "--use_async",
                str(args.use_async),
                "--minibatch_size",
                str(args.minibatch_size),
                "--model_zoo",
                args.model_zoo,
                "--model_def",
                args.model_def,
                "--job_name",
                args.job_name,
                "--num_minibatches_per_task",
                str(args.num_minibatches_per_task),
                "--port",
                "2222",
                "--master_addr",
                self.master_addr,
                "--namespace",
                args.namespace,
                "--evaluation_steps",
                str(args.evaluation_steps),
                "--checkpoint_dir",
                str(args.checkpoint_dir),
                "--checkpoint_steps",
                str(args.checkpoint_steps),
                "--keep_checkpoint_max",
                str(args.keep_checkpoint_max),
                "--num_ps_pods",
                str(args.num_ps_pods),
                "--checkpoint_dir_for_init",
                str(args.checkpoint_dir_for_init),
                "--num_workers",
                str(args.num_workers),
                "--log_level",
                str(args.log_level),
            ]

            env_dict = parse_envs(args.envs)
            env = []
            for key in env_dict:
                env.append(V1EnvVar(name=key, value=env_dict[key]))

            instance_manager = InstanceManager(
                self.task_d,
                job_name=args.job_name,
                image_name=args.worker_image,
                worker_command=worker_command,
                worker_args=worker_args,
                namespace=args.namespace,
                num_workers=args.num_workers,
                worker_resource_request=args.worker_resource_request,
                worker_resource_limit=args.worker_resource_limit,
                worker_pod_priority=args.worker_pod_priority,
                num_ps=args.num_ps_pods,
                ps_command=ps_command,
                ps_args=ps_args,
                ps_resource_request=args.ps_resource_request,
                ps_resource_limit=args.ps_resource_limit,
                ps_pod_priority=args.ps_pod_priority,
                volume=args.volume,
                image_pull_policy=args.image_pull_policy,
                restart_policy=args.restart_policy,
                cluster_spec=args.cluster_spec,
                envs=env,
            )

        return instance_manager