Beispiel #1
0
    def validate_epoch(self, val_model, epoch_cm):
        """
        Computes the batch validation confusion matrix
        and then updates the epoch confusion matrix.
        """
        # Loop through validation set
        for n in range(self.validation_steps):

            # Grab next batch
            X, y_true, _ = next(self.validation_data)

            # Make prediction with model
            y_pred = val_model([X])[0]

            # Find highest classes prediction
            y_true = np.argmax(y_true, axis=-1)
            y_pred = np.argmax(y_pred, axis=-1)

            # Flatten batch into single array
            y_true = np.ndarray.flatten(y_true)
            y_pred = np.ndarray.flatten(y_pred)

            # Create batch CM
            batch_cm = ConfusionMatrix(y_true, y_pred)

            # Get all classes in batch
            all_classes = list(batch_cm.classes)

            batch_cm = batch_cm.to_array()

            # Update epoch CM
            for i in all_classes:
                for j in all_classes:
                    epoch_cm[i, j] += batch_cm[all_classes.index(i), all_classes.index(j)]
Beispiel #2
0
def plot_confusion_matrix_metrics(true_labels=None,
                                  predicted_labels=None,
                                  normalized=False,
                                  verbose=True):
    """
    Plot a confusion matrix given the known labels of the data (true_labels) and their corresponding predictions (predicted_labels).
    If normalized=True, the confusion matrix will bound its values in an interval between 0 and 1.
    
    Doesn't require plt.show(), just call this function at the end of a cell.
    
    :param true_labels: true values for labels
    :type true_labels: np.array
    :param predicted_labels: predicted label values
    :type predicted_labels: np.array
    :param normalized: bound the analysis in the interval [0, 1]
    :type normalized: boolean (default=False)
    """
    cm = ConfusionMatrix(true_labels, predicted_labels)
    cm.plot(cmap='GnBu', normalized=normalized)
    ax = plt.gca()
    label_dict = {"True": 1, "False": 0}
    str_labels = [
        'Digit {}'.format(label_dict.get(i.get_text(), i.get_text()))
        for i in ax.get_xticklabels()
    ]
    ax.set_xticklabels(str_labels, rotation=0, horizontalalignment='center')
    ax.set_yticklabels(str_labels)
    cm_array = cm.to_array()
    width, height = cm_array.shape
    for x in range(width):
        for y in range(height):
            plt.annotate(str(cm_array[x][y]),
                         xy=(y, x),
                         horizontalalignment='center',
                         verticalalignment='center')
    plt.show()
    print(cm)
    if verbose:
        print("===================================================")
        print("Evaluation metrics:")
        cm.print_stats()
Beispiel #3
0
def plotconfusion(truth, predictions, path, label_dict, classes):
    """
    This function plots the confusion matrix and
    also prints useful statistics.

    :param truth: true labels
    :type truth: np array
    :param predictions: model predictions
    :type predictions: np array
    :param path: path to save image
    :type path: str
    :param label_dict: dict to transform int to str
    :type label_dict: dict
    :param classes: number of classes
    :type classes: int
    """
    acc = np.array(truth) == np.array(predictions)
    size = float(acc.shape[0])
    acc = np.sum(acc.astype("int32")) / size
    truth = [label_dict[i] for i in truth]
    predictions = [label_dict[i] for i in predictions]
    cm = ConfusionMatrix(truth, predictions)
    cm_array = cm.to_array()
    cm_diag = np.diag(cm_array)
    sizes_per_cat = []
    for n in range(cm_array.shape[0]):
        sizes_per_cat.append(np.sum(cm_array[n]))
    sizes_per_cat = np.array(sizes_per_cat)
    sizes_per_cat = sizes_per_cat.astype(np.float32)**-1
    recall = np.multiply(cm_diag, sizes_per_cat)
    print("\nRecall:{}".format(recall))
    print("\nRecall stats: mean = {0:.6f}, std = {1:.6f}\n".format(
        np.mean(recall),  # noqa
        np.std(recall)))  # noqa
    title = "Confusion matrix of {0} examples\n accuracy = {1:.6f}".format(
        int(size),  # noqa
        acc)
    plot_confusion_matrix(cm_array, classes, title=title, path=path)
    cm.print_stats()