示例#1
0
    def testReportTaskResult(self):
        task_d = _TaskDispatcher(
            {
                "shard_1": (0, 10),
                "shard_2": (0, 9)
            },
            {},
            {},
            records_per_task=3,
            num_epochs=2,
        )
        master = MasterServicer(
            3,
            task_d,
            evaluation_service=None,
        )

        # task to number of runs.
        tasks = defaultdict(int)
        while True:
            req = elasticdl_pb2.GetTaskRequest()
            req.worker_id = random.randint(1, 10)
            task = master.get_task(req, None)
            if not task.shard_name:
                break
            self.assertEqual(task_d._doing[task.task_id][0], req.worker_id)
            task_key = (task.shard_name, task.start, task.end)
            tasks[task_key] += 1
            report = elasticdl_pb2.ReportTaskResultRequest()
            report.task_id = task.task_id
            if task.start == 0 and tasks[task_key] == 1:
                # Simulate error reports.
                report.err_message = "Worker error"
            master.report_task_result(report, None)

        self.assertDictEqual(
            {
                ("shard_1", 0, 3): 3,
                ("shard_1", 3, 6): 2,
                ("shard_1", 6, 9): 2,
                ("shard_1", 9, 10): 2,
                ("shard_2", 0, 3): 3,
                ("shard_2", 3, 6): 2,
                ("shard_2", 6, 9): 2,
            },
            tasks,
        )
示例#2
0
    def test_report_task_result(self):
        self.master.task_manager = create_task_manager([("shard_1", 0, 10),
                                                        ("shard_2", 0, 9)], [],
                                                       2)
        master = MasterServicer(
            self.master.task_manager,
            self.master.instance_manager,
            None,
            None,
        )

        # task to number of runs.
        tasks = defaultdict(int)
        while True:
            req = elasticai_api_pb2.GetTaskRequest()
            req.worker_id = random.randint(1, 10)
            task = master.get_task(req, None)
            if not task.shard.name:
                break
            self.assertEqual(self.master.task_manager._doing[task.task_id][0],
                             req.worker_id)
            task_key = (task.shard.name, task.shard.start, task.shard.end)
            tasks[task_key] += 1
            report = elasticai_api_pb2.ReportTaskResultRequest()
            report.task_id = task.task_id
            if task.shard.start == 0 and tasks[task_key] == 1:
                # Simulate error reports.
                report.err_message = "Worker error"
            master.report_task_result(report, None)

        self.assertDictEqual(
            {
                ("shard_1", 0, 3): 3,
                ("shard_1", 3, 6): 2,
                ("shard_1", 6, 9): 2,
                ("shard_1", 9, 10): 2,
                ("shard_2", 0, 3): 3,
                ("shard_2", 3, 6): 2,
                ("shard_2", 6, 9): 2,
            },
            tasks,
        )