def test_rank_profile_inherits(self): rank_profile = RankProfile( name="bm25", first_phase="bm25(title) + bm25(body)", inherits="default" ) self.assertEqual(rank_profile.name, "bm25") self.assertEqual(rank_profile.first_phase, "bm25(title) + bm25(body)") self.assertEqual(rank_profile, RankProfile.from_dict(rank_profile.to_dict))
def test_rank_profile_bert_second_phase(self): rank_profile = RankProfile( name="bert", first_phase="bm25(title) + bm25(body)", second_phase=SecondPhaseRanking( rerank_count=10, expression="sum(onnx(bert_tiny).logits{d0:0,d1:0})"), inherits="default", constants={ "TOKEN_NONE": 0, "TOKEN_CLS": 101, "TOKEN_SEP": 102 }, functions=[ Function( name="question_length", expression="sum(map(query(query_token_ids), f(a)(a > 0)))", ), Function( name="doc_length", expression= "sum(map(attribute(doc_token_ids), f(a)(a > 0)))", ), Function( name="input_ids", expression="tensor<float>(d0[1],d1[128])(\n" " if (d1 == 0,\n" " TOKEN_CLS,\n" " if (d1 < question_length + 1,\n" " query(query_token_ids){d0:(d1-1)},\n" " if (d1 == question_length + 1,\n" " TOKEN_SEP,\n" " if (d1 < question_length + doc_length + 2,\n" " attribute(doc_token_ids){d0:(d1-question_length-2)},\n" " if (d1 == question_length + doc_length + 2,\n" " TOKEN_SEP,\n" " TOKEN_NONE\n" " ))))))", ), Function( name="attention_mask", expression="map(input_ids, f(a)(a > 0))", ), Function( name="token_type_ids", expression="tensor<float>(d0[1],d1[128])(\n" " if (d1 < question_length,\n" " 0,\n" " if (d1 < question_length + doc_length,\n" " 1,\n" " TOKEN_NONE\n" " )))", ), ], summary_features=[ "onnx(bert).logits", "input_ids", "attention_mask", "token_type_ids", ], ) self.assertEqual(rank_profile.name, "bert") self.assertEqual(rank_profile.first_phase, "bm25(title) + bm25(body)") self.assertDictEqual( rank_profile.constants, { "TOKEN_NONE": 0, "TOKEN_CLS": 101, "TOKEN_SEP": 102 }, ) self.assertEqual( rank_profile.summary_features, [ "onnx(bert).logits", "input_ids", "attention_mask", "token_type_ids" ], ) self.assertEqual(rank_profile, RankProfile.from_dict(rank_profile.to_dict))