def testRelaunchWorkerPod(self): task_d = _TaskDispatcher({"f": (0, 10)}, {}, {}, 1, 1) worker_manager = WorkerManager( task_d, job_name="test-relaunch-worker-pod-%d-%d" % (int(time.time()), random.randint(1, 101)), image_name="gcr.io/google-samples/hello-app:1.0", command=["sleep 10"], args=[], namespace="default", num_workers=3, ) worker_manager.start_workers() max_check_num = 60 for _ in range(max_check_num): time.sleep(1) counters = worker_manager.get_counters() print(counters) if counters["Running"] + counters["Pending"] > 0: break # Note: There is a slight chance of race condition. # Hack to find a worker to remove current_workers = set() live_workers = set() with worker_manager._lock: for k, (_, phase) in worker_manager._pods_phase.items(): current_workers.add(k) if phase in ["Running", "Pending"]: live_workers.add(k) self.assertTrue(live_workers) worker_manager._remove_worker(live_workers.pop()) # verify a new worker get launched found = False print(current_workers) for _ in range(max_check_num): if found: break time.sleep(1) counters = worker_manager.get_counters() print(counters) with worker_manager._lock: for k in worker_manager._pods_phase: if k not in current_workers: found = True else: self.fail("Failed to find newly launched worker.") worker_manager.stop_relaunch_and_remove_workers()
def testCreateDeleteWorkerPod(self): task_d = _TaskDispatcher({"f": (0, 10)}, {}, {}, 1, 1) task_d.recover_tasks = MagicMock() worker_manager = WorkerManager( task_d, job_name="test-create-worker-pod-%d-%d" % (int(time.time()), random.randint(1, 101)), image_name="gcr.io/google-samples/hello-app:1.0", command=["echo"], args=[], namespace="default", num_workers=3, ) worker_manager.start_workers() max_check_num = 20 for _ in range(max_check_num): time.sleep(3) counters = worker_manager.get_counters() print(counters) if counters["Succeeded"] == 3: break worker_manager.stop_relaunch_and_remove_workers() for _ in range(max_check_num): time.sleep(3) counters = worker_manager.get_counters() print(counters) if not counters: break task_d.recover_tasks.assert_has_calls( [call(0), call(1), call(2)], any_order=True)
def testFailedWorkerPod(self): """ Start a pod running a python program destined to fail with restart_policy="Never" to test failed_worker_count """ task_d = _TaskDispatcher({"f": 10}, {}, {}, 1, 1) task_d.recover_tasks = MagicMock() worker_manager = WorkerManager( task_d, job_name="test-failed-worker-pod-%d-%d" % (int(time.time()), random.randint(1, 101)), image_name="gcr.io/google-samples/hello-app:1.0", command=["badcommand"], args=["badargs"], namespace="default", num_workers=3, restart_policy="Never", ) worker_manager.start_workers() max_check_num = 20 for _ in range(max_check_num): time.sleep(3) counters = worker_manager.get_counters() print(counters) if counters["Failed"] == 3: break worker_manager.stop_relaunch_and_remove_workers() for _ in range(max_check_num): time.sleep(3) counters = worker_manager.get_counters() print(counters) if not counters: break task_d.recover_tasks.assert_has_calls( [call(0), call(1), call(2)], any_order=True )
def main(): args = parse_args() logger = get_logger("master", level=args.log_level.upper()) # Master addr master_ip = os.getenv("MY_POD_IP", "localhost") master_addr = "%s:%d" % (master_ip, args.port) # Start TensorBoard service if requested if args.tensorboard_log_dir: logger.info( "Starting TensorBoard service with log directory %s", args.tensorboard_log_dir, ) # Start TensorBoard CLI tb_service = TensorboardService(args.tensorboard_log_dir, master_ip) tb_service.start() else: tb_service = None # Start task queue logger.debug( "Starting task queue with training data directory %s, " "evaluation data directory %s, " "and prediction data directory %s", args.training_data_dir, args.evaluation_data_dir, args.prediction_data_dir, ) task_d = _make_task_dispatcher( args.training_data_dir, args.evaluation_data_dir, args.prediction_data_dir, args.records_per_task, args.num_epochs, ) model_module = load_module( get_module_file_path(args.model_zoo, args.model_def)).__dict__ model_inst = load_model_from_module(args.model_def, model_module, args.model_params) optimizer = model_module[args.optimizer]() if all(( args.training_data_dir, args.evaluation_data_dir, args.evaluation_throttle_secs or args.evaluation_steps, )): job_type = JobType.TRAINING_WITH_EVALUATION elif all(( args.evaluation_data_dir, not args.training_data_dir, not args.prediction_data_dir, )): job_type = JobType.EVALUATION_ONLY elif all(( args.prediction_data_dir, not args.evaluation_data_dir, not args.training_data_dir, )): job_type = JobType.PREDICTION_ONLY else: job_type = JobType.TRAINING_ONLY # Initialize checkpoint service if args.checkpoint_steps or job_type == JobType.TRAINING_WITH_EVALUATION: logger.info("Starting checkpoint service") checkpoint_service = CheckpointService( args.checkpoint_dir, args.checkpoint_steps, args.keep_checkpoint_max, job_type == JobType.TRAINING_WITH_EVALUATION, ) else: checkpoint_service = None # Initialize evaluation service evaluation_service = None if (job_type == JobType.TRAINING_WITH_EVALUATION or job_type == JobType.EVALUATION_ONLY): logger.info( "Starting evaluation service with throttle seconds %d " " and evaluation steps %d", args.evaluation_throttle_secs, args.evaluation_steps, ) evaluation_service = EvaluationService( checkpoint_service, tb_service, task_d, args.evaluation_start_delay_secs, args.evaluation_throttle_secs, args.evaluation_steps, job_type == JobType.EVALUATION_ONLY, ) evaluation_service.start() task_d.set_evaluation_service(evaluation_service) embedding_service_endpoint = None embedding_dims = {} # Search for embedding layers in the model, # if found, initialize embedding service layers = find_layer(model_inst, Embedding) if layers: embedding_service = EmbeddingService() embedding_service_endpoint = embedding_service.start_embedding_service( job_name=args.job_name, image_name=args.worker_image, namespace=args.namespace, resource_request=args.master_resource_request, resource_limit=args.master_resource_limit, pod_priority=args.worker_pod_priority, volume=args.volume, image_pull_policy=args.image_pull_policy, restart_policy=args.restart_policy, cluster_spec=args.cluster_spec, ) logger.info("Embedding service start succeeded. The endpoint is %s." % str(embedding_service_endpoint)) embedding_dims = dict([(layer.name, layer.output_dim) for layer in layers]) # The master service logger.info("Starting master service") server = grpc.server( futures.ThreadPoolExecutor(max_workers=64), options=[ ("grpc.max_send_message_length", GRPC.MAX_SEND_MESSAGE_LENGTH), ( "grpc.max_receive_message_length", GRPC.MAX_RECEIVE_MESSAGE_LENGTH, ), ], ) master_servicer = MasterServicer( args.grads_to_wait, args.minibatch_size, optimizer, task_d, init_var=model_inst.trainable_variables if model_inst.built else [], embedding_dims=embedding_dims, checkpoint_filename_for_init=args.checkpoint_filename_for_init, checkpoint_service=checkpoint_service, evaluation_service=evaluation_service, embedding_service_endpoint=embedding_service_endpoint, lr_staleness_modulation=args.lr_staleness_modulation, use_async=args.use_async, ) elasticdl_pb2_grpc.add_MasterServicer_to_server(master_servicer, server) server.add_insecure_port("[::]:{}".format(args.port)) server.start() logger.info("Server started at port: %d", args.port) worker_manager = None if args.num_workers: assert args.worker_image, "Worker image cannot be empty" worker_command = ["python"] worker_args = [ "-m", "elasticdl.python.worker.main", "--model_zoo", args.model_zoo, "--master_addr", master_addr, "--log_level", args.log_level, "--dataset_fn", args.dataset_fn, "--loss", args.loss, "--optimizer", args.optimizer, "--eval_metrics_fn", args.eval_metrics_fn, "--model_def", args.model_def, "--job_type", job_type, "--minibatch_size", str(args.minibatch_size), "--embedding_service_endpoint", str(embedding_service_endpoint), "--get_model_steps", str(args.get_model_steps), ] env_dict = parse_envs(args.envs) env = [] for key in env_dict: env.append(V1EnvVar(name=key, value=env_dict[key])) worker_manager = WorkerManager( task_d, job_name=args.job_name, image_name=args.worker_image, command=worker_command, args=worker_args, namespace=args.namespace, num_workers=args.num_workers, worker_resource_request=args.worker_resource_request, worker_resource_limit=args.worker_resource_limit, pod_priority=args.worker_pod_priority, volume=args.volume, image_pull_policy=args.image_pull_policy, restart_policy=args.restart_policy, cluster_spec=args.cluster_spec, envs=env, ) worker_manager.update_status(WorkerManagerStatus.PENDING) logger.info("Launching %d workers", args.num_workers) worker_manager.start_workers() worker_manager.update_status(WorkerManagerStatus.RUNNING) # Start TensorBoard k8s Service if requested if tb_service: TensorBoardClient( job_name=args.job_name, image_name=args.worker_image, namespace=args.namespace, ).start_tensorboard_service() try: while True: if task_d.finished(): if worker_manager: worker_manager.update_status(WorkerManagerStatus.FINISHED) if args.output: master_servicer.save_latest_checkpoint(args.output) break time.sleep(30) except KeyboardInterrupt: logger.warning("Server stopping") if evaluation_service: logger.info("Stopping evaluation service") evaluation_service.stop() logger.info("Stopping RPC server") server.stop(0) # Keep TensorBoard running when all the tasks are finished if tb_service: logger.info( "All tasks finished. Keeping TensorBoard service running...") while True: if tb_service.is_active(): time.sleep(10) else: logger.warning("Unable to keep TensorBoard running. " "It has already terminated") break logger.info("Master stopped")