Exemplo n.º 1
0
def heatmap(x_labels, y_labels, matrix_values, show_text=False):
        """
        Generates a heatmap.

        Arguments:
         matrix_values (arr): 2D dataset of shape x_labels * y_labels, containing
                            heatmap values that can be coerced into an ndarray.
         x_labels  (list): Named labels for rows (x_axis).
         y_labels  (list): Named labels for columns (y_axis).
         show_text (bool): Show text values in heatmap cells.

        Returns:
         Nothing. To see plots, go to your W&B run page then expand the 'media' tab
               under 'auto visualizations'.

        Example:
         wandb.log({'heatmap': wandb.plots.HeatMap(x_labels, y_labels,
                    matrix_values)})
        """
        np = util.get_module("numpy", required="roc requires the numpy library, install with `pip install numpy`")
        scikit = util.get_module("sklearn", required="roc requires the scikit library, install with `pip install scikit-learn`")

        if (test_missing(x_labels=x_labels, y_labels=y_labels,
            matrix_values=matrix_values) and test_types(x_labels=x_labels,
            y_labels=y_labels, matrix_values=matrix_values)):
            matrix_values = np.array(matrix_values)
            wandb.termlog('Visualizing heatmap.')

            def heatmap_table(x_labels, y_labels, matrix_values, show_text):
                x_axis=[]
                y_axis=[]
                values=[]
                count = 0
                for i, x in enumerate(x_labels):
                    for j, y in enumerate(y_labels):
                        x_axis.append(x)
                        y_axis.append(y)
                        values.append(matrix_values[j][i])
                        count+=1
                        if count >= chart_limit:
                            wandb.termwarn("wandb uses only the first %d datapoints to create the plots."% wandb.Table.MAX_ROWS)
                            break
                if show_text:
                    heatmap_key = 'wandb/heatmap/v1'
                else:
                    heatmap_key = 'wandb/heatmap_no_text/v1'
                return wandb.visualize(
                    heatmap_key, wandb.Table(
                    columns=['x_axis', 'y_axis', 'values'],
                    data=[
                        [x_axis[i], y_axis[i], round(values[i], 2)] for i in range(len(x_axis))
                    ]
                ))
            return heatmap_table(x_labels, y_labels, matrix_values, show_text)
def precision_recall(y_true=None,
                     y_probas=None,
                     labels=None,
                     plot_micro=True,
                     classes_to_plot=None):
    """
    Computes the tradeoff between precision and recall for different thresholds.
        A high area under the curve represents both high recall and high precision,
        where high precision relates to a low false positive rate, and high recall
        relates to a low false negative rate. High scores for both show that the
        classifier is returning accurate results (high precision), as well as
        returning a majority of all positive results (high recall).
        PR curve is useful when the classes are very imbalanced.

    Arguments:
    y_true (arr): Test set labels.
    y_probas (arr): Test set predicted probabilities.
    labels (list): Named labels for target varible (y). Makes plots easier to
      read by replacing target values with corresponding index.
      For example labels= ['dog', 'cat', 'owl'] all 0s are
      replaced by 'dog', 1s by 'cat'.

    Returns:
    Nothing. To see plots, go to your W&B run page then expand the 'media' tab
    under 'auto visualizations'.

    Example:
    wandb.log({'pr': wandb.plots.precision_recall(y_true, y_probas, labels)})
    """
    np = util.get_module(
        "numpy",
        required=
        "roc requires the numpy library, install with `pip install numpy`")
    scikit = util.get_module(
        "sklearn",
        "roc requires the scikit library, install with `pip install scikit-learn`"
    )

    y_true = np.array(y_true)
    y_probas = np.array(y_probas)

    if (test_missing(y_true=y_true, y_probas=y_probas)
            and test_types(y_true=y_true, y_probas=y_probas)):
        classes = np.unique(y_true)
        probas = y_probas

        if classes_to_plot is None:
            classes_to_plot = classes

        binarized_y_true = scikit.preprocessing.label_binarize(y_true,
                                                               classes=classes)
        if len(classes) == 2:
            binarized_y_true = np.hstack(
                (1 - binarized_y_true, binarized_y_true))

        pr_curves = {}
        indices_to_plot = np.in1d(classes, classes_to_plot)
        for i, to_plot in enumerate(indices_to_plot):
            if to_plot:
                average_precision = scikit.metrics.average_precision_score(
                    binarized_y_true[:, i], probas[:, i])
                precision, recall, _ = scikit.metrics.precision_recall_curve(
                    y_true, probas[:, i], pos_label=classes[i])

                samples = 20
                sample_precision = []
                sample_recall = []
                for k in range(samples):
                    sample_precision.append(precision[int(
                        len(precision) * k / samples)])
                    sample_recall.append(recall[int(len(recall) * k /
                                                    samples)])

                pr_curves[classes[i]] = (sample_precision, sample_recall)

        def pr_table(pr_curves):
            data = []
            count = 0
            for i, class_name in enumerate(pr_curves.keys()):
                precision, recall = pr_curves[class_name]
                for p, r in zip(precision, recall):
                    # if class_names are ints and labels are set
                    if labels is not None and (isinstance(class_name, int)
                                               or isinstance(
                                                   class_name, np.integer)):
                        class_name = labels[class_name]
                    # if class_names are ints and labels are not set
                    # or, if class_names have something other than ints
                    # (string, float, date) - user class_names
                    data.append([class_name, round(p, 3), round(r, 3)])
                    count += 1
                    if count >= chart_limit:
                        wandb.termwarn(
                            "wandb uses only the first %d datapoints to create the plots."
                            % wandb.Table.MAX_ROWS)
                        break
            return wandb.visualize(
                'wandb/pr_curve/v1',
                wandb.Table(columns=['class', 'precision', 'recall'],
                            data=data))

        return pr_table(pr_curves)
