def testReportTaskResult(self): task_d = _TaskDispatcher( { "shard_1": (0, 10), "shard_2": (0, 9) }, {}, {}, records_per_task=3, num_epochs=2, ) master = MasterServicer( 3, task_d, evaluation_service=None, ) # task to number of runs. tasks = defaultdict(int) while True: req = elasticdl_pb2.GetTaskRequest() req.worker_id = random.randint(1, 10) task = master.get_task(req, None) if not task.shard_name: break self.assertEqual(task_d._doing[task.task_id][0], req.worker_id) task_key = (task.shard_name, task.start, task.end) tasks[task_key] += 1 report = elasticdl_pb2.ReportTaskResultRequest() report.task_id = task.task_id if task.start == 0 and tasks[task_key] == 1: # Simulate error reports. report.err_message = "Worker error" master.report_task_result(report, None) self.assertDictEqual( { ("shard_1", 0, 3): 3, ("shard_1", 3, 6): 2, ("shard_1", 6, 9): 2, ("shard_1", 9, 10): 2, ("shard_2", 0, 3): 3, ("shard_2", 3, 6): 2, ("shard_2", 6, 9): 2, }, tasks, )
def test_report_task_result(self): self.master.task_manager = create_task_manager([("shard_1", 0, 10), ("shard_2", 0, 9)], [], 2) master = MasterServicer( self.master.task_manager, self.master.instance_manager, None, None, ) # task to number of runs. tasks = defaultdict(int) while True: req = elasticai_api_pb2.GetTaskRequest() req.worker_id = random.randint(1, 10) task = master.get_task(req, None) if not task.shard.name: break self.assertEqual(self.master.task_manager._doing[task.task_id][0], req.worker_id) task_key = (task.shard.name, task.shard.start, task.shard.end) tasks[task_key] += 1 report = elasticai_api_pb2.ReportTaskResultRequest() report.task_id = task.task_id if task.shard.start == 0 and tasks[task_key] == 1: # Simulate error reports. report.err_message = "Worker error" master.report_task_result(report, None) self.assertDictEqual( { ("shard_1", 0, 3): 3, ("shard_1", 3, 6): 2, ("shard_1", 6, 9): 2, ("shard_1", 9, 10): 2, ("shard_2", 0, 3): 3, ("shard_2", 3, 6): 2, ("shard_2", 6, 9): 2, }, tasks, )