def start_tensorboard_service(self): tb_client = TensorBoardClient(self._k8s_client) tb_client.create_tensorboard_service() logger.info("Waiting for the URL for TensorBoard service...") tb_url = tb_client.get_tensorboard_url() if tb_url: logger.info("TensorBoard service is available at: %s" % tb_url) else: logger.warning("Unable to get the URL for TensorBoard service")
def test_create_tensorboard_service(self): tb_client = TensorBoardClient( image_name=None, namespace="default", job_name="test-job-%d-%d" % (int(time.time()), random.randint(1, 101)), event_callback=None, ) tb_client._create_tensorboard_service(port=80, service_type="LoadBalancer") time.sleep(1) service = tb_client._get_tensorboard_service() self.assertTrue("load_balancer" in service["status"]) self.assertEqual(service["spec"]["ports"][0]["port"], 80)
def __init__(self, args): self.logger = get_logger("master", level=args.log_level.upper()) self.num_ps_pods = args.num_ps_pods self.checkpoint_output_path = args.checkpoint_dir self.distribution_strategy = args.distribution_strategy # Master addr master_ip = os.getenv("MY_POD_IP", "localhost") self.master_addr = "%s:%d" % (master_ip, args.port) self.job_type = Master._get_job_type(args) self.rendezvous_server = None if self.distribution_strategy == DistributionStrategy.ALLREDUCE: self.rendezvous_server = HorovodRendezvousServer(master_ip) # Initialize TensorBoard service if requested self.tb_service = self._create_tensorboard_service( args.tensorboard_log_dir, master_ip ) if self.tb_service: self.tb_client = TensorBoardClient( job_name=args.job_name, image_name=args.worker_image, namespace=args.namespace, ) # Initialize the components from the model definition self.model_module = load_module( get_module_file_path(args.model_zoo, args.model_def) ).__dict__ self.model_inst = load_model_from_module( args.model_def, self.model_module, args.model_params ) self.optimizer = self.model_module[args.optimizer]() self._create_data_reader_fn = create_data_reader if args.custom_data_reader in self.model_module: self._create_data_reader_fn = self.model_module[ args.custom_data_reader ] # Initialize the callbacks self.callbacks_list = load_callbacks_from_module( args.callbacks, self.model_module ) self.callbacks_list.set_model(self.model_inst) set_callback_parameters( self.callbacks_list, batch_size=args.minibatch_size, saved_model_path=args.output, checkpoint_path=args.checkpoint_dir, ) self._set_completed_steps_by_checkpoint(args.checkpoint_dir_for_init) # Start task queue records_per_task = args.minibatch_size * args.num_minibatches_per_task self.task_d = _make_task_dispatcher( args.training_data, args.validation_data, args.prediction_data, records_per_task, args.num_epochs, args.data_reader_params, self._create_data_reader_fn, self.callbacks_list, ) self.task_d.add_deferred_callback_create_train_end_task() self.evaluation_service = self._create_evaluation_service(args) # Initialize instance manager self.instance_manager = self._create_instance_manager(args) # Initialize master service self.master_servicer, self.server = self._create_master_service(args) self._should_stop = False self._exit_code = 0 threading.Thread( target=self._check_timeout_tasks, name="check_timeout_tasks", daemon=True, ).start()
class Master(object): def __init__(self, args): self.logger = get_logger("master", level=args.log_level.upper()) self.num_ps_pods = args.num_ps_pods self.checkpoint_output_path = args.checkpoint_dir self.distribution_strategy = args.distribution_strategy # Master addr master_ip = os.getenv("MY_POD_IP", "localhost") self.master_addr = "%s:%d" % (master_ip, args.port) self.job_type = Master._get_job_type(args) self.rendezvous_server = None if self.distribution_strategy == DistributionStrategy.ALLREDUCE: self.rendezvous_server = HorovodRendezvousServer(master_ip) # Initialize TensorBoard service if requested self.tb_service = self._create_tensorboard_service( args.tensorboard_log_dir, master_ip ) if self.tb_service: self.tb_client = TensorBoardClient( job_name=args.job_name, image_name=args.worker_image, namespace=args.namespace, ) # Initialize the components from the model definition self.model_module = load_module( get_module_file_path(args.model_zoo, args.model_def) ).__dict__ self.model_inst = load_model_from_module( args.model_def, self.model_module, args.model_params ) self.optimizer = self.model_module[args.optimizer]() self._create_data_reader_fn = create_data_reader if args.custom_data_reader in self.model_module: self._create_data_reader_fn = self.model_module[ args.custom_data_reader ] # Initialize the callbacks self.callbacks_list = load_callbacks_from_module( args.callbacks, self.model_module ) self.callbacks_list.set_model(self.model_inst) set_callback_parameters( self.callbacks_list, batch_size=args.minibatch_size, saved_model_path=args.output, checkpoint_path=args.checkpoint_dir, ) self._set_completed_steps_by_checkpoint(args.checkpoint_dir_for_init) # Start task queue records_per_task = args.minibatch_size * args.num_minibatches_per_task self.task_d = _make_task_dispatcher( args.training_data, args.validation_data, args.prediction_data, records_per_task, args.num_epochs, args.data_reader_params, self._create_data_reader_fn, self.callbacks_list, ) self.task_d.add_deferred_callback_create_train_end_task() self.evaluation_service = self._create_evaluation_service(args) # Initialize instance manager self.instance_manager = self._create_instance_manager(args) # Initialize master service self.master_servicer, self.server = self._create_master_service(args) self._should_stop = False self._exit_code = 0 threading.Thread( target=self._check_timeout_tasks, name="check_timeout_tasks", daemon=True, ).start() def _set_completed_steps_by_checkpoint(self, checkpoint_dir_for_init): if not checkpoint_dir_for_init: return if not CheckpointSaver.check_checkpoint_valid(checkpoint_dir_for_init): raise ValueError( "Invalid checkpoint directory {}".format( checkpoint_dir_for_init ) ) model_verion = CheckpointSaver.get_version_from_checkpoint( checkpoint_dir_for_init ) for callback in self.callbacks_list.callbacks: if isinstance(callback, MaxStepsStopping): callback.set_completed_steps(model_verion) def request_stop(self, err_msg=None): """Request master to quit""" self._should_stop = True if err_msg: self.logger.error(err_msg) # TODO (chengfu.wcy) create meaningful status codes self._exit_code = -1 def prepare(self): """ Start the components one by one. Make sure that it is ready to run. """ # Start the evaluation service if requested if self.evaluation_service: self.logger.info("Starting evaluation service") self.evaluation_service.start() self.logger.info("Evaluation service started") # Start the master GRPC server self.logger.info("Starting master RPC server") self.server.start() self.logger.info("Master RPC server started") # Start the worker manager if requested if self.instance_manager: self.instance_manager.update_status(InstanceManagerStatus.PENDING) if self.distribution_strategy == DistributionStrategy.ALLREDUCE: # Start rendezvous server for workers to initialize Horovod self.rendezvous_server.start() else: self.instance_manager.start_parameter_servers() self.instance_manager.start_workers() self.instance_manager.update_status(InstanceManagerStatus.RUNNING) # Start TensorBoard k8s Service if requested if self.tb_service and self.tb_client: self.logger.info("Starting tensorboard service") self.tb_service.start() self.tb_client.start_tensorboard_service() self.logger.info("Tensorboard service started") def run(self): """ The main loop of master. Dispatch the tasks to the workers until all the tasks are completed. """ try: while True: if self.instance_manager.all_workers_failed: raise Exception( "All workers fail with unrecoverable errors" ) break if self.task_d.finished(): if self.instance_manager: self.instance_manager.update_status( InstanceManagerStatus.FINISHED ) break if self._should_stop: break time.sleep(30) except KeyboardInterrupt: self.logger.warning("Server stopping") finally: self._stop() return self._exit_code def _stop(self): """ Stop all the components. Make sure that the created services and components are shut down. """ self.logger.info("Stopping master") if self.evaluation_service: self.logger.info("Stopping evaluation service") self.evaluation_service.stop() self.logger.info("Evaluation service stopped") self.logger.info("Stopping RPC server") self.server.stop(None) # grace = None self.logger.info("RPC server stopped") # Keep TensorBoard running when all the tasks are finished if self.tb_service: self.logger.info( "All tasks finished. Keeping TensorBoard service running..." ) while True: if self.tb_service.is_active(): time.sleep(10) else: self.logger.warning( "Unable to keep TensorBoard running. " "It has already terminated" ) break self.logger.info("Master stopped") @staticmethod def _get_job_type(args): if all( ( args.training_data, args.validation_data, args.evaluation_throttle_secs or args.evaluation_steps, ) ): job_type = JobType.TRAINING_WITH_EVALUATION elif all( ( args.validation_data, not args.training_data, not args.prediction_data, ) ): job_type = JobType.EVALUATION_ONLY elif all( ( args.prediction_data, not args.validation_data, not args.training_data, ) ): job_type = JobType.PREDICTION_ONLY else: job_type = JobType.TRAINING_ONLY return job_type def _create_tensorboard_service(self, tensorboard_log_dir, master_ip): tb_service = None if tensorboard_log_dir: self.logger.info( "Creating TensorBoard service with log directory %s", tensorboard_log_dir, ) # Start TensorBoard CLI tb_service = TensorboardService(tensorboard_log_dir, master_ip) return tb_service def _create_evaluation_service(self, args): evaluation_service = None if ( self.job_type == JobType.TRAINING_WITH_EVALUATION or self.job_type == JobType.EVALUATION_ONLY ): self.logger.info( "Creating evaluation service with throttle seconds %d " " and evaluation steps %d", args.evaluation_throttle_secs, args.evaluation_steps, ) evaluation_service = EvaluationService( self.tb_service, self.task_d, args.evaluation_start_delay_secs, args.evaluation_throttle_secs, args.evaluation_steps, self.job_type == JobType.EVALUATION_ONLY, self.model_module[args.eval_metrics_fn], ) self.task_d.set_evaluation_service(evaluation_service) return evaluation_service def _create_master_service(self, args): self.logger.info("Creating master service") server = grpc.server( futures.ThreadPoolExecutor(max_workers=64), options=[ ("grpc.max_send_message_length", GRPC.MAX_SEND_MESSAGE_LENGTH), ( "grpc.max_receive_message_length", GRPC.MAX_RECEIVE_MESSAGE_LENGTH, ), ], ) master_servicer = MasterServicer( args.minibatch_size, evaluation_service=self.evaluation_service, master=self, ) elasticdl_pb2_grpc.add_MasterServicer_to_server( master_servicer, server ) server.add_insecure_port("[::]:{}".format(args.port)) self.logger.info("The port of the master server is: %d", args.port) return master_servicer, server def _create_instance_manager(self, args): instance_manager = None container_command = ["/bin/bash"] if args.num_workers: assert args.worker_image, "Worker image cannot be empty" worker_client_command = ( BashCommandTemplate.SET_PIPEFAIL + " python -m elasticdl.python.worker.main" ) worker_args = [ "--master_addr", self.master_addr, "--job_type", self.job_type, ] worker_args.extend( build_arguments_from_parsed_result(args, filter_args=["envs"]) ) worker_args = wrap_python_args_with_string(worker_args) worker_args.insert(0, worker_client_command) if args.use_go_ps: opt_type, opt_args = get_optimizer_info(self.optimizer) ps_command = "elasticdl_ps" ps_command_args = [ "-job_name=" + args.job_name, "-namespace=" + args.namespace, "-master_addr=" + self.master_addr, "-port=2222", "-use_async=" + ("true" if args.use_async else "false"), "-grads_to_wait=" + str(args.grads_to_wait), "-lr_staleness_modulation=" + ("true" if args.lr_staleness_modulation else "false"), "-sync_version_tolerance=" + str(args.sync_version_tolerance), "-evaluation_steps=" + str(args.evaluation_steps), "-num_ps_pods=" + str(args.num_ps_pods), "-num_workers=" + str(args.num_workers), "-checkpoint_dir=" + str(args.checkpoint_dir), "-checkpoint_steps=" + str(args.checkpoint_steps), "-keep_checkpoint_max=" + str(args.keep_checkpoint_max), "-checkpoint_dir_for_init=" + str(args.checkpoint_dir_for_init), "-opt_type=" + opt_type, "-opt_args=" + opt_args, ] ps_command_args = wrap_go_args_with_string(ps_command_args) # Execute source /root/.bashrc to add the file path # of `elasticdl_ps` into the PATH environment variable. ps_args = [ "source", "/root/.bashrc_elasticdl", "&&", ps_command, ] ps_args.extend(ps_command_args) else: ps_command = ( BashCommandTemplate.SET_PIPEFAIL + " python -m elasticdl.python.ps.main" ) ps_command_args = [ "--grads_to_wait", str(args.grads_to_wait), "--lr_staleness_modulation", str(args.lr_staleness_modulation), "--sync_version_tolerance", str(args.sync_version_tolerance), "--use_async", str(args.use_async), "--model_zoo", args.model_zoo, "--model_def", args.model_def, "--job_name", args.job_name, "--port", "2222", "--master_addr", self.master_addr, "--namespace", args.namespace, "--evaluation_steps", str(args.evaluation_steps), "--checkpoint_dir", str(args.checkpoint_dir), "--checkpoint_steps", str(args.checkpoint_steps), "--keep_checkpoint_max", str(args.keep_checkpoint_max), "--num_ps_pods", str(args.num_ps_pods), "--checkpoint_dir_for_init", str(args.checkpoint_dir_for_init), "--num_workers", str(args.num_workers), "--log_level", str(args.log_level), "--minibatch_size", str(args.minibatch_size), "--num_minibatches_per_task", str(args.num_minibatches_per_task), ] ps_args = wrap_python_args_with_string(ps_command_args) ps_args.insert(0, ps_command) worker_args = ["-c", " ".join(worker_args)] ps_args = ["-c", " ".join(ps_args)] env_dict = parse_envs(args.envs) env = [] for key in env_dict: env.append(V1EnvVar(name=key, value=env_dict[key])) kwargs = get_dict_from_params_str(args.aux_params) disable_relaunch = kwargs.get("disable_relaunch", False) cluster_spec = self._get_image_cluster_spec(args.cluster_spec) instance_manager = InstanceManager( self.task_d, rendezvous_server=self.rendezvous_server, job_name=args.job_name, image_name=args.worker_image, worker_command=container_command, worker_args=worker_args, namespace=args.namespace, num_workers=args.num_workers, worker_resource_request=args.worker_resource_request, worker_resource_limit=args.worker_resource_limit, worker_pod_priority=args.worker_pod_priority, num_ps=args.num_ps_pods, ps_command=container_command, ps_args=ps_args, ps_resource_request=args.ps_resource_request, ps_resource_limit=args.ps_resource_limit, ps_pod_priority=args.ps_pod_priority, volume=args.volume, image_pull_policy=args.image_pull_policy, restart_policy=args.restart_policy, cluster_spec=cluster_spec, envs=env, disable_relaunch=disable_relaunch, log_file_path=args.log_file_path, ) return instance_manager def _get_image_cluster_spec(self, cluster_spec): if cluster_spec: filename = os.path.basename(cluster_spec) image_cluster_spec = os.path.join( ClusterSpecConfig.CLUSTER_SPEC_DIR, filename ) return image_cluster_spec return cluster_spec def _check_timeout_tasks(self): while True: doing_tasks = self.task_d._doing.copy() cur_time = time.time() avg_time = self.master_servicer.get_average_task_complete_time() for task_id, (worker_id, task, start_time) in doing_tasks.items(): if task.type == elasticdl_pb2.TRAINING: start_time = self.master_servicer.get_worker_liveness_time( worker_id ) if task.type in [ elasticdl_pb2.TRAINING, elasticdl_pb2.EVALUATION, ]: if (cur_time - start_time) > 3 * avg_time[task.type]: self.logger.info( "worker %d timeout, relaunch it" % worker_id ) self.task_d.recover_tasks(worker_id) # TODO: save worker logs before remove it self.instance_manager._remove_worker(worker_id) break time.sleep(30)
def main(): args = parse_args() logger = get_logger("master", level=args.log_level.upper()) # Master addr master_ip = os.getenv("MY_POD_IP", "localhost") master_addr = "%s:%d" % (master_ip, args.port) # Start TensorBoard service if requested if args.tensorboard_log_dir: logger.info( "Starting TensorBoard service with log directory %s", args.tensorboard_log_dir, ) # Start TensorBoard CLI tb_service = TensorboardService(args.tensorboard_log_dir, master_ip) tb_service.start() else: tb_service = None # Start task queue logger.debug( "Starting task queue with training data directory %s, " "evaluation data directory %s, " "and prediction data directory %s", args.training_data_dir, args.evaluation_data_dir, args.prediction_data_dir, ) task_d = _make_task_dispatcher( args.training_data_dir, args.evaluation_data_dir, args.prediction_data_dir, args.records_per_task, args.num_epochs, ) model_module = load_module( get_module_file_path(args.model_zoo, args.model_def)).__dict__ model_inst = load_model_from_module(args.model_def, model_module, args.model_params) optimizer = model_module[args.optimizer]() if all(( args.training_data_dir, args.evaluation_data_dir, args.evaluation_throttle_secs or args.evaluation_steps, )): job_type = JobType.TRAINING_WITH_EVALUATION elif all(( args.evaluation_data_dir, not args.training_data_dir, not args.prediction_data_dir, )): job_type = JobType.EVALUATION_ONLY elif all(( args.prediction_data_dir, not args.evaluation_data_dir, not args.training_data_dir, )): job_type = JobType.PREDICTION_ONLY else: job_type = JobType.TRAINING_ONLY # Initialize checkpoint service if args.checkpoint_steps or job_type == JobType.TRAINING_WITH_EVALUATION: logger.info("Starting checkpoint service") checkpoint_service = CheckpointService( args.checkpoint_dir, args.checkpoint_steps, args.keep_checkpoint_max, job_type == JobType.TRAINING_WITH_EVALUATION, ) else: checkpoint_service = None # Initialize evaluation service evaluation_service = None if (job_type == JobType.TRAINING_WITH_EVALUATION or job_type == JobType.EVALUATION_ONLY): logger.info( "Starting evaluation service with throttle seconds %d " " and evaluation steps %d", args.evaluation_throttle_secs, args.evaluation_steps, ) evaluation_service = EvaluationService( checkpoint_service, tb_service, task_d, args.evaluation_start_delay_secs, args.evaluation_throttle_secs, args.evaluation_steps, job_type == JobType.EVALUATION_ONLY, ) evaluation_service.start() task_d.set_evaluation_service(evaluation_service) embedding_service_endpoint = None embedding_dims = {} # Search for embedding layers in the model, # if found, initialize embedding service layers = find_layer(model_inst, Embedding) if layers: embedding_service = EmbeddingService() embedding_service_endpoint = embedding_service.start_embedding_service( job_name=args.job_name, image_name=args.worker_image, namespace=args.namespace, resource_request=args.master_resource_request, resource_limit=args.master_resource_limit, pod_priority=args.worker_pod_priority, volume=args.volume, image_pull_policy=args.image_pull_policy, restart_policy=args.restart_policy, cluster_spec=args.cluster_spec, ) logger.info("Embedding service start succeeded. The endpoint is %s." % str(embedding_service_endpoint)) embedding_dims = dict([(layer.name, layer.output_dim) for layer in layers]) # The master service logger.info("Starting master service") server = grpc.server( futures.ThreadPoolExecutor(max_workers=64), options=[ ("grpc.max_send_message_length", GRPC.MAX_SEND_MESSAGE_LENGTH), ( "grpc.max_receive_message_length", GRPC.MAX_RECEIVE_MESSAGE_LENGTH, ), ], ) master_servicer = MasterServicer( args.grads_to_wait, args.minibatch_size, optimizer, task_d, init_var=model_inst.trainable_variables if model_inst.built else [], embedding_dims=embedding_dims, checkpoint_filename_for_init=args.checkpoint_filename_for_init, checkpoint_service=checkpoint_service, evaluation_service=evaluation_service, embedding_service_endpoint=embedding_service_endpoint, lr_staleness_modulation=args.lr_staleness_modulation, use_async=args.use_async, ) elasticdl_pb2_grpc.add_MasterServicer_to_server(master_servicer, server) server.add_insecure_port("[::]:{}".format(args.port)) server.start() logger.info("Server started at port: %d", args.port) worker_manager = None if args.num_workers: assert args.worker_image, "Worker image cannot be empty" worker_command = ["python"] worker_args = [ "-m", "elasticdl.python.worker.main", "--model_zoo", args.model_zoo, "--master_addr", master_addr, "--log_level", args.log_level, "--dataset_fn", args.dataset_fn, "--loss", args.loss, "--optimizer", args.optimizer, "--eval_metrics_fn", args.eval_metrics_fn, "--model_def", args.model_def, "--job_type", job_type, "--minibatch_size", str(args.minibatch_size), "--embedding_service_endpoint", str(embedding_service_endpoint), "--get_model_steps", str(args.get_model_steps), ] env_dict = parse_envs(args.envs) env = [] for key in env_dict: env.append(V1EnvVar(name=key, value=env_dict[key])) worker_manager = WorkerManager( task_d, job_name=args.job_name, image_name=args.worker_image, command=worker_command, args=worker_args, namespace=args.namespace, num_workers=args.num_workers, worker_resource_request=args.worker_resource_request, worker_resource_limit=args.worker_resource_limit, pod_priority=args.worker_pod_priority, volume=args.volume, image_pull_policy=args.image_pull_policy, restart_policy=args.restart_policy, cluster_spec=args.cluster_spec, envs=env, ) worker_manager.update_status(WorkerManagerStatus.PENDING) logger.info("Launching %d workers", args.num_workers) worker_manager.start_workers() worker_manager.update_status(WorkerManagerStatus.RUNNING) # Start TensorBoard k8s Service if requested if tb_service: TensorBoardClient( job_name=args.job_name, image_name=args.worker_image, namespace=args.namespace, ).start_tensorboard_service() try: while True: if task_d.finished(): if worker_manager: worker_manager.update_status(WorkerManagerStatus.FINISHED) if args.output: master_servicer.save_latest_checkpoint(args.output) break time.sleep(30) except KeyboardInterrupt: logger.warning("Server stopping") if evaluation_service: logger.info("Stopping evaluation service") evaluation_service.stop() logger.info("Stopping RPC server") server.stop(0) # Keep TensorBoard running when all the tasks are finished if tb_service: logger.info( "All tasks finished. Keeping TensorBoard service running...") while True: if tb_service.is_active(): time.sleep(10) else: logger.warning("Unable to keep TensorBoard running. " "It has already terminated") break logger.info("Master stopped")
def __init__(self, args): self.logger = get_logger("master", level=args.log_level.upper()) self.num_ps_pods = args.num_ps_pods self.checkpoint_output_path = args.checkpoint_dir # Master addr master_ip = os.getenv("MY_POD_IP", "localhost") self.master_addr = "%s:%d" % (master_ip, args.port) self.job_type = Master._get_job_type(args) # Initialize TensorBoard service if requested self.tb_service = self._create_tensorboard_service( args.tensorboard_log_dir, master_ip) if self.tb_service: self.tb_client = TensorBoardClient( job_name=args.job_name, image_name=args.worker_image, namespace=args.namespace, ) # Initialize the components from the model definition self.model_module = load_module( get_module_file_path(args.model_zoo, args.model_def)).__dict__ self.model_inst = load_model_from_module(args.model_def, self.model_module, args.model_params) model_handler = ModelHandler.get_model_handler( args.distribution_strategy, checkpoint_dir=args.checkpoint_dir) self.model_inst = model_handler.get_model_to_train(self.model_inst) self.optimizer = self.model_module[args.optimizer]() self._create_data_reader_fn = create_data_reader if args.custom_data_reader in self.model_module: self._create_data_reader_fn = self.model_module[ args.custom_data_reader] # Start task queue records_per_task = args.minibatch_size * args.num_minibatches_per_task self.task_d = _make_task_dispatcher( args.training_data, args.validation_data, args.prediction_data, records_per_task, args.num_epochs, args.data_reader_params, self._create_data_reader_fn, ) saved_model_path = args.output if saved_model_path is not None and self.job_type in [ JobType.TRAINING_ONLY, JobType.TRAINING_WITH_EVALUATION, ]: self.task_d.add_deferred_callback_create_save_model_task( saved_model_path) self.evaluation_service = self._create_evaluation_service(args) # Initialize master service self.master_servicer, self.server = self._create_master_service(args) # Initialize instance manager self.instance_manager = self._create_instance_manager(args) self._should_stop = False self._exit_code = 0
class Master(object): def __init__(self, args): self.logger = get_logger("master", level=args.log_level.upper()) self.num_ps_pods = args.num_ps_pods self.checkpoint_output_path = args.checkpoint_dir # Master addr master_ip = os.getenv("MY_POD_IP", "localhost") self.master_addr = "%s:%d" % (master_ip, args.port) self.job_type = Master._get_job_type(args) # Initialize TensorBoard service if requested self.tb_service = self._create_tensorboard_service( args.tensorboard_log_dir, master_ip) if self.tb_service: self.tb_client = TensorBoardClient( job_name=args.job_name, image_name=args.worker_image, namespace=args.namespace, ) # Initialize the components from the model definition self.model_module = load_module( get_module_file_path(args.model_zoo, args.model_def)).__dict__ self.model_inst = load_model_from_module(args.model_def, self.model_module, args.model_params) model_handler = ModelHandler.get_model_handler( args.distribution_strategy, checkpoint_dir=args.checkpoint_dir) self.model_inst = model_handler.get_model_to_train(self.model_inst) self.optimizer = self.model_module[args.optimizer]() self._create_data_reader_fn = create_data_reader if args.custom_data_reader in self.model_module: self._create_data_reader_fn = self.model_module[ args.custom_data_reader] # Start task queue records_per_task = args.minibatch_size * args.num_minibatches_per_task self.task_d = _make_task_dispatcher( args.training_data, args.validation_data, args.prediction_data, records_per_task, args.num_epochs, args.data_reader_params, self._create_data_reader_fn, ) saved_model_path = args.output if saved_model_path is not None and self.job_type in [ JobType.TRAINING_ONLY, JobType.TRAINING_WITH_EVALUATION, ]: self.task_d.add_deferred_callback_create_save_model_task( saved_model_path) self.evaluation_service = self._create_evaluation_service(args) # Initialize master service self.master_servicer, self.server = self._create_master_service(args) # Initialize instance manager self.instance_manager = self._create_instance_manager(args) self._should_stop = False self._exit_code = 0 def request_stop(self, err_msg=None): """Request master to quit""" self._should_stop = True if err_msg: self.logger.error(err_msg) # TODO (chengfu.wcy) create meaningful status codes self._exit_code = -1 def prepare(self): """ Start the components one by one. Make sure that it is ready to run. """ # Start the evaluation service if requested if self.evaluation_service: self.logger.info("Starting evaluation service") self.evaluation_service.start() self.logger.info("Evaluation service started") # Start the master GRPC server self.logger.info("Starting master RPC server") self.server.start() self.logger.info("Master RPC server started") # Start the worker manager if requested if self.instance_manager: self.instance_manager.update_status(InstanceManagerStatus.PENDING) self.instance_manager.start_parameter_servers() self.instance_manager.start_workers() self.instance_manager.update_status(InstanceManagerStatus.RUNNING) # Start TensorBoard k8s Service if requested if self.tb_service and self.tb_client: self.logger.info("Starting tensorboard service") self.tb_service.start() self.tb_client.start_tensorboard_service() self.logger.info("Tensorboard service started") def run(self): """ The main loop of master. Dispatch the tasks to the workers until all the tasks are completed. """ try: while True: if self.task_d.finished(): if self.instance_manager: self.instance_manager.update_status( InstanceManagerStatus.FINISHED) break if self._should_stop: break time.sleep(30) except KeyboardInterrupt: self.logger.warning("Server stopping") finally: self._stop() return self._exit_code def _stop(self): """ Stop all the components. Make sure that the created services and components are shut down. """ self.logger.info("Stopping master") if self.evaluation_service: self.logger.info("Stopping evaluation service") self.evaluation_service.stop() self.logger.info("Evaluation service stopped") self.logger.info("Stopping RPC server") self.server.stop(None) # grace = None self.logger.info("RPC server stopped") # Keep TensorBoard running when all the tasks are finished if self.tb_service: self.logger.info( "All tasks finished. Keeping TensorBoard service running...") while True: if self.tb_service.is_active(): time.sleep(10) else: self.logger.warning("Unable to keep TensorBoard running. " "It has already terminated") break self.logger.info("Master stopped") @staticmethod def _get_job_type(args): if all(( args.training_data, args.validation_data, args.evaluation_throttle_secs or args.evaluation_steps, )): job_type = JobType.TRAINING_WITH_EVALUATION elif all(( args.validation_data, not args.training_data, not args.prediction_data, )): job_type = JobType.EVALUATION_ONLY elif all(( args.prediction_data, not args.validation_data, not args.training_data, )): job_type = JobType.PREDICTION_ONLY else: job_type = JobType.TRAINING_ONLY return job_type def _create_tensorboard_service(self, tensorboard_log_dir, master_ip): tb_service = None if tensorboard_log_dir: self.logger.info( "Creating TensorBoard service with log directory %s", tensorboard_log_dir, ) # Start TensorBoard CLI tb_service = TensorboardService(tensorboard_log_dir, master_ip) return tb_service def _create_evaluation_service(self, args): evaluation_service = None if (self.job_type == JobType.TRAINING_WITH_EVALUATION or self.job_type == JobType.EVALUATION_ONLY): self.logger.info( "Creating evaluation service with throttle seconds %d " " and evaluation steps %d", args.evaluation_throttle_secs, args.evaluation_steps, ) evaluation_service = EvaluationService( self.tb_service, self.task_d, args.evaluation_start_delay_secs, args.evaluation_throttle_secs, args.evaluation_steps, self.job_type == JobType.EVALUATION_ONLY, self.model_module[args.eval_metrics_fn], ) self.task_d.set_evaluation_service(evaluation_service) return evaluation_service def _create_master_service(self, args): self.logger.info("Creating master service") server = grpc.server( futures.ThreadPoolExecutor(max_workers=64), options=[ ("grpc.max_send_message_length", GRPC.MAX_SEND_MESSAGE_LENGTH), ( "grpc.max_receive_message_length", GRPC.MAX_RECEIVE_MESSAGE_LENGTH, ), ], ) master_servicer = MasterServicer( args.minibatch_size, self.task_d, evaluation_service=self.evaluation_service, ) elasticdl_pb2_grpc.add_MasterServicer_to_server( master_servicer, server) server.add_insecure_port("[::]:{}".format(args.port)) self.logger.info("The port of the master server is: %d", args.port) return master_servicer, server def _create_instance_manager(self, args): instance_manager = None if args.num_workers: assert args.worker_image, "Worker image cannot be empty" worker_command = ["python"] worker_args = [ "-m", "elasticdl.python.worker.main", "--master_addr", self.master_addr, "--job_type", self.job_type, ] worker_args.extend(build_arguments_from_parsed_result(args)) ps_command = ["python"] ps_args = [ "-m", "elasticdl.python.ps.main", "--grads_to_wait", str(args.grads_to_wait), "--lr_staleness_modulation", str(args.lr_staleness_modulation), "--use_async", str(args.use_async), "--minibatch_size", str(args.minibatch_size), "--model_zoo", args.model_zoo, "--model_def", args.model_def, "--job_name", args.job_name, "--num_minibatches_per_task", str(args.num_minibatches_per_task), "--port", "2222", "--master_addr", self.master_addr, "--namespace", args.namespace, "--evaluation_steps", str(args.evaluation_steps), "--checkpoint_dir", str(args.checkpoint_dir), "--checkpoint_steps", str(args.checkpoint_steps), "--keep_checkpoint_max", str(args.keep_checkpoint_max), "--num_ps_pods", str(args.num_ps_pods), "--checkpoint_dir_for_init", str(args.checkpoint_dir_for_init), "--num_workers", str(args.num_workers), "--log_level", str(args.log_level), ] env_dict = parse_envs(args.envs) env = [] for key in env_dict: env.append(V1EnvVar(name=key, value=env_dict[key])) instance_manager = InstanceManager( self.task_d, job_name=args.job_name, image_name=args.worker_image, worker_command=worker_command, worker_args=worker_args, namespace=args.namespace, num_workers=args.num_workers, worker_resource_request=args.worker_resource_request, worker_resource_limit=args.worker_resource_limit, worker_pod_priority=args.worker_pod_priority, num_ps=args.num_ps_pods, ps_command=ps_command, ps_args=ps_args, ps_resource_request=args.ps_resource_request, ps_resource_limit=args.ps_resource_limit, ps_pod_priority=args.ps_pod_priority, volume=args.volume, image_pull_policy=args.image_pull_policy, restart_policy=args.restart_policy, cluster_spec=args.cluster_spec, envs=env, ) return instance_manager