Exemplo n.º 3
0
def roc_curve(y_true=None, y_probas=None, labels=None, classes_to_plot=None):
    """
    Calculates receiver operating characteristic scores and visualizes them as the
        ROC curve.

    Arguments:
        y_true (arr): Test set labels.
        y_probas (arr): Test set predicted probabilities.
        labels (list): Named labels for target varible (y). Makes plots easier to
                        read by replacing target values with corresponding index.
                        For example labels= ['dog', 'cat', 'owl'] all 0s are
                        replaced by 'dog', 1s by 'cat'.

    Returns:
        Nothing. To see plots, go to your W&B run page then expand the 'media' tab
            under 'auto visualizations'.

    Example:
        wandb.log({'roc-curve': wandb.plot.roc_curve(y_true, y_probas, labels)})
    """
    np = util.get_module(
        "numpy",
        required=
        "roc requires the numpy library, install with `pip install numpy`")
    util.get_module(
        "sklearn",
        required=
        "roc requires the scikit library, install with `pip install scikit-learn`"
    )
    from sklearn.metrics import roc_curve

    if (test_missing(y_true=y_true, y_probas=y_probas)
            and test_types(y_true=y_true, y_probas=y_probas)):
        y_true = np.array(y_true)
        y_probas = np.array(y_probas)
        classes = np.unique(y_true)
        probas = y_probas

        if classes_to_plot is None:
            classes_to_plot = classes

        fpr_dict = dict()
        tpr_dict = dict()

        indices_to_plot = np.in1d(classes, classes_to_plot)

        data = []
        count = 0

        for i, to_plot in enumerate(indices_to_plot):
            fpr_dict[i], tpr_dict[i], _ = roc_curve(y_true,
                                                    probas[:, i],
                                                    pos_label=classes[i])
            if to_plot:
                for j in range(len(fpr_dict[i])):
                    if labels is not None and (isinstance(classes[i], int)
                                               or isinstance(
                                                   classes[0], np.integer)):
                        class_dict = labels[classes[i]]
                    else:
                        class_dict = classes[i]
                    fpr = [
                        class_dict,
                        round(fpr_dict[i][j], 3),
                        round(tpr_dict[i][j], 3)
                    ]
                    data.append(fpr)
                    count += 1
                    if count >= chart_limit:
                        wandb.termwarn(
                            "wandb uses only the first %d datapoints to create the plots."
                            % wandb.Table.MAX_ROWS)
                        break
        table = wandb.Table(columns=['class', 'fpr', 'tpr'], data=data)
        return wandb.plot_table('wandb/area-under-curve/v0', table, {
            'x': 'fpr',
            'y': 'tpr',
            'class': 'class'
        }, {
            'title': 'ROC',
            'x-axis-title': 'False positive rate',
            'y-axis-title': 'True positive rate'
        })