Exemple #1
0
    def test_rank_candidates(self):
        candidates = pandas.read_csv(join(TEST_DATA_PATH, "test_data_candidates.csv.xz"),
                                     index_col=0).infer_objects()
        proba = numpy.load(join(TEST_DATA_PATH, "test_data_candidates_proba.pkl"))
        self.assertDictEqual(rank_candidates(candidates, proba, n_candidates=3), self.suggestions)

        proba = numpy.array([1.0, 0.9, 0.05, 0.01, 0.3, 0.98], dtype=float)
        self.assertDictEqual(rank_candidates(self.custom_candidates, proba, n_candidates=3),
                             self.custom_suggestions)
        self.assertDictEqual(rank_candidates(self.custom_candidates, proba, n_candidates=2,
                                             return_all=False),
                             self.custom_filtered_suggestions)
Exemple #2
0
    def test_rank_candidates(self):
        candidates = pandas.read_csv(join(TEST_DATA_PATH,
                                          "test_data_candidates.csv.xz"),
                                     index_col=0,
                                     keep_default_na=False)
        with open(join(TEST_DATA_PATH, "test_data_candidates_proba.pickle"),
                  "rb") as fin:
            proba = pickle.load(fin)
        self.assertDictEqual(
            rank_candidates(candidates, proba, n_candidates=3),
            self.suggestions)

        proba = numpy.array([1.0, 0.9, 0.05, 0.01, 0.3, 0.98], dtype=float)
        self.assertDictEqual(
            rank_candidates(self.custom_candidates, proba, n_candidates=3),
            self.custom_suggestions)
        self.assertDictEqual(
            rank_candidates(self.custom_candidates,
                            proba,
                            n_candidates=2,
                            return_all=False),
            self.custom_filtered_suggestions)
Exemple #3
0
    def rank(self, candidates: pandas.DataFrame, features: numpy.ndarray, n_candidates: int = 3,
             return_all: bool = True) -> Dict[int, List[Tuple[str, float]]]:
        """
        Assign the correctness probability value for each of the candidates.

        :param candidates: DataFrame containing information about candidates for correction.
        :param features: Matrix of features for candidates.
        :param n_candidates: Number of most probably correct candidates to return for each typo.
        :param return_all: False to return corrections only for typos corrected in the \
                           first candidate.
        :return: Dictionary `{id : [(candidate, correctness_proba), ...]}`, candidates are sorted \
                 by correctness probability in a descending order.
        """
        dtest = xgb.DMatrix(features)
        test_probs = self.bst.predict(dtest, ntree_limit=self.bst.best_ntree_limit)
        return rank_candidates(candidates, test_probs, n_candidates, return_all)