def test_create_tasks_with_zero_start_ind(self): task_d = create_task_manager([("f1", 0, 10), ("f2", 0, 10)], []) all_tasks = [ ("f1", 0, 3, elasticai_api_pb2.TRAINING, -1), ("f1", 3, 6, elasticai_api_pb2.TRAINING, -1), ("f1", 6, 9, elasticai_api_pb2.TRAINING, -1), ("f1", 9, 10, elasticai_api_pb2.TRAINING, -1), ("f2", 0, 3, elasticai_api_pb2.TRAINING, -1), ("f2", 3, 6, elasticai_api_pb2.TRAINING, -1), ("f2", 6, 9, elasticai_api_pb2.TRAINING, -1), ("f2", 9, 10, elasticai_api_pb2.TRAINING, -1), ] # get all tasks out, each worker is assigned 2 tasks. got_tasks = [task_d.get(i // 2) for i in range(8)] # verify ids ranges from 1 to 8 self.assertEqual(list(range(1, 9)), [k for k, _ in got_tasks]) # verify tasks self.assertEqual(sorted([v._info() for _, v in got_tasks]), all_tasks) # no todo tasks, should return None self.assertEqual((-1, None), task_d.get(10)) request = elasticai_api_pb2.ReportTaskResultRequest() # report 6 task successes. for t in (1, 3, 5, 7, 2, 8): request.task_id = t task_d.report(request, True) # there should be 2 doing tasks left. self.assertEqual(2, len(task_d._doing)) # report a task failure request.task_id = list(task_d._doing.items())[0][0] task_d.report(request, False) self.assertEqual(1, len(task_d._doing)) # recover tasks from a dead worker task_d.recover_tasks(list(task_d._doing.items())[0][1][0]) self.assertEqual(0, len(task_d._doing)) self.assertEqual(2, len(task_d._todo)) id1, t1 = task_d.get(11) id2, t2 = task_d.get(12) request.task_id = id1 task_d.report(request, True) request.task_id = id2 task_d.report(request, True) self.assertTrue(task_d.finished())
def recover_tasks(self, worker_id): """Recover doing tasks for a dead worker if needed""" if not self.support_fault_tolerance: return logger.info("Recover the tasks assigned to worker %d" % worker_id) with self._lock: ids = [ id for id, (wid, _, _) in self._doing.items() if wid == worker_id ] request = elasticai_api_pb2.ReportTaskResultRequest() for id in ids: request.task_id = id self.report(request, False)
def test_report_task_result(self): self.master.task_manager = create_task_manager([("shard_1", 0, 10), ("shard_2", 0, 9)], [], 2) master = MasterServicer( self.master.task_manager, self.master.instance_manager, None, None, ) # task to number of runs. tasks = defaultdict(int) while True: req = elasticai_api_pb2.GetTaskRequest() req.worker_id = random.randint(1, 10) task = master.get_task(req, None) if not task.shard.name: break self.assertEqual(self.master.task_manager._doing[task.task_id][0], req.worker_id) task_key = (task.shard.name, task.shard.start, task.shard.end) tasks[task_key] += 1 report = elasticai_api_pb2.ReportTaskResultRequest() report.task_id = task.task_id if task.shard.start == 0 and tasks[task_key] == 1: # Simulate error reports. report.err_message = "Worker error" master.report_task_result(report, None) self.assertDictEqual( { ("shard_1", 0, 3): 3, ("shard_1", 3, 6): 2, ("shard_1", 6, 9): 2, ("shard_1", 9, 10): 2, ("shard_2", 0, 3): 3, ("shard_2", 3, 6): 2, ("shard_2", 6, 9): 2, }, tasks, )
def report_task_result(self, task_id, err_msg, exec_counters=None): """Report task result to master. Args: task_id: int the task ID assigned by master err_msg: string the error message on training. exec_counters: dict statistics of the task being executed. """ request = elasticai_api_pb2.ReportTaskResultRequest() request.task_id = task_id request.err_message = err_msg if isinstance(exec_counters, dict): request.exec_counters.update(exec_counters) return self._stub.report_task_result(request)