def pred_test_based_on_valid_prod(self, h_validate, bsk_label_valid,
                                      r_validate, h_test, bsk_label_test,
                                      r_test, ks):
        """
        Loop through a ks is given, otherwise, k could be specified before calling this method
        :param h_validate:
        :param bsk_label_valid:
        :param h_test:
        :return: optimal prec, recall, f_0.5, auc, fpr, tpr
        """

        best_f = []
        best_auc = []
        thr_opt = []
        for k in ks:
            self.k = k
            self.fit(self.h_train,
                     self.bsk_label_train,
                     self.r_train,
                     step=self.step)
            pred_val = self.predict(h_validate, r_validate)

            obs_prod = np.array(
                [item for l in pred_val['obs_prod'] for item in l])
            pred_prod = np.array(
                [item for l in pred_val['pred_prob_prod'] for item in l])
            prec, rec, thr = metrics.precision_recall_curve(obs_prod,
                                                            pred_prod,
                                                            pos_label=True)

            f = f_point_5(prec, rec)
            thr_opt.append(thr[np.argmax(f)])
            best_f.append(np.max(f))
            fpr, tpr, _ = metrics.roc_curve(obs_prod,
                                            pred_prod,
                                            pos_label=True)
            auc = metrics.auc(fpr, tpr)
            best_auc.append(auc)

        idx_k = argmax(best_f, best_auc)[0]
        self.k = ks[idx_k]
        self.fit(self.h_train, self.bsk_label_train, self.r_train, self.step)
        pred_test = self.predict(h_test, r_test)

        obs_prod = np.array(
            [item for l in pred_test['obs_prod'] for item in l])
        pred_prod = np.array(
            [item for l in pred_test['pred_prob_prod'] for item in l])
        prec, rec, thr = metrics.precision_recall_curve(obs_prod,
                                                        pred_prod,
                                                        pos_label=True)
        f = f_point_5(prec, rec)
        idx_f = np.where(thr <= thr_opt[idx_k])[0][-1]

        fpr, tpr, _ = metrics.roc_curve(obs_prod, pred_prod)
        auc = metrics.auc(fpr, tpr)

        return prec[idx_f], rec[idx_f], f[idx_f], auc, fpr, tpr, thr[
            idx_f], ks[idx_k]
    def pred_test_based_on_valid(self, h_validate, bsk_label_valid,
                                 order_no_validate, r_validate, h_test,
                                 bsk_label_test, order_no_test, r_test, bs,
                                 phis):
        """

        :param h_validate:
        :param bsk_label_valid:
        :param h_test:
        :return: optimal prec, recall, f_0.5, auc, fpr, tpr
        """
        n = len(bs)
        m = len(phis)
        best_f = np.zeros((n, m))
        best_auc = np.zeros((n, m))
        thr_opt = np.zeros((n, m))

        for i in range(n):
            b = bs[i]
            for j in range(m):
                phi = phis[j]
                pred_rst = self.predict(h_validate, bsk_label_valid,
                                        order_no_validate, r_validate, b, phi)

                fpr, tpr, _ = metrics.roc_curve(pred_rst['obs'],
                                                pred_rst['pred_prob'],
                                                pos_label=True)
                auc = metrics.auc(fpr, tpr)

                prec, rec, thr = \
                    metrics.precision_recall_curve(pred_rst['obs'], pred_rst['pred_prob'], pos_label=True)
                f1 = f_point_5(prec, rec)
                f1[np.isnan(f1)] = 0
                best_f[i, j] = np.max(f1)
                best_auc[i, j] = auc
                thr_opt[i, j] = thr[np.argmax(f1)]

        idx_k = argmax(best_f, best_auc)
        pred_rst = self.predict(h_test, bsk_label_test, order_no_test, r_test,
                                bs[idx_k[0]], phis[idx_k[1]])

        prec, rec, thr = metrics.precision_recall_curve(pred_rst['obs'],
                                                        pred_rst['pred_prob'],
                                                        pos_label=True)
        f = f_point_5(prec, rec)
        idx_f = np.where(thr <= thr_opt[idx_k])[0][-1]

        fpr, tpr, _ = metrics.roc_curve(pred_rst['obs'],
                                        pred_rst['pred_prob'],
                                        pos_label=True)
        auc = metrics.auc(fpr, tpr)

        return prec[idx_f], rec[idx_f], f[idx_f], auc, fpr, tpr, thr[
            idx_f], bs[idx_k[0]], phis[idx_k[1]]
Esempio n. 3
0
    def pred_test_based_on_valid(self, h_validate, bsk_label_valid, h_test,
                                 bsk_label_test, ks):
        """
        Loop through a ks is given, otherwise, k could be specified before calling this method
        :param h_validate:
        :param bsk_label_valid:
        :param h_test:
        :return: optimal prec, recall, f_0.5, auc, fpr, tpr
        """
        best_f = []
        best_auc = []
        thr_opt = []
        for k in ks:
            self.k = k
            self.fit(self.h_train, self.bsk_label_train)
            pred_val = self.predict(h_validate)

            prec, rec, thr = metrics.precision_recall_curve(
                bsk_label_valid, pred_val)
            f = f_point_5(prec, rec)
            fpr, tpr, _ = metrics.roc_curve(bsk_label_valid, pred_val)
            auc = metrics.auc(fpr, tpr)

            thr_opt.append(thr[np.argmax(f)])
            best_f.append(np.max(f))
            best_auc.append(auc)

        idx_k = argmax(best_f, best_auc)[0]
        self.k = ks[idx_k]
        self.fit(self.h_train, self.bsk_label_train)
        pred_test = self.predict(h_test)

        prec, rec, thr = metrics.precision_recall_curve(
            bsk_label_test, pred_test)
        f = f_point_5(prec, rec)
        idx_f = np.where(thr <= thr_opt[idx_k])[0][-1]

        fpr, tpr, thr = metrics.roc_curve(bsk_label_test, pred_test)
        auc = metrics.auc(fpr, tpr)

        return prec[idx_f], rec[idx_f], f[idx_f], auc, fpr, tpr, thr[
            idx_f], ks[idx_k]