Пример #1
0
def calc_ROC(targets, soft_predictions, plot_ROC=False, outfile='./roc.png'):
    """
    determine ROC of SVM classifier

    :param targets: numpy vector of m targets
    :param soft_predictions: class probabilities
    :param plot_ROC: if true generate plot of ROC
    :param outfile: file to output ROC plot if verbose
    :return: array of [false pos. rate, false neg. rate]
    """
    import numpy as np
    from sklearn.metrics import roc_curve

    fpr, tpr, thresh = roc_curve(targets, soft_predictions)
    if plot_ROC:
        msg = '[calc_ROC] Saving ROC curve to: %s' % outfile
        logging.info(msg)
        print(msg)

        from matplotlib import pyplot as plt
        from accessory import create_dir
        plt.plot(fpr, tpr)
        plt.plot([0, 1], [0, 1], "r--", alpha=.5)
        plt.axis((-0.01, 1.01, -0.01, 1.01))
        plt.ylabel('True Positive Rate')
        plt.xlabel('False Positive Rate')
        plt.title('ROC for SVM Model on Test Set')

        create_dir(outfile)
        plt.savefig(outfile)
        plt.clf()

    return [fpr, tpr]
Пример #2
0
def test_create_dir():
    from accessory import create_dir

    filepath = './test_folder/file.png'
    filedir = './test_folder/'
    create_dir(filepath)
    assert os.path.exists(filedir)
    os.rmdir(filedir)
def calc_pct_yellow(rgb, verb=False, outfile='./yellow.png'):
    """
    calculate percentage of yellow pixels (excluding NaN and black pixels)

    :param rgb: RGB pixel array
    :param verb: verbose mode to display image yellow pixels highlighted
    :param outfile: file to output highlighted image if verbose
    :return: percent of color in image
    """
    import numpy as np
    from accessory import color_nans, percent_color
    from preprocess import nan_yellow_pixels

    if rgb.ndim != 3 or rgb.shape[-1] != 3:
        msg = 'ERROR [calc_pct_yellow]] Input array dimensions ' + \
              str(rgb.shape) + ' incompatible with expected ' \
                               'N x M x 3 RGB input.'
        print(msg)
        logging.error(msg)
        sys.exit()

    if np.max(rgb) > 255 or np.min(rgb) < 0:
        msg = 'ERROR [calc_pct_yellow] Input RGB array must contain element ' \
              'values between 0 and 255. Actual range: [%.1f, %.1f]' % \
              (np.min(rgb), np.max(rgb))
        print(msg)
        logging.error(msg)
        sys.exit()

    y_label = [0, 255, 0]
    recolored_rgb = np.array(rgb)

    # assign all image NaNs to black pixels
    recolored_rgb = color_nans(recolored_rgb, [0, 0, 0])

    # assign NaN to all yellow pixels and recolor based on desired label
    recolored_rgb = nan_yellow_pixels(recolored_rgb)
    recolored_rgb = color_nans(recolored_rgb, color=y_label)

    if verb:
        from accessory import save_rgb
        from accessory import create_dir
        msg = '[calc_pct_yellow] Saving yellow labeled image: %s' % outfile
        logging.info(msg)
        print(msg)

        create_dir(outfile)
        save_rgb(recolored_rgb, outfile)

    # calculate the percentage of labeled yellow pixels
    pct = percent_color(recolored_rgb, y_label)

    return pct
Пример #4
0
def plot_confusion_matrix(cm,
                          classes,
                          normalize=False,
                          title='Confusion Matrix',
                          outfile='./cm.png'):
    """
    plots the confusion matrix (directly adapted from sklearn example)

    :param cm: confusion matrix
    :param classes: classification target names
    :param normalize: enable to normalize confusion matrix values
    :param title: plot title, default='Confusion Matrix'
    :param outfile: file to output confusion matrix figure if verbose
    """
    import numpy as np
    import itertools
    import matplotlib.pyplot as plt
    from accessory import create_dir

    if len(classes) != cm.shape[0]:
        msg = 'ERROR [plot_confusion_matrix] Mismatch between number of ' \
              'specified classes and number of total targets. ' \
              '(%d classes and %d targets)' % (len(classes), cm.shape[0])
        print(msg)
        logging.error(msg)
        sys.exit()

    plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j,
                 i,
                 cm[i, j],
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

    plt.axis('tight')
    plt.ylabel('True')
    plt.xlabel('Predicted')

    create_dir(outfile)
    plt.savefig(outfile)
    plt.clf()
