def main(): args = parse_worker_args() logger = log_utils.get_logger(__name__) logger.info("Starting worker %d", args.worker_id) if args.master_addr is None: raise ValueError("master_addr is missing for worker") master_channel = build_channel(args.master_addr) ps_channels = [] if args.ps_addrs: ps_addrs = args.ps_addrs.split(",") for addr in ps_addrs: # addr is in the form as "ps-pod-name.namespace.svc:port" channel = build_channel(addr) # Wait the channel is ready by a Future object. grpc.channel_ready_future(channel).result() logger.info("grpc channel %s to connect pod %s is ready" % (addr, addr.split(".")[0])) ps_channels.append(channel) worker = Worker(args, channel=master_channel, ps_channels=ps_channels) worker.run()
def main(): args = parse_worker_args() channel = grpc.insecure_channel( args.master_addr, options=[ ("grpc.max_send_message_length", GRPC.MAX_SEND_MESSAGE_LENGTH), ( "grpc.max_receive_message_length", GRPC.MAX_RECEIVE_MESSAGE_LENGTH, ), ], ) logger = log_util.get_logger(__name__) logger.info("Starting worker %d", args.worker_id) worker = Worker( args.worker_id, args.job_type, args.minibatch_size, args.model_zoo, channel=channel, embedding_service_endpoint=eval(args.embedding_service_endpoint), dataset_fn=args.dataset_fn, loss=args.loss, optimizer=args.optimizer, eval_metrics_fn=args.eval_metrics_fn, model_def=args.model_def, model_params=args.model_params, get_model_steps=args.get_model_steps, ) worker.run()
def main(): args = parse_worker_args() if args.master_addr is None: raise ValueError("master_addr is missing for worker") channel = grpc.insecure_channel( args.master_addr, options=[ ("grpc.max_send_message_length", GRPC.MAX_SEND_MESSAGE_LENGTH), ( "grpc.max_receive_message_length", GRPC.MAX_RECEIVE_MESSAGE_LENGTH, ), ], ) # TODO, create PS channels here ps_addrs = args.ps_addrs.split(",") # Just print ps_addrs out to avoid flake8 failure # This print can be removed once we initialize ps_channels # by using ps_addrs print("Parameter server addresses are %s" % ps_addrs) ps_channels = None logger = log_utils.get_logger(__name__) logger.info("Starting worker %d", args.worker_id) worker = Worker(args, channel=channel, ps_channels=ps_channels) worker.run()
def main(): args = parse_worker_args() logger = log_utils.get_logger(__name__) logger.info("Starting worker %d", args.worker_id) if args.master_addr is None: raise ValueError("master_addr is missing for worker") master_channel = build_channel(args.master_addr) ps_channels = [] if args.ps_addrs: # TODO: use ps_addrs from master directly after ps service is working. # Get ps pod ip for ps grpc connection for now. ps_addrs = args.ps_addrs.split(",") config.load_incluster_config() api = client.CoreV1Api() for addr in ps_addrs: # addr is in the form as "ps-pod-name.namespace.svc:port" addr_splitted = addr.split(".") while True: pod = api.read_namespaced_pod( namespace=addr_splitted[1], name=addr_splitted[0] ) if pod.status.pod_ip: break # If ps pod is not ready yet, sleep 2 seconds and try again. time.sleep(2) addr = pod.status.pod_ip + ":" + addr.split(":")[-1] channel = grpc.insecure_channel( addr, options=[ ( "grpc.max_send_message_length", GRPC.MAX_SEND_MESSAGE_LENGTH, ), ( "grpc.max_receive_message_length", GRPC.MAX_RECEIVE_MESSAGE_LENGTH, ), ], ) # Wait the channel is ready by a Future object. grpc.channel_ready_future(channel).result() logger.info( "grpc channel %s to connect pod %s is ready" % (addr, pod.metadata.name) ) ps_channels.append(channel) worker = Worker(args, channel=master_channel, ps_channels=ps_channels) worker.run()
def main(): args = parse_worker_args() logger = log_utils.get_logger(__name__) logger.info("Starting worker %d", args.worker_id) if args.master_addr is None: raise ValueError("master_addr is missing for worker") master_channel = grpc.insecure_channel( args.master_addr, options=[ ("grpc.max_send_message_length", GRPC.MAX_SEND_MESSAGE_LENGTH), ( "grpc.max_receive_message_length", GRPC.MAX_RECEIVE_MESSAGE_LENGTH, ), ], ) ps_channels = [] if args.ps_addrs: # TODO: use ps_addrs from master directly after ps service is working. # Get ps pod ip for ps grpc connection for now. ps_addrs = args.ps_addrs.split(",") from kubernetes import client, config config.load_incluster_config() api = client.CoreV1Api() for addr in ps_addrs: # addr is in the form as "ps-pod-name.namespace.svc:port" addr_splitted = addr.split(".") pod = api.read_namespaced_pod(namespace=addr_splitted[1], name=addr_splitted[0]) addr = pod.status.pod_ip + ":" + addr.split(":")[-1] channel = grpc.insecure_channel( addr, options=[ ( "grpc.max_send_message_length", GRPC.MAX_SEND_MESSAGE_LENGTH, ), ( "grpc.max_receive_message_length", GRPC.MAX_RECEIVE_MESSAGE_LENGTH, ), ], ) ps_channels.append(channel) worker = Worker(args, channel=master_channel, ps_channels=ps_channels) worker.run()
def main(): args = parse_worker_args() logger = log_utils.get_logger(__name__) logger.info("Starting worker %d", args.worker_id) if args.master_addr is None: raise ValueError("master_addr is missing for worker") master_channel = build_channel(args.master_addr) ps_channels = [] if args.ps_addrs: ps_addrs = args.ps_addrs.split(",") for addr in ps_addrs: # addr is in the form as "ps-pod-name.namespace.svc:port" channel = build_channel(addr) succeeded = False for i in range(CONNECT_PS_MAX_RETRIES): try: grpc.channel_ready_future(channel).result( timeout=CONNECT_PS_TIMEOUT) logger.info("grpc channel %s to connect pod %s is ready" % (addr, addr.split(".")[0])) ps_channels.append(channel) succeeded = True break except grpc.FutureTimeoutError: logger.warning("Failed to connect pod %s with %d retry" % (addr.split(".")[0], i)) if not succeeded: raise TimeoutError( "Time out to connect pod %s with 3 retries" % addr.split(".")[0]) if args.distribution_strategy == DistributionStrategy.ALLREDUCE: logger.info("Wait for %s seconds for FTLib consensus service to " "detect the worker pod" % str(_ALLREDUCE_STRATEGY_WARM_UP_SECS)) time.sleep(_ALLREDUCE_STRATEGY_WARM_UP_SECS) worker = Worker( args, channel=master_channel, ps_channels=ps_channels, set_parallelism=True, ) worker.run()
def main(): args = parse_worker_args() logger = log_utils.get_logger(__name__) logger.info("Starting worker %d", args.worker_id) if args.master_addr is None: raise ValueError("master_addr is missing for worker") master_client = MasterClient(build_channel(args.master_addr), args.worker_id) ps_client = None if (args.distribution_strategy == DistributionStrategy.PARAMETER_SERVER and args.ps_addrs): ps_channels = [] ps_addrs = args.ps_addrs.split(",") for addr in ps_addrs: # addr is in the form as "ps-pod-name.namespace.svc:port" channel = build_channel(addr) succeeded = False for i in range(CONNECT_PS_MAX_RETRIES): try: grpc.channel_ready_future(channel).result( timeout=CONNECT_PS_TIMEOUT) logger.info("grpc channel %s to connect pod %s is ready" % (addr, addr.split(".")[0])) ps_channels.append(channel) succeeded = True break except grpc.FutureTimeoutError: logger.warning("Failed to connect pod %s with %d retry" % (addr.split(".")[0], i)) if not succeeded: raise TimeoutError( "Time out to connect pod %s with 3 retries" % addr.split(".")[0]) ps_client = PSClient(ps_channels) worker = Worker( args, master_client=master_client, ps_client=ps_client, set_parallelism=True, ) worker.run()
def main(): args = parse_worker_args() logger = log_utils.get_logger(__name__) master_addr = args.master_addr worker_id = int(args.worker_id) logger.info("Starting worker %d", worker_id) master_client = MasterClient(build_channel(master_addr), worker_id) logger.info("Building PS connection....") ps_client = (build_ps_client(args.ps_addrs, logger) if args.distribution_strategy == DistributionStrategy.PARAMETER_SERVER else None) logger.info("Have builded PS.") worker = Worker( args, master_client=master_client, ps_client=ps_client, set_parallelism=True, ) worker.run()
def distributed_train_and_evaluate( feature_shape, model_zoo_path, model_def, model_params="", eval_metrics_fn="eval_metrics_fn", loss="loss", training=True, dataset_name=DatasetName.IMAGE_DEFAULT, callback_classes=[], use_async=False, get_model_steps=1, ps_channels=None, pservers=None, distribution_strategy=DistributionStrategy.PARAMETER_SERVER, ): """Runs distributed training and evaluation with a local master. Grpc calls are mocked by local master call. Args: feature_shape: The shape of model input. model_zoo_path: The directory that contains user-defined model files or a specific model file. model_def: The import path to the model definition function/class in the model zoo, e.g. "cifar10_subclass.CustomModel". model_params: The dictionary of model parameters in a string that will be used to instantiate the model, e.g. "param1=1,param2=2". eval_metrics_fn: The name of the evaluation metrics function defined in the model file. loss: The name of the loss function defined in the model file. training: True for job type `TRAIN_WITH_EVALUATION`, False for job type `EVALUATION`. dataset_name: A dataset name from `DatasetName`. callback_classes: A List of callbacks that will be called at given stages of the training procedure. use_async: A bool. True if using asynchronous updates. get_model_steps: Worker will perform `get_model` from the parameter server every this many steps. ps_channels: A channel list to all parameter server pods. pservers: A list of parameter server pods. distribution_strategy: The distribution startegy used by workers, e.g. DistributionStrategy.PARAMETER_SERVER or DistributionStrategy.AllreduceStrategy. Returns: An integer indicating the model version after the distributed training and evaluation. """ job_type = (JobType.TRAINING_WITH_EVALUATION if training else JobType.EVALUATION_ONLY) evaluation_steps = 1 if job_type == JobType.TRAINING_WITH_EVALUATION else 0 batch_size = 8 if dataset_name == DatasetName.IMAGENET else 16 pservers = pservers or [] ps_channels = ps_channels or [] model_module = load_module(get_module_file_path(model_zoo_path, model_def)).__dict__ for channel in ps_channels: grpc.channel_ready_future(channel).result() worker_arguments = [ "--worker_id", "1", "--job_type", job_type, "--minibatch_size", batch_size, "--model_zoo", model_zoo_path, "--model_def", model_def, "--model_params", model_params, "--loss", loss, "--get_model_steps", get_model_steps, "--distribution_strategy", distribution_strategy, ] args = parse_worker_args(worker_arguments) worker = Worker(args, ps_channels=ps_channels) if dataset_name in [DatasetName.IMAGENET, DatasetName.FRAPPE]: record_num = batch_size else: record_num = 128 shards = { create_recordio_file(record_num, dataset_name, feature_shape): ( 0, record_num, ) } if training: training_shards = shards evaluation_shards = shards else: training_shards = {} evaluation_shards = shards task_d = _TaskDispatcher( training_shards, evaluation_shards, {}, records_per_task=64, num_epochs=1, ) if training: evaluation_service = EvaluationService( None, task_d, 0, 0, evaluation_steps, False, model_module[eval_metrics_fn], ) else: evaluation_service = EvaluationService( None, task_d, 0, 0, evaluation_steps, True, model_module[eval_metrics_fn], ) task_d.set_evaluation_service(evaluation_service) master = MasterServicer( batch_size, task_d, evaluation_service=evaluation_service, ) callbacks = [ callback_class(master, worker) for callback_class in callback_classes ] in_process_master = InProcessMaster(master, callbacks) worker._stub = in_process_master for pservicer in pservers: pservicer._master_stub = in_process_master worker.run() req = elasticdl_pb2.GetTaskRequest() req.worker_id = 1 task = master.get_task(req, None) # No more task. if task.shard_name: raise RuntimeError( "There are some tasks unfinished after worker exits.") return master._version
def distributed_train_and_evaluate( feature_shape, model_zoo_path, model_def, model_params="", eval_metrics_fn="eval_metrics_fn", training=True, dataset_name=DatasetName.IMAGE_DEFAULT, callback_classes=[], use_async=False, get_model_steps=1, ): """Runs distributed training and evaluation with a local master. Grpc calls are mocked by local master call. Args: feature_shape: The shape of model input. model_zoo_path: The directory that contains user-defined model files or a specific model file. model_def: The import path to the model definition function/class in the model zoo, e.g. "cifar10_subclass.CustomModel". model_params: The dictionary of model parameters in a string that will be used to instantiate the model, e.g. "param1=1,param2=2". training: True for job type `TRAIN_WITH_EVALUATION`, False for job type `EVALUATION`. dataset_name: A dataset name from `DatasetName`. callback_classes: A List of callbacks that will be called at given stages of the training procedure. use_async: A python bool. True if using asynchronous updates. get_model_steps: Worker will perform `get_model` from the parameter server every this many steps. Returns: An integer indicating the model version after the distributed training and evaluation. """ job_type = (JobType.TRAINING_WITH_EVALUATION if training else JobType.EVALUATION_ONLY) batch_size = 8 if dataset_name == DatasetName.IMAGENET else 16 arguments = [ "--worker_id", "1", "--job_type", job_type, "--minibatch_size", batch_size, "--model_zoo", model_zoo_path, "--model_def", model_def, "--model_params", model_params, "--get_model_steps", get_model_steps, ] args = parse_worker_args(arguments) worker = Worker(args) if dataset_name in [DatasetName.IMAGENET, DatasetName.FRAPPE]: record_num = batch_size else: record_num = 128 shards = { create_recordio_file(record_num, dataset_name, feature_shape): ( 0, record_num, ) } if training: training_shards = shards evaluation_shards = shards else: training_shards = {} evaluation_shards = shards task_d = _TaskDispatcher( training_shards, evaluation_shards, {}, records_per_task=64, num_epochs=1, ) model_module = load_module(get_module_file_path(model_zoo_path, model_def)).__dict__ checkpoint_service = CheckpointService("", 0, 0, True) if training: evaluation_service = EvaluationService( checkpoint_service, None, task_d, 0, 0, 1, False, model_module[eval_metrics_fn], ) else: evaluation_service = EvaluationService( checkpoint_service, None, task_d, 0, 0, 0, True, model_module[eval_metrics_fn], ) task_d.set_evaluation_service(evaluation_service) grads_to_wait = 1 if use_async else 2 master = MasterServicer( grads_to_wait, batch_size, worker._opt_fn(), task_d, init_var=[], checkpoint_filename_for_init="", checkpoint_service=checkpoint_service, evaluation_service=evaluation_service, use_async=use_async, ) callbacks = [ callback_class(master, worker) for callback_class in callback_classes ] worker._stub = InProcessMaster(master, callbacks) for var in worker._model.trainable_variables: master.set_model_var(var.name, var.numpy()) worker.run() req = elasticdl_pb2.GetTaskRequest() req.worker_id = 1 task = master.GetTask(req, None) # No more task. if task.shard_name: raise RuntimeError( "There are some tasks unfinished after worker exits.") return master._version
def distributed_train_and_evaluate( self, feature_shape, model_def, model_params="", training=True, dataset="", ): """ Run distributed training and evaluation with a local master. grpc calls are mocked by local master call. """ job_type = (JobType.TRAINING_ONLY if training else JobType.EVALUATION_ONLY) batch_size = 16 worker = Worker( 1, job_type, batch_size, _model_zoo_path, model_def=model_def, model_params=model_params, channel=None, ) if dataset == "imagenet": batch_size = 8 shards = {create_imagenet_recordio_file(8, feature_shape): (0, 8)} elif dataset == "frappe": shards = { create_frappe_recordio_file(16, feature_shape, 5383): (0, 16) } else: shards = {create_recordio_file(128, feature_shape): (0, 128)} if training: training_shards = shards evaluation_shards = shards else: training_shards = {} evaluation_shards = shards task_d = _TaskDispatcher( training_shards, evaluation_shards, {}, records_per_task=64, num_epochs=1, ) # Initialize checkpoint service checkpoint_service = CheckpointService("", 0, 0, True) if training: evaluation_service = EvaluationService(checkpoint_service, None, task_d, 0, 0, 1, False) else: evaluation_service = EvaluationService(checkpoint_service, None, task_d, 0, 0, 0, True) task_d.set_evaluation_service(evaluation_service) # The master service master = MasterServicer( 2, batch_size, worker._opt_fn(), task_d, init_var=[], checkpoint_filename_for_init="", checkpoint_service=checkpoint_service, evaluation_service=evaluation_service, ) worker._stub = InProcessMaster(master) for var in worker._model.trainable_variables: master.set_model_var(var.name, var.numpy()) worker.run() req = elasticdl_pb2.GetTaskRequest() req.worker_id = 1 task = master.GetTask(req, None) # No more task. self.assertTrue(not task.shard_name)
def distributed_train_and_evaluate(self, training=True): """ Run distributed training and evaluation with a local master. grpc calls are mocked by local master call. """ class _Master(InProcessMaster): def ReportGradient(self, req): if 2 < self._m._version < 80: # For testing of retrain when gradient not accepted. # Increase master version to reject the gradient. self._m._version += 1 return self._m.ReportGradient(req, None) def ReportEvaluationMetrics(self, req): if 2 < self._m._version < 80: # Testing of evaluation retries. Increase the master # version so the evaluation metrics will not be accepted. self._m._version += 1 return self._m.ReportEvaluationMetrics(req, None) job_type = (JobType.TRAINING_ONLY if training else JobType.EVALUATION_ONLY) batch_size = 16 worker = Worker( 1, job_type, batch_size, _model_zoo_path, model_def="test_module.custom_model", channel=None, ) shards = {create_recordio_file(128): 128} if training: training_shards = shards evaluation_shards = {} else: training_shards = {} evaluation_shards = shards task_d = _TaskDispatcher( training_shards, evaluation_shards, {}, records_per_task=64, num_epochs=1, ) if not training: evaluation_service = EvaluationService(None, None, task_d, 0, 0, 0, True) task_d.set_evaluation_service(evaluation_service) else: evaluation_service = None master = MasterServicer( 2, batch_size, worker._opt_fn(), task_d, init_var=[], checkpoint_filename_for_init="", checkpoint_service=None, evaluation_service=evaluation_service, ) worker._stub = _Master(master) for var in worker._model.trainable_variables: master.set_model_var(var.name, var.numpy()) worker.run() req = elasticdl_pb2.GetTaskRequest() req.worker_id = 1 task = master.GetTask(req, None) # No more task. self.assertTrue(not task.shard_name)
def testMaxCheckpointVersions(self): with tempfile.TemporaryDirectory() as tempdir: chkp_dir = os.path.join(tempdir, "testMaxCheckpointVersions") os.makedirs(chkp_dir) # Save checkpoints every 2 steps, and keep 5 checkpoints at most checkpointer = CheckpointService(chkp_dir, 2, 5, False) self.assertTrue(checkpointer.is_enabled()) batch_size = 2 # Launch the training arguments = [ "--worker_id", 1, "--job_type", JobType.TRAINING_ONLY, "--minibatch_size", batch_size, "--model_zoo", _model_zoo_path, "--model_def", "test_module.custom_model", ] args = parse_worker_args(arguments) worker = Worker(args) filename = create_recordio_file(128, DatasetName.TEST_MODULE, 1) task_d = _TaskDispatcher({filename: (0, 128)}, {}, {}, records_per_task=64, num_epochs=1) master = MasterServicer( 2, batch_size, worker._opt_fn(), task_d, init_var=worker._model.trainable_variables, checkpoint_filename_for_init="", checkpoint_service=checkpointer, evaluation_service=None, ) worker._stub = InProcessMaster(master) worker.run() # We should have 5 checkpoints when the training finishes checkpoint_files = sorted(os.listdir(checkpointer._directory)) self.assertEqual( checkpoint_files, [ "model_v24.chkpt", "model_v26.chkpt", "model_v28.chkpt", "model_v30.chkpt", "model_v32.chkpt", ], ) # Latest version should be 32 self.assertEqual(32, checkpointer.get_latest_checkpoint_version()) # Check all checkpoints for version in [24, 26, 28, 30, 32]: model = checkpointer.get_checkpoint_model(version) self.assertEqual(version, model.version) # Checkpoint not found self.assertRaisesRegex( RuntimeError, "Failed to read model checkpoint from file", checkpointer.get_checkpoint_model, 100, )
def distributed_train_and_evaluate( feature_shape, model_zoo_path, model_def, model_params="", eval_metrics_fn="eval_metrics_fn", loss="loss", training=True, dataset_name=DatasetName.IMAGE_DEFAULT, use_async=False, get_model_steps=1, ps_channels=None, pservers=None, distribution_strategy=DistributionStrategy.PARAMETER_SERVER, ): """Runs distributed training and evaluation with a local master. Grpc calls are mocked by local master call. Args: feature_shape: The shape of model input. model_zoo_path: The directory that contains user-defined model files or a specific model file. model_def: The import path to the model definition function/class in the model zoo, e.g. "cifar10_subclass.CustomModel". model_params: The dictionary of model parameters in a string that will be used to instantiate the model, e.g. "param1=1,param2=2". eval_metrics_fn: The name of the evaluation metrics function defined in the model file. loss: The name of the loss function defined in the model file. training: True for job type `TRAIN_WITH_EVALUATION`, False for job type `EVALUATION`. dataset_name: A dataset name from `DatasetName`. use_async: A bool. True if using asynchronous updates. get_model_steps: Worker will perform `get_model` from the parameter server every this many steps. ps_channels: A channel list to all parameter server pods. pservers: A list of parameter server pods. distribution_strategy: The distribution startegy used by workers, e.g. DistributionStrategy.PARAMETER_SERVER or DistributionStrategy.AllreduceStrategy. Returns: An integer indicating the model version after the distributed training and evaluation. """ job_type = (JobType.TRAINING_WITH_EVALUATION if training else JobType.EVALUATION_ONLY) evaluation_steps = 1 if job_type == JobType.TRAINING_WITH_EVALUATION else 0 batch_size = 8 if dataset_name == DatasetName.IMAGENET else 16 pservers = pservers or [] ps_channels = ps_channels or [] model_module = load_module(get_module_file_path(model_zoo_path, model_def)).__dict__ worker_arguments = [ "--worker_id", "1", "--job_type", job_type, "--minibatch_size", batch_size, "--model_zoo", model_zoo_path, "--model_def", model_def, "--model_params", model_params, "--loss", loss, "--get_model_steps", get_model_steps, "--distribution_strategy", distribution_strategy, ] args = parse_worker_args(worker_arguments) if dataset_name in [DatasetName.IMAGENET, DatasetName.FRAPPE]: record_num = batch_size else: record_num = 128 shards = { create_recordio_file(record_num, dataset_name, feature_shape): ( 0, record_num, ) } if training: training_shards = shards evaluation_shards = shards else: training_shards = {} evaluation_shards = shards task_d = _TaskDispatcher( training_shards, evaluation_shards, {}, records_per_task=64, num_epochs=1, ) if training: evaluation_service = EvaluationService( None, task_d, 0, 0, evaluation_steps, False, model_module[eval_metrics_fn], ) else: evaluation_service = EvaluationService( None, task_d, 0, 0, evaluation_steps, True, model_module[eval_metrics_fn], ) task_d.set_evaluation_service(evaluation_service) master = Mock( task_d=task_d, instance_manager=None, distribution_strategy=None, ) def master_creator(): return MasterServicer( batch_size, evaluation_service=evaluation_service, master=master, ) svc, port = _server(master_creator) mc = MasterClient(build_channel("localhost:%d" % port), 1) worker = Worker(args, master_client=mc, ps_client=PSClient(ps_channels)) for pservicer in pservers: # FIXME(yancey1989): decouple pserver and master client pservicer._master_stub = mc worker.run() task = mc.get_task() # stop the master servicer svc.stop(0) # No more task. if task.shard_name: raise RuntimeError( "There are some tasks unfinished after worker exits.") return task.model_version
def distributed_train_and_evaluate( self, training=True, callback_classes=[], use_async=False, grads_to_wait=2, get_model_steps=1, ): """ Run distributed training and evaluation with a local master. grpc calls are mocked by local master call. """ if use_async and grads_to_wait > 1: raise ValueError( "grads_to_wait should be 1 when using asynchronous SGD." ) job_type = ( JobType.TRAINING_ONLY if training else JobType.EVALUATION_ONLY ) batch_size = 16 worker = Worker( 1, job_type, batch_size, _model_zoo_path, model_def="test_module.custom_model", channel=None, get_model_steps=get_model_steps, ) shards = {create_recordio_file(128): (0, 128)} if training: training_shards = shards evaluation_shards = {} else: training_shards = {} evaluation_shards = shards task_d = _TaskDispatcher( training_shards, evaluation_shards, {}, records_per_task=64, num_epochs=1, ) if not training: evaluation_service = EvaluationService( None, None, task_d, 0, 0, 0, True ) task_d.set_evaluation_service(evaluation_service) else: evaluation_service = None master = MasterServicer( grads_to_wait, batch_size, worker._opt_fn(), task_d, init_var=[], checkpoint_filename_for_init="", checkpoint_service=None, evaluation_service=evaluation_service, use_async=use_async, ) callbacks = [ callback_class(master, worker, self) for callback_class in callback_classes ] worker._stub = InProcessMaster(master, callbacks) for var in worker._model.trainable_variables: master.set_model_var(var.name, var.numpy()) worker.run() req = elasticdl_pb2.GetTaskRequest() req.worker_id = 1 task = master.GetTask(req, None) # No more task. self.assertTrue(not task.shard_name)