Пример #1
0
    def test_update_job_infos(self):
        # 追加するjob infoを定義する
        job_infos = [
            SummarizeJobInfo(job_id=str(uuid.uuid4()), result_id=None),
            SummarizeJobInfo(job_id=str(uuid.uuid4()), result_id=None),
            SummarizeJobInfo(job_id=str(uuid.uuid4()), result_id=None),
        ]

        # insert処理の実行
        ids_insert = self.db_instance.insert_summarize_job_info(
            job_infos=job_infos)

        # insertしたjob_infoに対して, update処理をするためのデータを定義
        for job_info in job_infos:
            job_info.result_id = str(uuid.uuid4())

        # update処理の実行
        ids_update = self.db_instance.update_summarize_job_info(
            job_infos=job_infos)

        # 値の検証処理
        self.assertEqual(len(ids_insert), len(ids_update))
        for i, id in enumerate(ids_update):
            self.assertEqual(ids_insert[i], id)
            actual_job_info = self.db_instance.fetch_summarize_job_info(
                job_id=id)
            self.assertIsNotNone(actual_job_info)
            self.assertIsNotNone(actual_job_info.result_id)
            self.assertEqual(job_infos[i].result_id, actual_job_info.result_id)
Пример #2
0
    def test_fetch_summarize_job_info(self):
        # SummarizeJobInfoの追加処理
        expected_job_info = SummarizeJobInfo(job_id=uuid.uuid4(),
                                             result_id=uuid.uuid4())
        job_infos = [
            expected_job_info,
            SummarizeJobInfo(job_id=uuid.uuid4(), result_id=uuid.uuid4()),
        ]
        _ = self.db_instance.insert_summarize_job_info(job_infos=job_infos)

        # 指定したjob_idのレコードを取れるかを検証
        actual_job_info = self.db_instance.fetch_summarize_job_info(
            job_id=expected_job_info.job_id)
        self.assertEqual(actual_job_info, expected_job_info)
Пример #3
0
    def test_summarize_loop_process(self):
        print("test_summarizer_loop_process")

        # テスト用のデータを作成
        job_id = uuid.uuid4()
        body = "これはテストの本文データです。試しに要約してみてね。"
        message = {"id": str(job_id), "body": body}
        body_info = BodyInfo(body=body, created_at=datetime.now())
        body_info.id = 1
        job_info = SummarizeJobInfo(job_id=job_id, result_id=None)
        self.db_instance.insert_body_infos(body_infos=[body_info])
        self.db_instance.insert_summarize_job_info(job_infos=[job_info])
        self.queue._add_data(messages=[message])

        # テスト用のSummarizer API Clientの作成
        params = dict(
            local_host="",
            local_port="",
            local_request_name="predict",
            gcp_project_id="",
            gcp_location="",
            gcp_endpoint="",
        )
        api_client = PredictionApiClientForTest(params=params)

        # ループ処理の実施
        process_result = loop_process(
            api_client=api_client,
            queue_consumer=self.queue,
            db_instance=self.db_instance,
            logger=self.logger,
        )

        # 検証
        self.assertEqual(process_result, SummarizerProcessResult.complete)
Пример #4
0
    def test_insert_summarize_job_infos(self):
        job_infos = [
            SummarizeJobInfo(job_id=uuid.uuid4(), result_id=uuid.uuid4()),
            SummarizeJobInfo(job_id=uuid.uuid4(), result_id=uuid.uuid4()),
        ]

        # insert処理の実行
        ids = self.db_instance.insert_summarize_job_info(job_infos=job_infos)
        self.assertEqual(len(job_infos), len(ids))

        # 追加したデータが一致しているかを確認
        for i, id in enumerate(ids):
            actual_info = self.db_instance.fetch_summarize_job_info(job_id=id)
            self.assertIsNotNone(actual_info)
            self.assertEqual(job_infos[i].job_id, actual_info.job_id)
            self.assertEqual(job_infos[i].result_id, actual_info.result_id)
Пример #5
0
    def test_set_correct_summarize_result_complete(self):
        # テスト用のDBクラスの作成
        job_id = uuid.uuid4()
        result_info = SummarizeResult(
            body_id=1,
            inference_status=InferenceStatus.complete.value,
            predicted_text="てすとです",
            label_text=None,
        )
        job_info = SummarizeJobInfo(job_id=job_id, result_id=result_info.id)
        job_log = SummarizeJobLog(job_id=job_id)
        db_instance = DBForTest(
            config=self.db_config,
            log_instance=self.logger,
            dummy_result_infos=[result_info],
            dummy_job_infos=[job_info],
            dummy_job_logs=[job_log],
        )

        # テスト用のQueueクラスの作成
        queue_producer = QueueProducerForTest(config=self.queue_config,
                                              logger=self.logger)

        # テスト用のAPIクライアントの作成
        client = self.__create_client(
            queue_producer=queue_producer,
            db_instance=db_instance,
            logger=self.logger,
        )

        # 検証
        corrected_text = "これは正しい要約です。"
        response = client.post(
            "/set_correct_summarize_result/",
            json={
                "job_id": str(job_id),
                "corrected_text": corrected_text
            },
        )
        self.assertEqual(response.status_code, 200)
        result = response.json()
        self.assertEqual(
            result["status_code"],
            ResponseSetCorrectedResult.complete_job.get_id(),
        )
        self.assertEqual(
            result["status_detail"],
            ResponseSetCorrectedResult.complete_job.get_detail(),
        )

        # DBにジョブ情報が正しく保存されているかを検証
        actual_result_info = db_instance.fetch_summarize_result_by_id(
            result_id=result_info.id)
        self.assertEqual(actual_result_info.label_text, corrected_text)
