예제 #1
0
 def test_rank_profile_inherits(self):
     rank_profile = RankProfile(
         name="bm25", first_phase="bm25(title) + bm25(body)", inherits="default"
     )
     self.assertEqual(rank_profile.name, "bm25")
     self.assertEqual(rank_profile.first_phase, "bm25(title) + bm25(body)")
     self.assertEqual(rank_profile, RankProfile.from_dict(rank_profile.to_dict))
예제 #2
0
 def test_rank_profile_bert_second_phase(self):
     rank_profile = RankProfile(
         name="bert",
         first_phase="bm25(title) + bm25(body)",
         second_phase=SecondPhaseRanking(
             rerank_count=10,
             expression="sum(onnx(bert_tiny).logits{d0:0,d1:0})"),
         inherits="default",
         constants={
             "TOKEN_NONE": 0,
             "TOKEN_CLS": 101,
             "TOKEN_SEP": 102
         },
         functions=[
             Function(
                 name="question_length",
                 expression="sum(map(query(query_token_ids), f(a)(a > 0)))",
             ),
             Function(
                 name="doc_length",
                 expression=
                 "sum(map(attribute(doc_token_ids), f(a)(a > 0)))",
             ),
             Function(
                 name="input_ids",
                 expression="tensor<float>(d0[1],d1[128])(\n"
                 "    if (d1 == 0,\n"
                 "        TOKEN_CLS,\n"
                 "    if (d1 < question_length + 1,\n"
                 "        query(query_token_ids){d0:(d1-1)},\n"
                 "    if (d1 == question_length + 1,\n"
                 "        TOKEN_SEP,\n"
                 "    if (d1 < question_length + doc_length + 2,\n"
                 "        attribute(doc_token_ids){d0:(d1-question_length-2)},\n"
                 "    if (d1 == question_length + doc_length + 2,\n"
                 "        TOKEN_SEP,\n"
                 "        TOKEN_NONE\n"
                 "    ))))))",
             ),
             Function(
                 name="attention_mask",
                 expression="map(input_ids, f(a)(a > 0))",
             ),
             Function(
                 name="token_type_ids",
                 expression="tensor<float>(d0[1],d1[128])(\n"
                 "    if (d1 < question_length,\n"
                 "        0,\n"
                 "    if (d1 < question_length + doc_length,\n"
                 "        1,\n"
                 "        TOKEN_NONE\n"
                 "    )))",
             ),
         ],
         summary_features=[
             "onnx(bert).logits",
             "input_ids",
             "attention_mask",
             "token_type_ids",
         ],
     )
     self.assertEqual(rank_profile.name, "bert")
     self.assertEqual(rank_profile.first_phase, "bm25(title) + bm25(body)")
     self.assertDictEqual(
         rank_profile.constants,
         {
             "TOKEN_NONE": 0,
             "TOKEN_CLS": 101,
             "TOKEN_SEP": 102
         },
     )
     self.assertEqual(
         rank_profile.summary_features,
         [
             "onnx(bert).logits", "input_ids", "attention_mask",
             "token_type_ids"
         ],
     )
     self.assertEqual(rank_profile,
                      RankProfile.from_dict(rank_profile.to_dict))