コード例 #1
0
ファイル: test_metrics.py プロジェクト: xhlulu/scikit-plot
 def test_cmap(self):
     np.random.seed(0)
     clf = LogisticRegression()
     clf.fit(self.X, self.y)
     probas = clf.predict_proba(self.X)
     plot_precision_recall_curve(self.y, probas, cmap='nipy_spectral')
     plot_precision_recall_curve(self.y, probas, cmap=plt.cm.nipy_spectral)
コード例 #2
0
def draw_pr_graph(state):
    classified_data = [
        state.classified_data_normal_case, state.classified_data_resampled_case
    ]
    f, (ax1, ax2) = plt.subplots(1, 2)
    for x in range(2):
        probas = np.concatenate(classified_data[x]['preds_list'], axis=0)
        y_true = np.concatenate(classified_data[x]['trues_list'])
        ax = ax1 if x == 0 else ax2
        plot_precision_recall_curve(
            y_true,
            probas,
            title="Normal dataset with average precision = {0:0.2f}".format(
                classified_data[x]['average_precision']) if x == 0 else
            "Resampled dataset with {0} and\n average precision ={1:0.2f}".
            format(state.sampling_algorithm.value[0],
                   classified_data[x]['average_precision']),
            curves=('micro', 'each_class'),
            ax=ax,
            figsize=None,
            cmap='nipy_spectral',
            title_fontsize="large",
            text_fontsize="medium")
        ax.get_figure().set_size_inches(12, 9)
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.set_xlabel('Recall')
        ax.set_ylabel('Precision')
        ax.set_ylim([0.0, 1.05])
        ax.set_xlim([0.0, 1.0])

    plt.show()
コード例 #3
0
ファイル: test_metrics.py プロジェクト: xhlulu/scikit-plot
 def test_curve_diffs(self):
     np.random.seed(0)
     clf = LogisticRegression()
     clf.fit(self.X, self.y)
     probas = clf.predict_proba(self.X)
     ax_micro = plot_precision_recall_curve(self.y, probas, curves='micro')
     ax_class = plot_precision_recall_curve(self.y, probas,
                                            curves='each_class')
     self.assertNotEqual(ax_micro, ax_class)
コード例 #4
0
ファイル: test_metrics.py プロジェクト: xhlulu/scikit-plot
 def test_ax(self):
     np.random.seed(0)
     clf = LogisticRegression()
     clf.fit(self.X, self.y)
     probas = clf.predict_proba(self.X)
     fig, ax = plt.subplots(1, 1)
     out_ax = plot_precision_recall_curve(self.y, probas)
     assert ax is not out_ax
     out_ax = plot_precision_recall_curve(self.y, probas, ax=ax)
     assert ax is out_ax
コード例 #5
0
def plot_analysis(combine,
                  test_name,
                  y_true,
                  y_pred,
                  y_proba,
                  labels,
                  verbose,
                  library,
                  save=True,
                  show=True,
                  sessionid="testing",
                  prefix=""):

    met_index = 0
    plt.rcParams.update({'font.size': 14})
    # TODO: Find a way to do this better
    pltmetrics.plot_confusion_matrix(y_true, y_pred)
    if not combine:
        #plt.gcf().set_size_inches(3.65,3.65)
        save_show(plt, library + "/" + prefix, sessionid, "confusion_matrix",
                  show, save, False, True, True, False)
    else:
        plt.subplot(2, 4, met_index + 1)
    met_index += 1

    plt.rcParams.update({'font.size': 12})
    pltmetrics.plot_roc_curve(y_true, y_proba)
    for text in plt.gca().legend_.get_texts():
        text.set_text(text.get_text().replace("ROC curve of class", "class"))
        text.set_text(text.get_text().replace("area =", "AUC: "))
        text.set_text(text.get_text().replace("micro-average ROC curve",
                                              "micro-avg"))
        text.set_text(text.get_text().replace("macro-average ROC curve",
                                              "macro-avg"))
    if not combine:
        #plt.gcf().set_size_inches(3.65,3.65)
        save_show(plt, library + "/" + prefix, sessionid, "roc_curves", show,
                  save, False, True, True, False)
    else:
        plt.subplot(2, 4, met_index + 1)
    met_index += 1

    if len(labels) < 3:
        pltmetrics.plot_ks_statistic(y_true, y_proba)
        if not combine:
            #plt.gcf().set_size_inches(3.65,3.65)
            save_show(plt, library + "/" + prefix, sessionid, "ks_statistics",
                      show, save, False, True, True, False)
        else:
            plt.subplot(2, 4, met_index + 1)
        met_index += 1

    pltmetrics.plot_precision_recall_curve(y_true, y_proba)
    for text in plt.gca().legend_.get_texts():
        text.set_text(text.get_text().replace(
            "Precision-recall curve of class", "class"))
        text.set_text(text.get_text().replace("area =", "AUC: "))
        text.set_text(text.get_text().replace(
            "micro-average Precision-recall curve", "micro-avg"))
        text.set_text(text.get_text().replace("macro-average Precision-recall",
                                              "macro-avg"))
    if not combine:
        #plt.gcf().set_size_inches(3.65,3.65)
        save_show(plt, library + "/" + prefix, sessionid,
                  "precision_recall_curve", show, save, False, True, True,
                  False)
    else:
        plt.subplot(2, 4, met_index + 1)
    met_index += 1

    if len(labels) < 3:
        pltmetrics.plot_cumulative_gain(y_true, y_proba)
        if not combine:
            #plt.gcf().set_size_inches(3.65,3.65)
            save_show(plt, library + "/" + prefix, sessionid,
                      "cumulative_gain", show, save, False, True, True, False)
        else:
            plt.subplot(2, 4, met_index + 1)
        met_index += 1

    if len(labels) < 3:
        pltmetrics.plot_lift_curve(y_true, y_proba)
        if not combine:
            #plt.gcf().set_size_inches(3.65,3.65)
            save_show(plt, library + "/" + prefix, sessionid, "lift_curve",
                      show, save, False, True, True, False)
        else:
            plt.subplot(2, 4, met_index + 1)
        met_index += 1

    if combine:
        plt.suptitle(test_name)
        plt.tight_layout(rect=[0, 0.03, 1, 0.95])
        save_show(plt,
                  library,
                  sessionid,
                  figname,
                  show,
                  save,
                  True,
                  analysis=True)
コード例 #6
0
ファイル: test_metrics.py プロジェクト: xhlulu/scikit-plot
 def test_array_like(self):
     plot_precision_recall_curve([0, 1], [[0.8, 0.2], [0.2, 0.8]])
     plot_precision_recall_curve([0, 'a'], [[0.8, 0.2], [0.2, 0.8]])
     plot_precision_recall_curve(['b', 'a'], [[0.8, 0.2], [0.2, 0.8]])
コード例 #7
0
ファイル: test_metrics.py プロジェクト: xhlulu/scikit-plot
 def test_string_classes(self):
     np.random.seed(0)
     clf = LogisticRegression()
     clf.fit(self.X, convert_labels_into_string(self.y))
     probas = clf.predict_proba(self.X)
     plot_precision_recall_curve(convert_labels_into_string(self.y), probas)