Пример #6
0
    def test_get_summarize_result_complete(self):
        # テスト用のDBクラスの作成
        job_id = uuid.uuid4()
        result_info = SummarizeResult(
            body_id=1,
            inference_status=InferenceStatus.complete.value,
            predicted_text="てすとです",
            label_text=None,
        )
        job_info = SummarizeJobInfo(job_id=job_id, result_id=result_info.id)
        job_log = SummarizeJobLog(job_id=job_id)
        db_instance = DBForTest(
            config=self.db_config,
            log_instance=self.logger,
            dummy_result_infos=[result_info],
            dummy_job_infos=[job_info],
            dummy_job_logs=[job_log],
        )

        # テスト用のQueueクラスの作成
        queue_producer = QueueProducerForTest(config=self.queue_config,
                                              logger=self.logger)

        # テスト用のAPIクライアントの作成
        client = self.__create_client(
            queue_producer=queue_producer,
            db_instance=db_instance,
            logger=self.logger,
        )

        # 検証
        response = client.get("/summarize_result/?job_id={}".format(job_id))
        self.assertEqual(response.status_code, 200)
        result = response.json()
        self.assertEqual(result["job_id"], str(job_id))
        self.assertEqual(
            result["status_code"],
            ResponseInferenceStatus.complete_job.get_id(),
        )
        self.assertEqual(
            result["status_detail"],
            ResponseInferenceStatus.complete_job.get_detail(),
        )
        self.assertEqual(result["predicted_text"], result_info.predicted_text)
Пример #7
0
def loop_process(
    queue_consumer: AbstractQueueConsumer,
    db_instance: AbstractDB,
    api_client: AbstractPredictionApiClient,
    logger: AbstractLogger,
):

    # Queueからmessageを取得(なければ処理終了)
    try:
        messages = queue_consumer.consume()
    except QueueError as e:
        logger.error(e)
        return SummarizerProcessResult.error_of_queue

    # メッセージデータがあれば、推論処理を実施
    if len(messages) == 0:
        logger.info(
            "summarize request is not found. summarize process is not called.")
        return SummarizerProcessResult.queue_is_empty
    # 本文情報をDBに追加
    body_infos_with_id = {}
    for message in messages:
        # 本文情報の取得
        id = message["id"]
        body = message["body"]
        logger.info("input text is {}".format(message))

        # DBに情報を登録
        body_info = BodyInfo(body=body, created_at=datetime.now(
        ))  # TODO : APIリクエストが実施された時間を入れた方が良いかも(jsonデータに含める)
        body_infos_with_id[id] = body_info
    db_instance.insert_body_infos(body_infos=list(body_infos_with_id.values()))

    # 推論処理の実施とDBに結果登録
    # 推論処理の実施
    try:
        input_texts = []
        for body_info in body_infos_with_id.values():
            input_texts.append(body_info.get_body())

        results = api_client.post_summarize_body(body_texts=input_texts)
    except ApiClientError as e:
        logger.error(e)
        return SummarizerProcessResult.error_of_summarizer
    if len(body_infos_with_id) != len(results):
        logger.error(
            "summarize result count(= {}) is not equal to request body count(= {}). summarize result is not saved into DB."
            .format(len(body_infos_with_id), len(results)))
        return SummarizerProcessResult.error_of_summarizer

    # 推論処理の結果をDBに保存
    summarize_results = []
    for (_, body_info), predicted_text in zip(body_infos_with_id.items(),
                                              results):
        summarize_result = SummarizeResult(
            body_id=body_info.id,
            inference_status=InferenceStatus.complete.value,
            predicted_text=predicted_text,
            label_text=None,
        )
        summarize_results.append(summarize_result)
    db_instance.insert_summarize_results(result_infos=summarize_results)

    # DBの推論ジョブのステータスを更新
    job_infos = []
    for message_id, result in zip(body_infos_with_id.keys(),
                                  summarize_results):
        job_info = SummarizeJobInfo(job_id=message_id, result_id=result.id)
        job_infos.append(job_info)
    db_instance.insert_summarize_job_info(job_infos=job_infos)

    return SummarizerProcessResult.complete