def plot_features(features, targets, labels, outfile='features.png'):
    """
    visualize 2D feature space for data set with known targets

    :param features: N x 2 array of features from N data sets
    :param targets: target labels corresponding to each set of features
    :param labels: labels for each feature
    :param outfile: save location of output feature plot
    """
    import numpy as np
    import matplotlib.pyplot as plt
    from accessory import create_dir

    target_types = np.unique(targets)
    if len(target_types) > 2:
        msg = 'ERROR [plot_features] Function only compatible with 2 targets.'
        logging.error(msg)
        print(msg)
        sys.exit()

    if np.shape(features)[1] > 2:
        msg = 'ERROR [plot_features] Function only compatible with 2 features.'
        logging.error(msg)
        print(msg)
        sys.exit()

    if np.shape(features)[0] != len(targets):
        msg = 'ERROR [plot_features] Mismatch between number of target ' \
              'labels and feature sets.'
        logging.error(msg)
        print(msg)
        sys.exit()

    features0 = features[targets == target_types[0], :]
    features1 = features[targets == target_types[1], :]

    h0 = plt.scatter(features0[:, 0], features0[:, 1], marker='o', c='red',
                     label=target_types[0])
    h1 = plt.scatter(features1[:, 0], features1[:, 1], marker='o',
                     c='blue',
                     label=target_types[1])
    plt.xlabel(labels[0])
    plt.ylabel(labels[1])
    plt.legend(handles=[h0, h1], loc=4)
    plt.grid(True)
    plt.axis('tight')

    create_dir(outfile)
    msg = '[plot_features] Feature space plot saved: %s' % outfile
    print(msg)
    plt.savefig(outfile)
def otsu_threshold(img, omit=[], verb=False, outfile='./otsu_img.png'):
    """
    calculate the global otsu's threshold for image

    :param img: 2D array of pixel values
    :param omit: pixel values to omit from calculation
    :param verb: verbose mode to display threshold image
    :param outfile: file to output otsu image if verbose
    :return: threshold (float)
    """
    from skimage.filters import threshold_otsu
    import numpy as np
    from accessory import get_iterable

    if np.ndim(img) > 2:
        msg = 'ERROR [otsu_threshold] Input image must be 2D grayscale (not ' \
              'RGB).'
        logging.error(msg)
        print(msg)
        sys.exit()

    otsu_img = np.array(img)
    # set omitted pixel values as 0
    for omit_idx in get_iterable(omit):
        otsu_img[img == omit_idx] = np.nan
        otsu_img[np.isnan(otsu_img)] = 0

    threshold_global_otsu = threshold_otsu(otsu_img)

    if verb:
        import matplotlib.pyplot as plt
        from accessory import create_dir
        global_otsu = otsu_img >= threshold_global_otsu
        plt.imshow(global_otsu, cmap=plt.cm.gray)

        msg = '[otsu_threshold] Saving Otsu threshold image: %s' % outfile
        logging.info(msg)
        print(msg)

        create_dir(outfile)
        plt.savefig(outfile)
        plt.clf()

    return threshold_global_otsu
def rgb_histogram(rgb, verb=False, process=True, omit=[],
                  outfile='./hist.png'):
    """
    generate histograms for each color channel of input RGB pixel array

    :param rgb: RGB pixel array
    :param verb: verbose mode to show histograms, default False
    :param process: normalize histograms and omit pixel bins, default True
    :param omit: pixel value bins to omit, default []
    :param outfile: file to output histogram figure if verbose
    :return: histogram values for 0-255 for each color channel (np.arrays)
    """
    import numpy as np

    if rgb.ndim != 3 or rgb.shape[-1] != 3:
        msg = 'ERROR [rgb_histogram] Input array dimensions ' + \
              str(rgb.shape) + \
              ' incompatible with expected N x M x 3 RGB input.'
        print(msg)
        logging.error(msg)
        sys.exit()

    if np.max(rgb) > 255 or np.min(rgb) < 0:
        msg = 'ERROR [rgb_histogram] Input RGB array must contain ' \
              'element values between 0 and 255. Actual range: ' \
              '[%.1f, %.1f]' % (np.min(rgb), np.max(rgb))
        print(msg)
        logging.error(msg)
        sys.exit()

    # compute histograms for each color channel
    rh = extract_hist(rgb[:, :, 0])
    gh = extract_hist(rgb[:, :, 1])
    bh = extract_hist(rgb[:, :, 2])

    # omit bins and normalize histograms
    if process:
        rh = process_rgb_histogram(rh, omit)
        gh = process_rgb_histogram(gh, omit)
        bh = process_rgb_histogram(bh, omit)

    msg = '[rgb_histogram] Extracting RGB histograms from pixel array.'
    logging.debug(msg)
    if verb:
        print(msg)
        import matplotlib.pyplot as plt
        from accessory import create_dir
        bins = [ii for ii in range(0, 256)]

        # plot RGB histograms in subplots with shared x axis
        f, axarr = plt.subplots(3, sharex=True)
        axarr[0].plot(bins, rh)
        axarr[0].axis('tight')
        axarr[0].set_xlabel('R Pixel Value')
        axarr[0].set_ylabel('Frequency')

        axarr[1].plot(bins, gh)
        axarr[1].axis('tight')
        axarr[1].set_xlabel('G Pixel Value')
        axarr[1].set_ylabel('Frequency')

        axarr[2].plot(bins, bh)
        axarr[2].axis('tight')
        axarr[2].set_xlabel('B Pixel Value')
        axarr[2].set_ylabel('Frequency')

        msg = '[rgb_histogram] Saving RGB histogram figure: %s' % outfile
        logging.info(msg)
        print(msg)

        create_dir(outfile)
        plt.savefig(outfile)
        plt.clf()

    return rh, gh, bh