def test_rank_candidates(self): candidates = pandas.read_csv(join(TEST_DATA_PATH, "test_data_candidates.csv.xz"), index_col=0).infer_objects() proba = numpy.load(join(TEST_DATA_PATH, "test_data_candidates_proba.pkl")) self.assertDictEqual(rank_candidates(candidates, proba, n_candidates=3), self.suggestions) proba = numpy.array([1.0, 0.9, 0.05, 0.01, 0.3, 0.98], dtype=float) self.assertDictEqual(rank_candidates(self.custom_candidates, proba, n_candidates=3), self.custom_suggestions) self.assertDictEqual(rank_candidates(self.custom_candidates, proba, n_candidates=2, return_all=False), self.custom_filtered_suggestions)
def test_rank_candidates(self): candidates = pandas.read_csv(join(TEST_DATA_PATH, "test_data_candidates.csv.xz"), index_col=0, keep_default_na=False) with open(join(TEST_DATA_PATH, "test_data_candidates_proba.pickle"), "rb") as fin: proba = pickle.load(fin) self.assertDictEqual( rank_candidates(candidates, proba, n_candidates=3), self.suggestions) proba = numpy.array([1.0, 0.9, 0.05, 0.01, 0.3, 0.98], dtype=float) self.assertDictEqual( rank_candidates(self.custom_candidates, proba, n_candidates=3), self.custom_suggestions) self.assertDictEqual( rank_candidates(self.custom_candidates, proba, n_candidates=2, return_all=False), self.custom_filtered_suggestions)
def rank(self, candidates: pandas.DataFrame, features: numpy.ndarray, n_candidates: int = 3, return_all: bool = True) -> Dict[int, List[Tuple[str, float]]]: """ Assign the correctness probability value for each of the candidates. :param candidates: DataFrame containing information about candidates for correction. :param features: Matrix of features for candidates. :param n_candidates: Number of most probably correct candidates to return for each typo. :param return_all: False to return corrections only for typos corrected in the \ first candidate. :return: Dictionary `{id : [(candidate, correctness_proba), ...]}`, candidates are sorted \ by correctness probability in a descending order. """ dtest = xgb.DMatrix(features) test_probs = self.bst.predict(dtest, ntree_limit=self.bst.best_ntree_limit) return rank_candidates(candidates, test_probs, n_candidates, return_all)