def test_fetch_summarize_result_by_id(self): # BodyInfoの追加処理 body_infos = [ BodyInfo(body="てすとです", created_at=datetime.now()), BodyInfo(body="てすとです2", created_at=datetime.now()), ] ids = self.db_instance.insert_body_infos(body_infos=body_infos) # SummarizerResultの追加処理 expected_predicted_texts = ["予測です", "予測です2"] expected_summarizer_results = [ SummarizeResult( body_id=ids[0], inference_status=InferenceStatus.complete.value, predicted_text=expected_predicted_texts[0], label_text=expected_predicted_texts[0], ), SummarizeResult( body_id=ids[1], inference_status=InferenceStatus.complete.value, predicted_text=expected_predicted_texts[1], label_text=expected_predicted_texts[1], ), ] result_ids = self.db_instance.insert_summarize_results( result_infos=expected_summarizer_results) # SummarizeResultの取得処理 actual_result = self.db_instance.fetch_summarize_result_by_id( result_id=result_ids[0]) self.assertEqual(actual_result, expected_summarizer_results[0])
def test_summarize_loop_process(self): print("test_summarizer_loop_process") # テスト用のデータを作成 job_id = uuid.uuid4() body = "これはテストの本文データです。試しに要約してみてね。" message = {"id": str(job_id), "body": body} body_info = BodyInfo(body=body, created_at=datetime.now()) body_info.id = 1 job_info = SummarizeJobInfo(job_id=job_id, result_id=None) self.db_instance.insert_body_infos(body_infos=[body_info]) self.db_instance.insert_summarize_job_info(job_infos=[job_info]) self.queue._add_data(messages=[message]) # テスト用のSummarizer API Clientの作成 params = dict( local_host="", local_port="", local_request_name="predict", gcp_project_id="", gcp_location="", gcp_endpoint="", ) api_client = PredictionApiClientForTest(params=params) # ループ処理の実施 process_result = loop_process( api_client=api_client, queue_consumer=self.queue, db_instance=self.db_instance, logger=self.logger, ) # 検証 self.assertEqual(process_result, SummarizerProcessResult.complete)
def test_insert_body_infos(self): body_infos = [ BodyInfo(body="テストです1", created_at=datetime.now()), BodyInfo(body="テストです2", created_at=datetime.now()), ] # insert処理の実行 ids = self.db_instance.insert_body_infos(body_infos=body_infos) self.assertEqual(len(body_infos), len(ids)) # 追加したデータが一致しているかを確認 for i, id in enumerate(ids): actual_info = self.db_instance.fetch_body_info_by_id( body_info_id=id) self.assertIsNotNone(actual_info) self.assertEqual(body_infos[i].body, actual_info.body) self.assertEqual(body_infos[i].created_at, actual_info.created_at)
def test_insert_body_infos(self): # BodyInfoの追加処理 expected_body_texts = ["てすとです", "てすとです2"] expected_body_infos = [ BodyInfo(body=expected_body_texts[0], created_at=datetime.now()), BodyInfo(body=expected_body_texts[1], created_at=datetime.now()), ] _ = self.db_instance.insert_body_infos(body_infos=expected_body_infos) # BodyInfoの取得を行い、データが正しいかを検証 actual_body_infos = self.db_instance.fetch_body_infos() for expected_body_info, expected_body_text in zip( expected_body_infos, expected_body_texts): actual_body_info = [ body_info for body_info in actual_body_infos if body_info.id == expected_body_info.id ][0] self.assertEqual(expected_body_info, actual_body_info) self.assertEqual(expected_body_text, actual_body_info.get_body())
def test_insert_summarizer_results(self): # BodyInfoの追加処理 body_infos = [ BodyInfo(body="てすとです", created_at=datetime.now()), BodyInfo(body="てすとです2", created_at=datetime.now()), ] ids = self.db_instance.insert_body_infos(body_infos=body_infos) # SummarizerResultの追加処理 expected_predicted_texts = ["予測です", "予測です2"] expected_summarizer_results = [ SummarizeResult( body_id=ids[0], inference_status=InferenceStatus.complete.value, predicted_text=expected_predicted_texts[0], label_text=expected_predicted_texts[0], ), SummarizeResult( body_id=ids[1], inference_status=InferenceStatus.complete.value, predicted_text=expected_predicted_texts[1], label_text=expected_predicted_texts[1], ), ] _ = self.db_instance.insert_summarize_results( result_infos=expected_summarizer_results) # SummarizerResultの取得を行い、データが正しいかを検証 actual_results = self.db_instance.fetch_summarize_results() for expected_result, expected_predicted_text in zip( expected_summarizer_results, expected_predicted_texts): actual_result = [ result for result in actual_results if result.id == expected_result.id ][0] self.assertEqual(expected_result, actual_result) self.assertEqual(expected_predicted_text, actual_result.get_predicted_text()) self.assertEqual(expected_predicted_text, actual_result.get_label_text())
def loop_process( queue_consumer: AbstractQueueConsumer, db_instance: AbstractDB, api_client: AbstractPredictionApiClient, logger: AbstractLogger, ): # Queueからmessageを取得(なければ処理終了) try: messages = queue_consumer.consume() except QueueError as e: logger.error(e) return SummarizerProcessResult.error_of_queue # メッセージデータがあれば、推論処理を実施 if len(messages) == 0: logger.info( "summarize request is not found. summarize process is not called.") return SummarizerProcessResult.queue_is_empty # 本文情報をDBに追加 body_infos_with_id = {} for message in messages: # 本文情報の取得 id = message["id"] body = message["body"] logger.info("input text is {}".format(message)) # DBに情報を登録 body_info = BodyInfo(body=body, created_at=datetime.now( )) # TODO : APIリクエストが実施された時間を入れた方が良いかも(jsonデータに含める) body_infos_with_id[id] = body_info db_instance.insert_body_infos(body_infos=list(body_infos_with_id.values())) # 推論処理の実施とDBに結果登録 # 推論処理の実施 try: input_texts = [] for body_info in body_infos_with_id.values(): input_texts.append(body_info.get_body()) results = api_client.post_summarize_body(body_texts=input_texts) except ApiClientError as e: logger.error(e) return SummarizerProcessResult.error_of_summarizer if len(body_infos_with_id) != len(results): logger.error( "summarize result count(= {}) is not equal to request body count(= {}). summarize result is not saved into DB." .format(len(body_infos_with_id), len(results))) return SummarizerProcessResult.error_of_summarizer # 推論処理の結果をDBに保存 summarize_results = [] for (_, body_info), predicted_text in zip(body_infos_with_id.items(), results): summarize_result = SummarizeResult( body_id=body_info.id, inference_status=InferenceStatus.complete.value, predicted_text=predicted_text, label_text=None, ) summarize_results.append(summarize_result) db_instance.insert_summarize_results(result_infos=summarize_results) # DBの推論ジョブのステータスを更新 job_infos = [] for message_id, result in zip(body_infos_with_id.keys(), summarize_results): job_info = SummarizeJobInfo(job_id=message_id, result_id=result.id) job_infos.append(job_info) db_instance.insert_summarize_job_info(job_infos=job_infos) return SummarizerProcessResult.complete