def testGetEmptyTask(self): master = MasterServicer( 3, _TaskDispatcher({}, {}, {}, records_per_task=3, num_epochs=2), evaluation_service=None, ) req = elasticdl_pb2.GetTaskRequest() # No task yet, make sure the returned versions are as expected. req.worker_id = 1 task = master.get_task(req, None) self.assertEqual("", task.shard_name) self.assertEqual(0, task.model_version) master._version = 1 task = master.get_task(req, None) self.assertEqual("", task.shard_name) self.assertEqual(1, task.model_version)
def test_get_empty_task(self): self.master.task_manager = create_task_manager([], []) master_servicer = MasterServicer( self.master.task_manager, self.master.instance_manager, None, None, ) req = elasticai_api_pb2.GetTaskRequest() # No task yet, make sure the returned versions are as expected. req.worker_id = 1 task = master_servicer.get_task(req, None) self.assertEqual("", task.shard.name) self.assertEqual(0, task.model_version) master_servicer._version = 1 task = master_servicer.get_task(req, None) self.assertEqual("", task.shard.name) self.assertEqual(1, task.model_version)
def testReportTaskResult(self): task_d = _TaskDispatcher( { "shard_1": (0, 10), "shard_2": (0, 9) }, {}, {}, records_per_task=3, num_epochs=2, ) master = MasterServicer( 3, task_d, evaluation_service=None, ) # task to number of runs. tasks = defaultdict(int) while True: req = elasticdl_pb2.GetTaskRequest() req.worker_id = random.randint(1, 10) task = master.get_task(req, None) if not task.shard_name: break self.assertEqual(task_d._doing[task.task_id][0], req.worker_id) task_key = (task.shard_name, task.start, task.end) tasks[task_key] += 1 report = elasticdl_pb2.ReportTaskResultRequest() report.task_id = task.task_id if task.start == 0 and tasks[task_key] == 1: # Simulate error reports. report.err_message = "Worker error" master.report_task_result(report, None) self.assertDictEqual( { ("shard_1", 0, 3): 3, ("shard_1", 3, 6): 2, ("shard_1", 6, 9): 2, ("shard_1", 9, 10): 2, ("shard_2", 0, 3): 3, ("shard_2", 3, 6): 2, ("shard_2", 6, 9): 2, }, tasks, )
def test_report_task_result(self): self.master.task_manager = create_task_manager([("shard_1", 0, 10), ("shard_2", 0, 9)], [], 2) master = MasterServicer( self.master.task_manager, self.master.instance_manager, None, None, ) # task to number of runs. tasks = defaultdict(int) while True: req = elasticai_api_pb2.GetTaskRequest() req.worker_id = random.randint(1, 10) task = master.get_task(req, None) if not task.shard.name: break self.assertEqual(self.master.task_manager._doing[task.task_id][0], req.worker_id) task_key = (task.shard.name, task.shard.start, task.shard.end) tasks[task_key] += 1 report = elasticai_api_pb2.ReportTaskResultRequest() report.task_id = task.task_id if task.shard.start == 0 and tasks[task_key] == 1: # Simulate error reports. report.err_message = "Worker error" master.report_task_result(report, None) self.assertDictEqual( { ("shard_1", 0, 3): 3, ("shard_1", 3, 6): 2, ("shard_1", 6, 9): 2, ("shard_1", 9, 10): 2, ("shard_2", 0, 3): 3, ("shard_2", 3, 6): 2, ("shard_2", 6, 9): 2, }, tasks, )
def distributed_train_and_evaluate( feature_shape, model_zoo_path, model_def, model_params="", eval_metrics_fn="eval_metrics_fn", loss="loss", training=True, dataset_name=DatasetName.IMAGE_DEFAULT, callback_classes=[], use_async=False, get_model_steps=1, ps_channels=None, pservers=None, distribution_strategy=DistributionStrategy.PARAMETER_SERVER, ): """Runs distributed training and evaluation with a local master. Grpc calls are mocked by local master call. Args: feature_shape: The shape of model input. model_zoo_path: The directory that contains user-defined model files or a specific model file. model_def: The import path to the model definition function/class in the model zoo, e.g. "cifar10_subclass.CustomModel". model_params: The dictionary of model parameters in a string that will be used to instantiate the model, e.g. "param1=1,param2=2". eval_metrics_fn: The name of the evaluation metrics function defined in the model file. loss: The name of the loss function defined in the model file. training: True for job type `TRAIN_WITH_EVALUATION`, False for job type `EVALUATION`. dataset_name: A dataset name from `DatasetName`. callback_classes: A List of callbacks that will be called at given stages of the training procedure. use_async: A bool. True if using asynchronous updates. get_model_steps: Worker will perform `get_model` from the parameter server every this many steps. ps_channels: A channel list to all parameter server pods. pservers: A list of parameter server pods. distribution_strategy: The distribution startegy used by workers, e.g. DistributionStrategy.PARAMETER_SERVER or DistributionStrategy.AllreduceStrategy. Returns: An integer indicating the model version after the distributed training and evaluation. """ job_type = (JobType.TRAINING_WITH_EVALUATION if training else JobType.EVALUATION_ONLY) evaluation_steps = 1 if job_type == JobType.TRAINING_WITH_EVALUATION else 0 batch_size = 8 if dataset_name == DatasetName.IMAGENET else 16 pservers = pservers or [] ps_channels = ps_channels or [] model_module = load_module(get_module_file_path(model_zoo_path, model_def)).__dict__ for channel in ps_channels: grpc.channel_ready_future(channel).result() worker_arguments = [ "--worker_id", "1", "--job_type", job_type, "--minibatch_size", batch_size, "--model_zoo", model_zoo_path, "--model_def", model_def, "--model_params", model_params, "--loss", loss, "--get_model_steps", get_model_steps, "--distribution_strategy", distribution_strategy, ] args = parse_worker_args(worker_arguments) worker = Worker(args, ps_channels=ps_channels) if dataset_name in [DatasetName.IMAGENET, DatasetName.FRAPPE]: record_num = batch_size else: record_num = 128 shards = { create_recordio_file(record_num, dataset_name, feature_shape): ( 0, record_num, ) } if training: training_shards = shards evaluation_shards = shards else: training_shards = {} evaluation_shards = shards task_d = _TaskDispatcher( training_shards, evaluation_shards, {}, records_per_task=64, num_epochs=1, ) if training: evaluation_service = EvaluationService( None, task_d, 0, 0, evaluation_steps, False, model_module[eval_metrics_fn], ) else: evaluation_service = EvaluationService( None, task_d, 0, 0, evaluation_steps, True, model_module[eval_metrics_fn], ) task_d.set_evaluation_service(evaluation_service) master = MasterServicer( batch_size, task_d, evaluation_service=evaluation_service, ) callbacks = [ callback_class(master, worker) for callback_class in callback_classes ] in_process_master = InProcessMaster(master, callbacks) worker._stub = in_process_master for pservicer in pservers: pservicer._master_stub = in_process_master worker.run() req = elasticdl_pb2.GetTaskRequest() req.worker_id = 1 task = master.get_task(req, None) # No more task. if task.shard_name: raise RuntimeError( "There are some tasks unfinished after worker exits.") return master._version