Ejemplo n.º 1
0
def calculate_features(im_ms, cloud_mask, im_bool):
    """
    Calculates a range of features on the image that are used for the supervised classification.
    The features include spectral normalized-difference indices and standard deviation of the image.

    KV WRL 2018

    Arguments:
    -----------
        im_ms: np.array
            RGB + downsampled NIR and SWIR
        cloud_mask: np.array
            2D cloud mask with True where cloud pixels are
        im_bool: np.array
            2D array of boolean indicating where on the image to calculate the features

    Returns:    -----------
        features: np.array
            matrix containing each feature (columns) calculated for all
            the pixels (rows) indicated in im_bool
    """

    # add all the multispectral bands
    features = np.expand_dims(im_ms[im_bool,0],axis=1)
    for k in range(1,im_ms.shape[2]):
        feature = np.expand_dims(im_ms[im_bool,k],axis=1)
        features = np.append(features, feature, axis=-1)
    # NIR-G
    im_NIRG = SDS_tools.nd_index(im_ms[:,:,3], im_ms[:,:,1], cloud_mask)
    features = np.append(features, np.expand_dims(im_NIRG[im_bool],axis=1), axis=-1)
    # SWIR-G
    im_SWIRG = SDS_tools.nd_index(im_ms[:,:,4], im_ms[:,:,1], cloud_mask)
    features = np.append(features, np.expand_dims(im_SWIRG[im_bool],axis=1), axis=-1)
    # NIR-R
    im_NIRR = SDS_tools.nd_index(im_ms[:,:,3], im_ms[:,:,2], cloud_mask)
    features = np.append(features, np.expand_dims(im_NIRR[im_bool],axis=1), axis=-1)
    # SWIR-NIR
    im_SWIRNIR = SDS_tools.nd_index(im_ms[:,:,4], im_ms[:,:,3], cloud_mask)
    features = np.append(features, np.expand_dims(im_SWIRNIR[im_bool],axis=1), axis=-1)
    # B-R
    im_BR = SDS_tools.nd_index(im_ms[:,:,0], im_ms[:,:,2], cloud_mask)
    features = np.append(features, np.expand_dims(im_BR[im_bool],axis=1), axis=-1)
    # calculate standard deviation of individual bands
    for k in range(im_ms.shape[2]):
        im_std =  SDS_tools.image_std(im_ms[:,:,k], 1)
        features = np.append(features, np.expand_dims(im_std[im_bool],axis=1), axis=-1)
    # calculate standard deviation of the spectral indices
    im_std = SDS_tools.image_std(im_NIRG, 1)
    features = np.append(features, np.expand_dims(im_std[im_bool],axis=1), axis=-1)
    im_std = SDS_tools.image_std(im_SWIRG, 1)
    features = np.append(features, np.expand_dims(im_std[im_bool],axis=1), axis=-1)
    im_std = SDS_tools.image_std(im_NIRR, 1)
    features = np.append(features, np.expand_dims(im_std[im_bool],axis=1), axis=-1)
    im_std = SDS_tools.image_std(im_SWIRNIR, 1)
    features = np.append(features, np.expand_dims(im_std[im_bool],axis=1), axis=-1)
    im_std = SDS_tools.image_std(im_BR, 1)
    features = np.append(features, np.expand_dims(im_std[im_bool],axis=1), axis=-1)

    return features
Ejemplo n.º 2
0
def extract_shorelines(metadata, settings):
    """
    Extracts shorelines from satellite images.

    KV WRL 2018

    Arguments:
    -----------
        metadata: dict
            contains all the information about the satellite images that were downloaded

        settings: dict
            contains the following fields:
        sitename: str
            String containig the name of the site
        cloud_mask_issue: boolean
            True if there is an issue with the cloud mask and sand pixels are being masked on the images
        buffer_size: int
            size of the buffer (m) around the sandy beach over which the pixels are considered in the
            thresholding algorithm
        min_beach_area: int
            minimum allowable object area (in metres^2) for the class 'sand'
        cloud_thresh: float
            value between 0 and 1 defining the maximum percentage of cloud cover allowed in the images
        output_epsg: int
            output spatial reference system as EPSG code
        check_detection: boolean
            True to show each invidual detection and let the user validate the mapped shoreline

    Returns:
    -----------
        output: dict
            contains the extracted shorelines and corresponding dates.

    """

    sitename = settings['inputs']['sitename']
    filepath_data = settings['inputs']['filepath']
    # initialise output structure
    output = dict([])
    # create a subfolder to store the .jpg images showing the detection
    filepath_jpg = os.path.join(filepath_data, sitename, 'jpg_files',
                                'detection')
    if not os.path.exists(filepath_jpg):
        os.makedirs(filepath_jpg)
    # close all open figures
    plt.close('all')

    print('Mapping shorelines:')

    # loop through satellite list
    for satname in metadata.keys():

        # get images
        filepath = SDS_tools.get_filepath(settings['inputs'], satname)
        filenames = metadata[satname]['filenames']

        # initialise the output variables
        output_timestamp = [
        ]  # datetime at which the image was acquired (UTC time)
        output_shoreline = []  # vector of shoreline points
        output_filename = [
        ]  # filename of the images from which the shorelines where derived
        output_cloudcover = []  # cloud cover of the images
        output_geoaccuracy = []  # georeferencing accuracy of the images
        output_idxkeep = [
        ]  # index that were kept during the analysis (cloudy images are skipped)

        # load classifiers and
        if satname in ['L5', 'L7', 'L8']:
            pixel_size = 15
            if settings['dark_sand']:
                clf = joblib.load(
                    os.path.join(os.getcwd(), 'classifiers',
                                 'NN_4classes_Landsat_dark.pkl'))
            else:
                clf = joblib.load(
                    os.path.join(os.getcwd(), 'classifiers',
                                 'NN_4classes_Landsat.pkl'))

        elif satname == 'S2':
            pixel_size = 10
            clf = joblib.load(
                os.path.join(os.getcwd(), 'classifiers', 'NN_4classes_S2.pkl'))

        # convert settings['min_beach_area'] and settings['buffer_size'] from metres to pixels
        buffer_size_pixels = np.ceil(settings['buffer_size'] / pixel_size)
        min_beach_area_pixels = np.ceil(settings['min_beach_area'] /
                                        pixel_size**2)

        # loop through the images
        for i in range(len(filenames)):

            print('\r%s:   %d%%' %
                  (satname, int(((i + 1) / len(filenames)) * 100)),
                  end='')

            # get image filename
            fn = SDS_tools.get_filenames(filenames[i], filepath, satname)
            # preprocess image (cloud mask + pansharpening/downsampling)
            im_ms, georef, cloud_mask, im_extra, imQA = SDS_preprocess.preprocess_single(
                fn, satname, settings['cloud_mask_issue'])
            # get image spatial reference system (epsg code) from metadata dict
            image_epsg = metadata[satname]['epsg'][i]
            # calculate cloud cover
            cloud_cover = np.divide(
                sum(sum(cloud_mask.astype(int))),
                (cloud_mask.shape[0] * cloud_mask.shape[1]))
            # skip image if cloud cover is above threshold
            if cloud_cover > settings['cloud_thresh']:
                continue

            # classify image in 4 classes (sand, whitewater, water, other) with NN classifier
            im_classif, im_labels = classify_image_NN(im_ms, im_extra,
                                                      cloud_mask,
                                                      min_beach_area_pixels,
                                                      clf)

            # calculate a buffer around the reference shoreline (if any has been digitised)
            im_ref_buffer = create_shoreline_buffer(cloud_mask.shape, georef,
                                                    image_epsg, pixel_size,
                                                    settings)

            # there are two options to extract to map the contours:
            # if there are pixels in the 'sand' class --> use find_wl_contours2 (enhanced)
            # otherwise use find_wl_contours2 (traditional)
            try:  # use try/except structure for long runs
                if sum(sum(im_labels[:, :, 0])) == 0:
                    # compute MNDWI image (SWIR-G)
                    im_mndwi = SDS_tools.nd_index(im_ms[:, :, 4],
                                                  im_ms[:, :, 1], cloud_mask)
                    # find water contours on MNDWI grayscale image
                    contours_mwi = find_wl_contours1(im_mndwi, cloud_mask,
                                                     im_ref_buffer)
                else:
                    # use classification to refine threshold and extract the sand/water interface
                    contours_wi, contours_mwi = find_wl_contours2(
                        im_ms, im_labels, cloud_mask, buffer_size_pixels,
                        im_ref_buffer)
            except:
                print('Could not map shoreline for this image: ' +
                      filenames[i])
                continue

            # process water contours into shorelines
            shoreline = process_shoreline(contours_mwi, georef, image_epsg,
                                          settings)

            # visualise the mapped shorelines, there are two options:
            # if settings['check_detection'] = True, shows the detection to the user for accept/reject
            # if settings['save_figure'] = True, saves a figure for each mapped shoreline
            if settings['check_detection'] or settings['save_figure']:
                date = filenames[i][:19]
                skip_image = show_detection(im_ms, cloud_mask, im_labels,
                                            shoreline, image_epsg, georef,
                                            settings, date, satname)
                # if the user decides to skip the image, continue and do not save the mapped shoreline
                if skip_image:
                    continue

            # append to output variables
            output_timestamp.append(metadata[satname]['dates'][i])
            output_shoreline.append(shoreline)
            output_filename.append(filenames[i])
            output_cloudcover.append(cloud_cover)
            output_geoaccuracy.append(metadata[satname]['acc_georef'][i])
            output_idxkeep.append(i)

        # create dictionnary of output
        output[satname] = {
            'dates': output_timestamp,
            'shorelines': output_shoreline,
            'filename': output_filename,
            'cloud_cover': output_cloudcover,
            'geoaccuracy': output_geoaccuracy,
            'idx': output_idxkeep
        }
        print('')

    # Close figure window if still open
    if plt.get_fignums():
        plt.close()

    # change the format to have one list sorted by date with all the shorelines (easier to use)
    output = SDS_tools.merge_output(output)

    # save outputput structure as output.pkl
    filepath = os.path.join(filepath_data, sitename)
    with open(os.path.join(filepath, sitename + '_output.pkl'), 'wb') as f:
        pickle.dump(output, f)

    # save output into a gdb.GeoDataFrame
    gdf = SDS_tools.output_to_gdf(output)
    # set projection
    gdf.crs = {'init': 'epsg:' + str(settings['output_epsg'])}
    # save as geojson
    gdf.to_file(os.path.join(filepath, sitename + '_output.geojson'),
                driver='GeoJSON',
                encoding='utf-8')

    return output
Ejemplo n.º 3
0
def show_detection(im_ms, cloud_mask, im_labels, shoreline, image_epsg, georef,
                   settings, date, satname):
    """
    Shows the detected shoreline to the user for visual quality control. The user can select "keep"
    if the shoreline detection is correct or "skip" if it is incorrect.

    KV WRL 2018

    Arguments:
    -----------
        im_ms: np.array
            RGB + downsampled NIR and SWIR
        cloud_mask: np.array
            2D cloud mask with True where cloud pixels are
        im_labels: np.array
            3D image containing a boolean image for each class in the order (sand, swash, water)
        shoreline: np.array
            array of points with the X and Y coordinates of the shoreline
        image_epsg: int
            spatial reference system of the image from which the contours were extracted
        georef: np.array
            vector of 6 elements [Xtr, Xscale, Xshear, Ytr, Yshear, Yscale]
        settings: dict
            contains the following fields:
        date: string
            date at which the image was taken
        satname: string
            indicates the satname (L5,L7,L8 or S2)

    Returns:    
    -----------
        skip_image: boolean
            True if the user wants to skip the image, False otherwise.

    """

    sitename = settings['inputs']['sitename']
    filepath_data = settings['inputs']['filepath']
    # subfolder where the .jpg file is stored if the user accepts the shoreline detection
    filepath = os.path.join(filepath_data, sitename, 'jpg_files', 'detection')

    im_RGB = SDS_preprocess.rescale_image_intensity(im_ms[:, :, [2, 1, 0]],
                                                    cloud_mask, 99.9)

    # compute classified image
    im_class = np.copy(im_RGB)
    cmap = cm.get_cmap('tab20c')
    colorpalette = cmap(np.arange(0, 13, 1))
    colours = np.zeros((3, 4))
    colours[0, :] = colorpalette[5]
    colours[1, :] = np.array([204 / 255, 1, 1, 1])
    colours[2, :] = np.array([0, 91 / 255, 1, 1])
    for k in range(0, im_labels.shape[2]):
        im_class[im_labels[:, :, k], 0] = colours[k, 0]
        im_class[im_labels[:, :, k], 1] = colours[k, 1]
        im_class[im_labels[:, :, k], 2] = colours[k, 2]

    # compute MNDWI grayscale image
    im_mwi = SDS_tools.nd_index(im_ms[:, :, 4], im_ms[:, :, 1], cloud_mask)

    # transform world coordinates of shoreline into pixel coordinates
    # use try/except in case there are no coordinates to be transformed (shoreline = [])
    try:
        sl_pix = SDS_tools.convert_world2pix(
            SDS_tools.convert_epsg(shoreline, settings['output_epsg'],
                                   image_epsg)[:, [0, 1]], georef)
    except:
        # if try fails, just add nan into the shoreline vector so the next parts can still run
        sl_pix = np.array([[np.nan, np.nan], [np.nan, np.nan]])

    if plt.get_fignums():
        # get open figure if it exists
        fig = plt.gcf()
        ax1 = fig.axes[0]
        ax2 = fig.axes[1]
        ax3 = fig.axes[2]
    else:
        # else create a new figure
        fig = plt.figure()
        fig.set_size_inches([12.53, 9.3])
        mng = plt.get_current_fig_manager()
        mng.window.showMaximized()

        # according to the image shape, decide whether it is better to have the images
        # in vertical subplots or horizontal subplots
        if im_RGB.shape[1] > 2 * im_RGB.shape[0]:
            # vertical subplots
            gs = gridspec.GridSpec(3, 1)
            gs.update(bottom=0.03, top=0.97, left=0.03, right=0.97)
            ax1 = fig.add_subplot(gs[0, 0])
            ax2 = fig.add_subplot(gs[1, 0])
            ax3 = fig.add_subplot(gs[2, 0])
        else:
            # horizontal subplots
            gs = gridspec.GridSpec(1, 3)
            gs.update(bottom=0.05, top=0.95, left=0.05, right=0.95)
            ax1 = fig.add_subplot(gs[0, 0])
            ax2 = fig.add_subplot(gs[0, 1])
            ax3 = fig.add_subplot(gs[0, 2])

    # change the color of nans to either black (0.0) or white (1.0) or somewhere in between
    nan_color = 1.0
    im_RGB = np.where(np.isnan(im_RGB), nan_color, im_RGB)
    im_class = np.where(np.isnan(im_class), 1.0, im_class)

    # create image 1 (RGB)
    ax1.imshow(im_RGB)
    ax1.plot(sl_pix[:, 0], sl_pix[:, 1], 'k.', markersize=3)
    ax1.axis('off')
    ax1.set_title(sitename, fontweight='bold', fontsize=16)

    # create image 2 (classification)
    ax2.imshow(im_class)
    ax2.plot(sl_pix[:, 0], sl_pix[:, 1], 'k.', markersize=3)
    ax2.axis('off')
    orange_patch = mpatches.Patch(color=colours[0, :], label='sand')
    white_patch = mpatches.Patch(color=colours[1, :], label='whitewater')
    blue_patch = mpatches.Patch(color=colours[2, :], label='water')
    black_line = mlines.Line2D([], [],
                               color='k',
                               linestyle='-',
                               label='shoreline')
    ax2.legend(handles=[orange_patch, white_patch, blue_patch, black_line],
               bbox_to_anchor=(1, 0.5),
               fontsize=10)
    ax2.set_title(date, fontweight='bold', fontsize=16)

    # create image 3 (MNDWI)
    ax3.imshow(im_mwi, cmap='bwr')
    ax3.plot(sl_pix[:, 0], sl_pix[:, 1], 'k.', markersize=3)
    ax3.axis('off')
    ax3.set_title(satname, fontweight='bold', fontsize=16)

    # additional options
    #    ax1.set_anchor('W')
    #    ax2.set_anchor('W')
    #    cb = plt.colorbar()
    #    cb.ax.tick_params(labelsize=10)
    #    cb.set_label('MNDWI values')
    #    ax3.set_anchor('W')

    # if check_detection is True, let user manually accept/reject the images
    skip_image = False
    if settings['check_detection']:

        # set a key event to accept/reject the detections (see https://stackoverflow.com/a/15033071)
        # this variable needs to be immuatable so we can access it after the keypress event
        key_event = {}

        def press(event):
            # store what key was pressed in the dictionary
            key_event['pressed'] = event.key

        # let the user press a key, right arrow to keep the image, left arrow to skip it
        # to break the loop the user can press 'escape'
        while True:
            btn_keep = plt.text(1.1,
                                0.9,
                                'keep ⇨',
                                size=12,
                                ha="right",
                                va="top",
                                transform=ax1.transAxes,
                                bbox=dict(boxstyle="square", ec='k', fc='w'))
            btn_skip = plt.text(-0.1,
                                0.9,
                                '⇦ skip',
                                size=12,
                                ha="left",
                                va="top",
                                transform=ax1.transAxes,
                                bbox=dict(boxstyle="square", ec='k', fc='w'))
            btn_esc = plt.text(0.5,
                               0,
                               '<esc> to quit',
                               size=12,
                               ha="center",
                               va="top",
                               transform=ax1.transAxes,
                               bbox=dict(boxstyle="square", ec='k', fc='w'))
            plt.draw()
            fig.canvas.mpl_connect('key_press_event', press)
            plt.waitforbuttonpress()
            # after button is pressed, remove the buttons
            btn_skip.remove()
            btn_keep.remove()
            btn_esc.remove()

            # keep/skip image according to the pressed key, 'escape' to break the loop
            if key_event.get('pressed') == 'right':
                skip_image = False
                break
            elif key_event.get('pressed') == 'left':
                skip_image = True
                break
            elif key_event.get('pressed') == 'escape':
                plt.close()
                raise StopIteration(
                    'User cancelled checking shoreline detection')
            else:
                plt.waitforbuttonpress()

    # if save_figure is True, save a .jpg under /jpg_files/detection
    if settings['save_figure'] and not skip_image:
        fig.savefig(os.path.join(filepath, date + '_' + satname + '.jpg'),
                    dpi=200)

    # Don't close the figure window, but remove all axes and settings, ready for next plot
    for ax in fig.axes:
        ax.clear()

    return skip_image
Ejemplo n.º 4
0
def find_wl_contours2(im_ms, im_labels, cloud_mask, buffer_size,
                      im_ref_buffer):
    """
    New robust method for extracting shorelines. Incorporates the classification component to
    refine the treshold and make it specific to the sand/water interface.

    KV WRL 2018

    Arguments:
    -----------
        im_ms: np.array
            RGB + downsampled NIR and SWIR
        im_labels: np.array
            3D image containing a boolean image for each class in the order (sand, swash, water)
        cloud_mask: np.array
            2D cloud mask with True where cloud pixels are
        buffer_size: int
            size of the buffer around the sandy beach over which the pixels are considered in the
            thresholding algorithm.
        im_ref_buffer: np.array
            Binary image containing a buffer around the reference shoreline

    Returns:    -----------
        contours_wi: list of np.arrays
            contains the (row,column) coordinates of the contour lines extracted from the
            NDWI (Normalized Difference Water Index) image
        contours_mwi: list of np.arrays
            contains the (row,column) coordinates of the contour lines extracted from the
            MNDWI (Modified Normalized Difference Water Index) image

    """

    nrows = cloud_mask.shape[0]
    ncols = cloud_mask.shape[1]

    # calculate Normalized Difference Modified Water Index (SWIR - G)
    im_mwi = SDS_tools.nd_index(im_ms[:, :, 4], im_ms[:, :, 1], cloud_mask)
    # calculate Normalized Difference Modified Water Index (NIR - G)
    im_wi = SDS_tools.nd_index(im_ms[:, :, 3], im_ms[:, :, 1], cloud_mask)
    # stack indices together
    im_ind = np.stack((im_wi, im_mwi), axis=-1)
    vec_ind = im_ind.reshape(nrows * ncols, 2)

    # reshape labels into vectors
    vec_sand = im_labels[:, :, 0].reshape(ncols * nrows)
    vec_water = im_labels[:, :, 2].reshape(ncols * nrows)

    # create a buffer around the sandy beach
    se = morphology.disk(buffer_size)
    im_buffer = morphology.binary_dilation(im_labels[:, :, 0], se)
    vec_buffer = im_buffer.reshape(nrows * ncols)

    # select water/sand/swash pixels that are within the buffer
    int_water = vec_ind[np.logical_and(vec_buffer, vec_water), :]
    int_sand = vec_ind[np.logical_and(vec_buffer, vec_sand), :]

    # make sure both classes have the same number of pixels before thresholding
    if len(int_water) > 0 and len(int_sand) > 0:
        if np.argmin([int_sand.shape[0], int_water.shape[0]]) == 1:
            int_sand = int_sand[np.random.choice(
                int_sand.shape[0], int_water.shape[0], replace=False), :]
        else:
            int_water = int_water[np.random.choice(
                int_water.shape[0], int_sand.shape[0], replace=False), :]

    # threshold the sand/water intensities
    int_all = np.append(int_water, int_sand, axis=0)
    t_mwi = filters.threshold_otsu(int_all[:, 0])
    t_wi = filters.threshold_otsu(int_all[:, 1])

    # find contour with MS algorithm
    im_wi_buffer = np.copy(im_wi)
    im_wi_buffer[~im_ref_buffer] = np.nan
    im_mwi_buffer = np.copy(im_mwi)
    im_mwi_buffer[~im_ref_buffer] = np.nan
    contours_wi = measure.find_contours(im_wi_buffer, t_wi)
    contours_mwi = measure.find_contours(im_mwi_buffer, t_mwi)

    # remove contour points that are NaNs (around clouds)
    contours = contours_wi
    contours_nonans = []
    for k in range(len(contours)):
        if np.any(np.isnan(contours[k])):
            index_nan = np.where(np.isnan(contours[k]))[0]
            contours_temp = np.delete(contours[k], index_nan, axis=0)
            if len(contours_temp) > 1:
                contours_nonans.append(contours_temp)
        else:
            contours_nonans.append(contours[k])
    contours_wi = contours_nonans
    # repeat for MNDWI contours
    contours = contours_mwi
    contours_nonans = []
    for k in range(len(contours)):
        if np.any(np.isnan(contours[k])):
            index_nan = np.where(np.isnan(contours[k]))[0]
            contours_temp = np.delete(contours[k], index_nan, axis=0)
            if len(contours_temp) > 1:
                contours_nonans.append(contours_temp)
        else:
            contours_nonans.append(contours[k])
    contours_mwi = contours_nonans

    return contours_wi, contours_mwi
Ejemplo n.º 5
0
def label_images(metadata,settings):
    """
    Load satellite images and interactively label different classes (hard-coded)

    KV WRL 2019

    Arguments:
    -----------
    metadata: dict
        contains all the information about the satellite images that were downloaded
    settings: dict with the following keys
        'cloud_thresh': float
            value between 0 and 1 indicating the maximum cloud fraction in 
            the cropped image that is accepted    
        'cloud_mask_issue': boolean
            True if there is an issue with the cloud mask and sand pixels
            are erroneously being masked on the images
        'labels': dict
            list of label names (key) and label numbers (value) for each class
        'flood_fill': boolean
            True to use the flood_fill functionality when labelling sand pixels
        'tolerance': float
            tolerance value for flood fill when labelling the sand pixels
        'filepath_train': str
            directory in which to save the labelled data
        'inputs': dict
            input parameters (sitename, filepath, polygon, dates, sat_list)
                
    Returns:
    -----------
    Stores the labelled data in the specified directory

    """
    
    filepath_train = settings['filepath_train']
    # initialize figure
    fig,ax = plt.subplots(1,1,figsize=[17,10], tight_layout=True,sharex=True,
                          sharey=True)
    mng = plt.get_current_fig_manager()                                         
    mng.window.showMaximized()

    # loop through satellites
    for satname in metadata.keys():
        filepath = SDS_tools.get_filepath(settings['inputs'],satname)
        filenames = metadata[satname]['filenames']
        # loop through images
        for i in range(len(filenames)):
            # image filename
            fn = SDS_tools.get_filenames(filenames[i],filepath, satname)
            # read and preprocess image
            im_ms, georef, cloud_mask, im_extra, im_QA, im_nodata = SDS_preprocess.preprocess_single(fn, satname, settings['cloud_mask_issue'])

            # compute cloud_cover percentage (with no data pixels)
            cloud_cover_combined = np.divide(sum(sum(cloud_mask.astype(int))),
                                    (cloud_mask.shape[0]*cloud_mask.shape[1]))
            if cloud_cover_combined > 0.99: # if 99% of cloudy pixels in image skip
                continue

            # remove no data pixels from the cloud mask (for example L7 bands of no data should not be accounted for)
            cloud_mask_adv = np.logical_xor(cloud_mask, im_nodata)
            # compute updated cloud cover percentage (without no data pixels)
            cloud_cover = np.divide(sum(sum(cloud_mask_adv.astype(int))),
                                    (sum(sum((~im_nodata).astype(int)))))
            # skip image if cloud cover is above threshold
            if cloud_cover > settings['cloud_thresh'] or cloud_cover == 1:
                continue
            # get individual RGB image
            im_RGB = SDS_preprocess.rescale_image_intensity(im_ms[:,:,[2,1,0]], cloud_mask, 99.9)
            im_NDVI = SDS_tools.nd_index(im_ms[:,:,3], im_ms[:,:,2], cloud_mask)
            im_NDWI = SDS_tools.nd_index(im_ms[:,:,3], im_ms[:,:,1], cloud_mask)
            # initialise labels
            im_viz = im_RGB.copy()
            im_labels = np.zeros([im_RGB.shape[0],im_RGB.shape[1]])
            # show RGB image
            ax.axis('off')  
            ax.imshow(im_RGB)
            implot = ax.imshow(im_viz, alpha=0.6)            
            filename = filenames[i][:filenames[i].find('.')][:-4] 
            ax.set_title(filename)
           
            ##############################################################
            # select image to label
            ##############################################################           
            # set a key event to accept/reject the detections (see https://stackoverflow.com/a/15033071)
            # this variable needs to be immuatable so we can access it after the keypress event
            key_event = {}
            def press(event):
                # store what key was pressed in the dictionary
                key_event['pressed'] = event.key
            # let the user press a key, right arrow to keep the image, left arrow to skip it
            # to break the loop the user can press 'escape'
            while True:
                btn_keep = ax.text(1.1, 0.9, 'keep ⇨', size=12, ha="right", va="top",
                                    transform=ax.transAxes,
                                    bbox=dict(boxstyle="square", ec='k',fc='w'))
                btn_skip = ax.text(-0.1, 0.9, '⇦ skip', size=12, ha="left", va="top",
                                    transform=ax.transAxes,
                                    bbox=dict(boxstyle="square", ec='k',fc='w'))
                btn_esc = ax.text(0.5, 0, '<esc> to quit', size=12, ha="center", va="top",
                                    transform=ax.transAxes,
                                    bbox=dict(boxstyle="square", ec='k',fc='w'))
                fig.canvas.draw_idle()                         
                fig.canvas.mpl_connect('key_press_event', press)
                plt.waitforbuttonpress()
                # after button is pressed, remove the buttons
                btn_skip.remove()
                btn_keep.remove()
                btn_esc.remove()
                
                # keep/skip image according to the pressed key, 'escape' to break the loop
                if key_event.get('pressed') == 'right':
                    skip_image = False
                    break
                elif key_event.get('pressed') == 'left':
                    skip_image = True
                    break
                elif key_event.get('pressed') == 'escape':
                    plt.close()
                    raise StopIteration('User cancelled labelling images')
                else:
                    plt.waitforbuttonpress()
                    
            # if user decided to skip show the next image
            if skip_image:
                ax.clear()
                continue
            # otherwise label this image
            else:
                ##############################################################
                # digitize sandy pixels
                ##############################################################
                ax.set_title('Click on SAND pixels (flood fill activated, tolerance = %.2f)\nwhen finished press <Enter>'%settings['tolerance'])
                # create erase button, if you click there it delets the last selection
                btn_erase = ax.text(im_ms.shape[1], 0, 'Erase', size=20, ha='right', va='top',
                                    bbox=dict(boxstyle="square", ec='k',fc='w'))                
                fig.canvas.draw_idle()
                color_sand = settings['colors']['sand']
                sand_pixels = []
                while 1:
                    seed = ginput(n=1, timeout=0, show_clicks=True)
                    # if empty break the loop and go to next label
                    if len(seed) == 0:
                        break
                    else:
                        # round to pixel location
                        seed = np.round(seed[0]).astype(int)     
                    # if user clicks on erase, delete the last selection
                    if seed[0] > 0.95*im_ms.shape[1] and seed[1] < 0.05*im_ms.shape[0]:
                        if len(sand_pixels) > 0:
                            im_labels[sand_pixels[-1]] = 0
                            for k in range(im_viz.shape[2]):                              
                                im_viz[sand_pixels[-1],k] = im_RGB[sand_pixels[-1],k]
                            implot.set_data(im_viz)
                            fig.canvas.draw_idle() 
                            del sand_pixels[-1]
                            
                    # otherwise label the selected sand pixels
                    else:
                        # flood fill the NDVI and the NDWI
                        fill_NDVI = flood(im_NDVI, (seed[1],seed[0]), tolerance=settings['tolerance'])
                        fill_NDWI = flood(im_NDWI, (seed[1],seed[0]), tolerance=settings['tolerance'])
                        # compute the intersection of the two masks
                        fill_sand = np.logical_and(fill_NDVI, fill_NDWI)
                        im_labels[fill_sand] = settings['labels']['sand'] 
                        sand_pixels.append(fill_sand)
                        # show the labelled pixels
                        for k in range(im_viz.shape[2]):                              
                            im_viz[im_labels==settings['labels']['sand'],k] = color_sand[k]
                        implot.set_data(im_viz)
                        fig.canvas.draw_idle() 
                
                ##############################################################
                # digitize white-water pixels
                ##############################################################
                color_ww = settings['colors']['white-water']
                ax.set_title('Click on individual WHITE-WATER pixels (no flood fill)\nwhen finished press <Enter>')
                fig.canvas.draw_idle() 
                ww_pixels = []                        
                while 1:
                    seed = ginput(n=1, timeout=0, show_clicks=True)
                    # if empty break the loop and go to next label
                    if len(seed) == 0:
                        break
                    else:
                        # round to pixel location
                        seed = np.round(seed[0]).astype(int)     
                    # if user clicks on erase, delete the last labelled pixels
                    if seed[0] > 0.95*im_ms.shape[1] and seed[1] < 0.05*im_ms.shape[0]:
                        if len(ww_pixels) > 0:
                            im_labels[ww_pixels[-1][1],ww_pixels[-1][0]] = 0
                            for k in range(im_viz.shape[2]):
                                im_viz[ww_pixels[-1][1],ww_pixels[-1][0],k] = im_RGB[ww_pixels[-1][1],ww_pixels[-1][0],k]
                            implot.set_data(im_viz)
                            fig.canvas.draw_idle()
                            del ww_pixels[-1]
                    else:
                        im_labels[seed[1],seed[0]] = settings['labels']['white-water']  
                        for k in range(im_viz.shape[2]):                              
                            im_viz[seed[1],seed[0],k] = color_ww[k]
                        implot.set_data(im_viz)
                        fig.canvas.draw_idle()
                        ww_pixels.append(seed)
                        
                im_sand_ww = im_viz.copy()
                btn_erase.set(text='<Esc> to Erase', fontsize=12)
                
                ##############################################################
                # digitize water pixels (with lassos)
                ##############################################################
                color_water = settings['colors']['water']
                ax.set_title('Click and hold to draw lassos and select WATER pixels\nwhen finished press <Enter>')
                fig.canvas.draw_idle() 
                selector_water = SelectFromImage(ax, implot, color_water)
                key_event = {}
                while True:
                    fig.canvas.draw_idle()                         
                    fig.canvas.mpl_connect('key_press_event', press)
                    plt.waitforbuttonpress()
                    if key_event.get('pressed') == 'enter':
                        selector_water.disconnect()
                        break
                    elif key_event.get('pressed') == 'escape':
                        selector_water.array = im_sand_ww
                        implot.set_data(selector_water.array)
                        fig.canvas.draw_idle()                         
                        selector_water.implot = implot
                        selector_water.im_bool = np.zeros((selector_water.array.shape[0], selector_water.array.shape[1])) 
                        selector_water.ind=[]          
                # update im_viz and im_labels
                im_viz = selector_water.array
                selector_water.im_bool = selector_water.im_bool.astype(bool)
                im_labels[selector_water.im_bool] = settings['labels']['water']
                
                im_sand_ww_water = im_viz.copy()
                
                ##############################################################
                # digitize land pixels (with lassos)
                ##############################################################
                color_land = settings['colors']['other land features']
                ax.set_title('Click and hold to draw lassos and select OTHER LAND pixels\nwhen finished press <Enter>')
                fig.canvas.draw_idle() 
                selector_land = SelectFromImage(ax, implot, color_land)
                key_event = {}
                while True:
                    fig.canvas.draw_idle()                         
                    fig.canvas.mpl_connect('key_press_event', press)
                    plt.waitforbuttonpress()
                    if key_event.get('pressed') == 'enter':
                        selector_land.disconnect()
                        break
                    elif key_event.get('pressed') == 'escape':
                        selector_land.array = im_sand_ww_water
                        implot.set_data(selector_land.array)
                        fig.canvas.draw_idle()                         
                        selector_land.implot = implot
                        selector_land.im_bool = np.zeros((selector_land.array.shape[0], selector_land.array.shape[1])) 
                        selector_land.ind=[]
                # update im_viz and im_labels
                im_viz = selector_land.array
                selector_land.im_bool = selector_land.im_bool.astype(bool)
                im_labels[selector_land.im_bool] = settings['labels']['other land features']  
                
                # save labelled image
                ax.set_title(filename)
                fig.canvas.draw_idle()                         
                fp = os.path.join(filepath_train,settings['inputs']['sitename'])
                if not os.path.exists(fp):
                    os.makedirs(fp)
                fig.savefig(os.path.join(fp,filename+'.jpg'), dpi=150)
                ax.clear()
                # save labels and features
                features = dict([])
                for key in settings['labels'].keys():
                    im_bool = im_labels == settings['labels'][key]
                    features[key] = SDS_shoreline.calculate_features(im_ms, cloud_mask, im_bool)
                training_data = {'labels':im_labels, 'features':features, 'label_ids':settings['labels']}
                with open(os.path.join(fp, filename + '.pkl'), 'wb') as f:
                    pickle.dump(training_data,f)
                    
    # close figure when finished
    plt.close(fig)
Ejemplo n.º 6
0
def evaluate_classifier(classifier, metadata, settings):
    """
    Apply the image classifier to all the images and save the classified images.

    KV WRL 2019

    Arguments:
    -----------
    classifier: joblib object
        classifier model to be used for image classification
    metadata: dict
        contains all the information about the satellite images that were downloaded
    settings: dict with the following keys
        'inputs': dict
            input parameters (sitename, filepath, polygon, dates, sat_list)
        'cloud_thresh': float
            value between 0 and 1 indicating the maximum cloud fraction in 
            the cropped image that is accepted
        'cloud_mask_issue': boolean
            True if there is an issue with the cloud mask and sand pixels
            are erroneously being masked on the images
        'output_epsg': int
            output spatial reference system as EPSG code
        'buffer_size': int
            size of the buffer (m) around the sandy pixels over which the pixels 
            are considered in the thresholding algorithm
        'min_beach_area': int
            minimum allowable object area (in metres^2) for the class 'sand',
            the area is converted to number of connected pixels
        'min_length_sl': int
            minimum length (in metres) of shoreline contour to be valid

    Returns:
    -----------
    Saves .jpg images with the output of the classification in the folder ./detection
    
    """  
    
    # create folder called evaluation
    fp = os.path.join(os.getcwd(), 'evaluation')
    if not os.path.exists(fp):
        os.makedirs(fp)
        
    # initialize figure (not interactive)
    plt.ioff()
    fig,ax = plt.subplots(1,2,figsize=[17,10],sharex=True, sharey=True,
                          constrained_layout=True)

    # create colormap for labels
    cmap = cm.get_cmap('tab20c')
    colorpalette = cmap(np.arange(0,13,1))
    colours = np.zeros((3,4))
    colours[0,:] = colorpalette[5]
    colours[1,:] = np.array([204/255,1,1,1])
    colours[2,:] = np.array([0,91/255,1,1])
    # loop through satellites
    for satname in metadata.keys():
        filepath = SDS_tools.get_filepath(settings['inputs'],satname)
        filenames = metadata[satname]['filenames']
        
        # load classifiers and
        if satname in ['L5','L7','L8']:
            pixel_size = 15
        elif satname == 'S2':
            pixel_size = 10
        # convert settings['min_beach_area'] and settings['buffer_size'] from metres to pixels
        buffer_size_pixels = np.ceil(settings['buffer_size']/pixel_size)
        min_beach_area_pixels = np.ceil(settings['min_beach_area']/pixel_size**2)
        
        # loop through images
        for i in range(len(filenames)):   
            # image filename
            fn = SDS_tools.get_filenames(filenames[i],filepath, satname)
            # read and preprocess image
            im_ms, georef, cloud_mask, im_extra, im_QA, im_nodata = SDS_preprocess.preprocess_single(fn, satname, settings['cloud_mask_issue'])
            image_epsg = metadata[satname]['epsg'][i]

            # compute cloud_cover percentage (with no data pixels)
            cloud_cover_combined = np.divide(sum(sum(cloud_mask.astype(int))),
                                    (cloud_mask.shape[0]*cloud_mask.shape[1]))
            if cloud_cover_combined > 0.99: # if 99% of cloudy pixels in image skip
                continue

            # remove no data pixels from the cloud mask (for example L7 bands of no data should not be accounted for)
            cloud_mask_adv = np.logical_xor(cloud_mask, im_nodata)
            # compute updated cloud cover percentage (without no data pixels)
            cloud_cover = np.divide(sum(sum(cloud_mask_adv.astype(int))),
                                    (sum(sum((~im_nodata).astype(int)))))
            # skip image if cloud cover is above threshold
            if cloud_cover > settings['cloud_thresh']:
                continue
            # calculate a buffer around the reference shoreline (if any has been digitised)
            im_ref_buffer = SDS_shoreline.create_shoreline_buffer(cloud_mask.shape, georef, image_epsg,
                                                    pixel_size, settings)
            # classify image in 4 classes (sand, whitewater, water, other) with NN classifier
            im_classif, im_labels = SDS_shoreline.classify_image_NN(im_ms, im_extra, cloud_mask,
                                    min_beach_area_pixels, classifier)
            # there are two options to map the contours:
            # if there are pixels in the 'sand' class --> use find_wl_contours2 (enhanced)
            # otherwise use find_wl_contours2 (traditional)
            try: # use try/except structure for long runs
                if sum(sum(im_labels[:,:,0])) < 10 :
                    # compute MNDWI image (SWIR-G)
                    im_mndwi = SDS_tools.nd_index(im_ms[:,:,4], im_ms[:,:,1], cloud_mask)
                    # find water contours on MNDWI grayscale image
                    contours_mwi, t_mndwi = SDS_shoreline.find_wl_contours1(im_mndwi, cloud_mask, im_ref_buffer)
                else:
                    # use classification to refine threshold and extract the sand/water interface
                    contours_mwi, t_mndwi = SDS_shoreline.find_wl_contours2(im_ms, im_labels,
                                                cloud_mask, buffer_size_pixels, im_ref_buffer)
            except:
                print('Could not map shoreline for this image: ' + filenames[i])
                continue
            # process the water contours into a shoreline
            shoreline = SDS_shoreline.process_shoreline(contours_mwi, cloud_mask, georef, image_epsg, settings)
            try:
                sl_pix = SDS_tools.convert_world2pix(SDS_tools.convert_epsg(shoreline,
                                                                            settings['output_epsg'],
                                                                            image_epsg)[:,[0,1]], georef)
            except:
                # if try fails, just add nan into the shoreline vector so the next parts can still run
                sl_pix = np.array([[np.nan, np.nan],[np.nan, np.nan]])
            # make a plot
            im_RGB = SDS_preprocess.rescale_image_intensity(im_ms[:,:,[2,1,0]], cloud_mask, 99.9)
            # create classified image
            im_class = np.copy(im_RGB)
            for k in range(0,im_labels.shape[2]):
                im_class[im_labels[:,:,k],0] = colours[k,0]
                im_class[im_labels[:,:,k],1] = colours[k,1]
                im_class[im_labels[:,:,k],2] = colours[k,2]        
            # show images
            ax[0].imshow(im_RGB)
            ax[1].imshow(im_RGB)
            ax[1].imshow(im_class, alpha=0.5)
            ax[0].axis('off')
            ax[1].axis('off')
            filename = filenames[i][:filenames[i].find('.')][:-4] 
            ax[0].set_title(filename)  
            ax[0].plot(sl_pix[:,0], sl_pix[:,1], 'k.', markersize=3)
            ax[1].plot(sl_pix[:,0], sl_pix[:,1], 'k.', markersize=3)
            # save figure
            fig.savefig(os.path.join(fp,settings['inputs']['sitename'] + filename[:19] +'.jpg'), dpi=150)
            # clear axes
            for cax in fig.axes:
               cax.clear()
   
    # close the figure at the end
    plt.close()
Ejemplo n.º 7
0
def find_wl_contours2(im_ms, im_labels, cloud_mask, buffer_size, im_ref_buffer,
                      dirty_hack_flag):
    """
    New robust method for extracting shorelines. Incorporates the classification
    component to refine the treshold and make it specific to the sand/water interface.

    KV WRL 2018

    Arguments:
    -----------
    im_ms: np.array
        RGB + downsampled NIR and SWIR
    im_labels: np.array
        3D image containing a boolean image for each class in the order (sand, swash, water)
    cloud_mask: np.array
        2D cloud mask with True where cloud pixels are
    buffer_size: int
        size of the buffer around the sandy beach over which the pixels are considered in the
        thresholding algorithm.
    im_ref_buffer: np.array
        binary image containing a buffer around the reference shoreline

    Returns:    
    -----------
    contours_wi: list of np.arrays
        contains the coordinates of the contour lines extracted from the
        NDWI (Normalized Difference Water Index) image
    contours_mwi: list of np.arrays
        contains the coordinates of the contour lines extracted from the
        MNDWI (Modified Normalized Difference Water Index) image

    """

    nrows = cloud_mask.shape[0]
    ncols = cloud_mask.shape[1]

    # calculate Normalized Difference Modified Water Index (SWIR - G)
    im_mwi = SDS_tools.nd_index(im_ms[:, :, 4], im_ms[:, :, 1], cloud_mask)
    # calculate Normalized Difference Modified Water Index (NIR - G)
    im_wi = SDS_tools.nd_index(im_ms[:, :, 3], im_ms[:, :, 1], cloud_mask)
    # stack indices together
    im_ind = np.stack((im_wi, im_mwi), axis=-1)
    vec_ind = im_ind.reshape(nrows * ncols, 2)

    # print("The flag for dirty hack: ", dirty_hack_flag)
    if dirty_hack_flag:
        print("I am doing the dirty hack!")
        ##### <<<<<<<<<<<<<<<<<< DIRTYHACK >>>>>>>>>>>>>>>>>>>>>>>> ####
        # Find a better way to do it!

        # Update the indexes:
        # If all the indexes are false (without a lable, i.e others)
        # add the index true for sand.
        # If the index for white water is true, automatically update
        # update the index to water index.
        for array in np.ndindex(im_labels.shape[:2]):
            i, j = array
            if np.any(im_labels[array] == True):
                if im_labels[i, j, 1] == True:
                    # print("Check if it is true:", im_labels[i,j,1])
                    # print("Before the water update:", im_labels[i,j,2])
                    im_labels[i, j, 2] = True
                    # print("After the water update:", im_labels[i,j,2])
            else:
                # print("Before", im_labels[array] )
                # print("Checj if it is the first element: ", im_labels[i,j,0])
                im_labels[i, j, 0] = True
                # print("After", im_labels[array])
                pass
            pass
        ##### <<<<<<<<<<<<<<<<<<< DIRTYHACK >>>>>>>>>>>>>>>>>>>>>>>> ####
        pass

    # reshape labels into vectors
    vec_sand = im_labels[:, :, 0].reshape(ncols * nrows)
    vec_water = im_labels[:, :, 2].reshape(ncols * nrows)

    # create a buffer around the sandy beach
    se = morphology.disk(buffer_size)
    im_buffer = morphology.binary_dilation(im_labels[:, :, 0], se)
    vec_buffer = im_buffer.reshape(nrows * ncols)

    # select water/sand/swash pixels that are within the buffer
    int_water = vec_ind[np.logical_and(vec_buffer, vec_water), :]
    int_sand = vec_ind[np.logical_and(vec_buffer, vec_sand), :]

    # make sure both classes have the same number of pixels before thresholding
    if len(int_water) > 0 and len(int_sand) > 0:
        if np.argmin([int_sand.shape[0], int_water.shape[0]]) == 1:
            int_sand = int_sand[np.random.choice(
                int_sand.shape[0], int_water.shape[0], replace=False), :]
        else:
            int_water = int_water[np.random.choice(
                int_water.shape[0], int_sand.shape[0], replace=False), :]

    # threshold the sand/water intensities
    int_all = np.append(int_water, int_sand, axis=0)
    t_mwi = filters.threshold_otsu(int_all[:, 0])
    t_wi = filters.threshold_otsu(int_all[:, 1])

    # find contour with MS algorithm
    im_wi_buffer = np.copy(im_wi)
    im_wi_buffer[~im_ref_buffer] = np.nan
    im_mwi_buffer = np.copy(im_mwi)
    im_mwi_buffer[~im_ref_buffer] = np.nan
    contours_wi = measure.find_contours(im_wi_buffer, t_wi)
    contours_mwi = measure.find_contours(im_mwi_buffer, t_mwi)

    # remove contour points that are NaNs (around clouds)
    contours = contours_wi
    contours_nonans = []
    for k in range(len(contours)):
        if np.any(np.isnan(contours[k])):
            index_nan = np.where(np.isnan(contours[k]))[0]
            contours_temp = np.delete(contours[k], index_nan, axis=0)
            if len(contours_temp) > 1:
                contours_nonans.append(contours_temp)
        else:
            contours_nonans.append(contours[k])
    contours_wi = contours_nonans
    # repeat for MNDWI contours
    contours = contours_mwi
    contours_nonans = []
    for k in range(len(contours)):
        if np.any(np.isnan(contours[k])):
            index_nan = np.where(np.isnan(contours[k]))[0]
            contours_temp = np.delete(contours[k], index_nan, axis=0)
            if len(contours_temp) > 1:
                contours_nonans.append(contours_temp)
        else:
            contours_nonans.append(contours[k])
    contours_mwi = contours_nonans

    return contours_wi, contours_mwi
Ejemplo n.º 8
0
def adjust_detection(im_ms, cloud_mask, im_labels, im_ref_buffer, image_epsg, georef,
                       settings, date, satname, buffer_size_pixels):
    """
    Advanced version of show detection where the user can adjust the detected 
    shorelines with a slide bar.

    KV WRL 2020

    Arguments:
    -----------
    im_ms: np.array
        RGB + downsampled NIR and SWIR
    cloud_mask: np.array
        2D cloud mask with True where cloud pixels are
    im_labels: np.array
        3D image containing a boolean image for each class in the order (sand, swash, water)
    im_ref_buffer: np.array
        Binary image containing a buffer around the reference shoreline
    image_epsg: int
        spatial reference system of the image from which the contours were extracted
    georef: np.array
        vector of 6 elements [Xtr, Xscale, Xshear, Ytr, Yshear, Yscale]
    date: string
        date at which the image was taken
    satname: string
        indicates the satname (L5,L7,L8 or S2)
    buffer_size_pixels: int
        buffer_size converted to number of pixels
    settings: dict with the following keys
        'inputs': dict
            input parameters (sitename, filepath, polygon, dates, sat_list)
        'output_epsg': int
            output spatial reference system as EPSG code
        'save_figure': bool
            if True, saves a -jpg file for each mapped shoreline

    Returns:
    -----------
    skip_image: boolean
        True if the user wants to skip the image, False otherwise
    shoreline: np.array
        array of points with the X and Y coordinates of the shoreline 
    t_mndwi: float
        value of the MNDWI threshold used to map the shoreline

    """

    sitename = settings['inputs']['sitename']
    filepath_data = settings['inputs']['filepath']
    # subfolder where the .jpg file is stored if the user accepts the shoreline detection
    filepath = os.path.join(filepath_data, sitename, 'jpg_files', 'detection')
    # format date
    date_str = datetime.strptime(date,'%Y-%m-%d-%H-%M-%S').strftime('%Y-%m-%d  %H:%M:%S')
    im_RGB = SDS_preprocess.rescale_image_intensity(im_ms[:,:,[2,1,0]], cloud_mask, 99.9)

    # compute classified image
    im_class = np.copy(im_RGB)
    cmap = cm.get_cmap('tab20c')
    colorpalette = cmap(np.arange(0,13,1))
    colours = np.zeros((3,4))
    colours[0,:] = colorpalette[5]
    colours[1,:] = np.array([204/255,1,1,1])
    colours[2,:] = np.array([0,91/255,1,1])
    for k in range(0,im_labels.shape[2]):
        im_class[im_labels[:,:,k],0] = colours[k,0]
        im_class[im_labels[:,:,k],1] = colours[k,1]
        im_class[im_labels[:,:,k],2] = colours[k,2]

    # compute MNDWI grayscale image
    im_mndwi = SDS_tools.nd_index(im_ms[:,:,4], im_ms[:,:,1], cloud_mask)
    # buffer MNDWI using reference shoreline
    im_mndwi_buffer = np.copy(im_mndwi)
    im_mndwi_buffer[~im_ref_buffer] = np.nan

    # get MNDWI pixel intensity in each class (for histogram plot)
    int_sand = im_mndwi[im_labels[:,:,0]]
    int_ww = im_mndwi[im_labels[:,:,1]]
    int_water = im_mndwi[im_labels[:,:,2]]
    labels_other = np.logical_and(np.logical_and(~im_labels[:,:,0],~im_labels[:,:,1]),~im_labels[:,:,2])
    int_other = im_mndwi[labels_other]
    
    # create figure
    if plt.get_fignums():
            # if it exists, open the figure 
            fig = plt.gcf()
            ax1 = fig.axes[0]
            ax2 = fig.axes[1]
            ax3 = fig.axes[2]
            ax4 = fig.axes[3]      
    else:
        # else create a new figure
        fig = plt.figure()
        fig.set_size_inches([18, 9])
        mng = plt.get_current_fig_manager()
        mng.window.showMaximized()
        gs = gridspec.GridSpec(2, 3, height_ratios=[4,1])
        gs.update(bottom=0.05, top=0.95, left=0.03, right=0.97)
        ax1 = fig.add_subplot(gs[0,0])
        ax2 = fig.add_subplot(gs[0,1], sharex=ax1, sharey=ax1)
        ax3 = fig.add_subplot(gs[0,2], sharex=ax1, sharey=ax1)
        ax4 = fig.add_subplot(gs[1,:])
    ##########################################################################
    # to do: rotate image if too wide
    ##########################################################################

    # change the color of nans to either black (0.0) or white (1.0) or somewhere in between
    nan_color = 1.0
    im_RGB = np.where(np.isnan(im_RGB), nan_color, im_RGB)
    im_class = np.where(np.isnan(im_class), 1.0, im_class)

    # plot image 1 (RGB)
    ax1.imshow(im_RGB)
    ax1.axis('off')
    ax1.set_title('%s - %s'%(sitename, satname), fontsize=12)

    # plot image 2 (classification)
    ax2.imshow(im_class)
    ax2.axis('off')
    orange_patch = mpatches.Patch(color=colours[0,:], label='sand')
    white_patch = mpatches.Patch(color=colours[1,:], label='whitewater')
    blue_patch = mpatches.Patch(color=colours[2,:], label='water')
    black_line = mlines.Line2D([],[],color='k',linestyle='-', label='shoreline')
    ax2.legend(handles=[orange_patch,white_patch,blue_patch, black_line],
               bbox_to_anchor=(1.1, 0.5), fontsize=10)
    ax2.set_title(date_str, fontsize=12)

    # plot image 3 (MNDWI)
    ax3.imshow(im_mndwi, cmap='bwr')
    ax3.axis('off')
    ax3.set_title('MNDWI', fontsize=12)
    
    # plot histogram of MNDWI values
    binwidth = 0.01
    ax4.set_facecolor('0.75')
    ax4.yaxis.grid(color='w', linestyle='--', linewidth=0.5)
    ax4.set(ylabel='PDF',yticklabels=[], xlim=[-1,1])
    if len(int_sand) > 0 and sum(~np.isnan(int_sand)) > 0:
        bins = np.arange(np.nanmin(int_sand), np.nanmax(int_sand) + binwidth, binwidth)
        ax4.hist(int_sand, bins=bins, density=True, color=colours[0,:], label='sand')
    if len(int_ww) > 0 and sum(~np.isnan(int_ww)) > 0:
        bins = np.arange(np.nanmin(int_ww), np.nanmax(int_ww) + binwidth, binwidth)
        ax4.hist(int_ww, bins=bins, density=True, color=colours[1,:], label='whitewater', alpha=0.75) 
    if len(int_water) > 0 and sum(~np.isnan(int_water)) > 0:
        bins = np.arange(np.nanmin(int_water), np.nanmax(int_water) + binwidth, binwidth)
        ax4.hist(int_water, bins=bins, density=True, color=colours[2,:], label='water', alpha=0.75) 
    if len(int_other) > 0 and sum(~np.isnan(int_other)) > 0:
        bins = np.arange(np.nanmin(int_other), np.nanmax(int_other) + binwidth, binwidth)
        ax4.hist(int_other, bins=bins, density=True, color='C4', label='other', alpha=0.5) 
    
    # automatically map the shoreline based on the classifier if enough sand pixels
    try:
        if sum(sum(im_labels[:,:,0])) > 10:
            # use classification to refine threshold and extract the sand/water interface
            contours_mndwi, t_mndwi = find_wl_contours2(im_ms, im_labels, cloud_mask,
                                                        buffer_size_pixels, im_ref_buffer)
        else:       
            # find water contours on MNDWI grayscale image
            contours_mndwi, t_mndwi = find_wl_contours1(im_mndwi, cloud_mask, im_ref_buffer)    
    except:
        print('Could not map shoreline so image was skipped')
        # clear axes and return skip_image=True, so that image is skipped above
        for ax in fig.axes:
            ax.clear()
        return True,[],[]

    # process the water contours into a shoreline
    shoreline = process_shoreline(contours_mndwi, cloud_mask, georef, image_epsg, settings)
    # convert shoreline to pixels
    if len(shoreline) > 0:
        sl_pix = SDS_tools.convert_world2pix(SDS_tools.convert_epsg(shoreline,
                                                                    settings['output_epsg'],
                                                                    image_epsg)[:,[0,1]], georef)
    else: sl_pix = np.array([[np.nan, np.nan],[np.nan, np.nan]])
    # plot the shoreline on the images
    sl_plot1 = ax1.plot(sl_pix[:,0], sl_pix[:,1], 'k.', markersize=3)
    sl_plot2 = ax2.plot(sl_pix[:,0], sl_pix[:,1], 'k.', markersize=3)
    sl_plot3 = ax3.plot(sl_pix[:,0], sl_pix[:,1], 'k.', markersize=3)
    t_line = ax4.axvline(x=t_mndwi,ls='--', c='k', lw=1.5, label='threshold')
    ax4.legend(loc=1)
    plt.draw() # to update the plot
    # adjust the threshold manually by letting the user change the threshold
    ax4.set_title('Click on the plot below to change the location of the threhsold and adjust the shoreline detection. When finished, press <Enter>')
    while True:  
        # let the user click on the threshold plot
        pt = ginput(n=1, show_clicks=True, timeout=-1)
        # if a point was clicked
        if len(pt) > 0: 
            # if user clicked somewhere wrong and value is not between -1 and 1
            if np.abs(pt[0][0]) >= 1: continue
            # update the threshold value
            t_mndwi = pt[0][0]
            # update the plot
            t_line.set_xdata([t_mndwi,t_mndwi])
            # map contours with new threshold
            contours = measure.find_contours(im_mndwi_buffer, t_mndwi)
            # remove contours that contain NaNs (due to cloud pixels in the contour)
            contours = process_contours(contours) 
            # process the water contours into a shoreline
            shoreline = process_shoreline(contours, cloud_mask, georef, image_epsg, settings)
            # convert shoreline to pixels
            if len(shoreline) > 0:
                sl_pix = SDS_tools.convert_world2pix(SDS_tools.convert_epsg(shoreline,
                                                                            settings['output_epsg'],
                                                                            image_epsg)[:,[0,1]], georef)
            else: sl_pix = np.array([[np.nan, np.nan],[np.nan, np.nan]])
            # update the plotted shorelines
            sl_plot1[0].set_data([sl_pix[:,0], sl_pix[:,1]])
            sl_plot2[0].set_data([sl_pix[:,0], sl_pix[:,1]])
            sl_plot3[0].set_data([sl_pix[:,0], sl_pix[:,1]])
            fig.canvas.draw_idle()
        else:
            ax4.set_title('MNDWI pixel intensities and threshold')
            break
    
    # let user manually accept/reject the image
    skip_image = False
    # set a key event to accept/reject the detections (see https://stackoverflow.com/a/15033071)
    # this variable needs to be immuatable so we can access it after the keypress event
    key_event = {}
    def press(event):
        # store what key was pressed in the dictionary
        key_event['pressed'] = event.key
    # let the user press a key, right arrow to keep the image, left arrow to skip it
    # to break the loop the user can press 'escape'
    while True:
        btn_keep = plt.text(1.1, 0.9, 'keep ⇨', size=12, ha="right", va="top",
                            transform=ax1.transAxes,
                            bbox=dict(boxstyle="square", ec='k',fc='w'))
        btn_skip = plt.text(-0.1, 0.9, '⇦ skip', size=12, ha="left", va="top",
                            transform=ax1.transAxes,
                            bbox=dict(boxstyle="square", ec='k',fc='w'))
        btn_esc = plt.text(0.5, 0, '<esc> to quit', size=12, ha="center", va="top",
                            transform=ax1.transAxes,
                            bbox=dict(boxstyle="square", ec='k',fc='w'))
        plt.draw()
        fig.canvas.mpl_connect('key_press_event', press)
        plt.waitforbuttonpress()
        # after button is pressed, remove the buttons
        btn_skip.remove()
        btn_keep.remove()
        btn_esc.remove()

        # keep/skip image according to the pressed key, 'escape' to break the loop
        if key_event.get('pressed') == 'right':
            skip_image = False
            break
        elif key_event.get('pressed') == 'left':
            skip_image = True
            break
        elif key_event.get('pressed') == 'escape':
            plt.close()
            raise StopIteration('User cancelled checking shoreline detection')
        else:
            plt.waitforbuttonpress()

    # if save_figure is True, save a .jpg under /jpg_files/detection
    if settings['save_figure'] and not skip_image:
        fig.savefig(os.path.join(filepath, date + '_' + satname + '.jpg'), dpi=150)

    # don't close the figure window, but remove all axes and settings, ready for next plot
    for ax in fig.axes:
        ax.clear()

    return skip_image, shoreline, t_mndwi
Ejemplo n.º 9
0
def extract_shorelines(metadata, settings):
    """
    Main function to extract shorelines from satellite images

    KV WRL 2018

    Arguments:
    -----------
    metadata: dict
        contains all the information about the satellite images that were downloaded
    settings: dict with the following keys
        'inputs': dict
            input parameters (sitename, filepath, polygon, dates, sat_list)
        'cloud_thresh': float
            value between 0 and 1 indicating the maximum cloud fraction in 
            the cropped image that is accepted
        'cloud_mask_issue': boolean
            True if there is an issue with the cloud mask and sand pixels
            are erroneously being masked on the images
        'buffer_size': int
            size of the buffer (m) around the sandy pixels over which the pixels 
            are considered in the thresholding algorithm
        'min_beach_area': int
            minimum allowable object area (in metres^2) for the class 'sand',
            the area is converted to number of connected pixels
        'min_length_sl': int
            minimum length (in metres) of shoreline contour to be valid
        'sand_color': str
            default', 'dark' (for grey/black sand beaches) or 'bright' (for white sand beaches)
        'output_epsg': int
            output spatial reference system as EPSG code
        'check_detection': bool
            if True, lets user manually accept/reject the mapped shorelines
        'save_figure': bool
            if True, saves a -jpg file for each mapped shoreline
        'adjust_detection': bool
            if True, allows user to manually adjust the detected shoreline
            
    Returns:
    -----------
    output: dict
        contains the extracted shorelines and corresponding dates + metadata

    """

    sitename = settings['inputs']['sitename']
    filepath_data = settings['inputs']['filepath']
    filepath_models = os.path.join(os.getcwd(), 'classification', 'models')
    # initialise output structure
    output = dict([])
    # create a subfolder to store the .jpg images showing the detection
    filepath_jpg = os.path.join(filepath_data, sitename, 'jpg_files', 'detection')
    if not os.path.exists(filepath_jpg):
            os.makedirs(filepath_jpg)
    # close all open figures
    plt.close('all')

    print('Mapping shorelines:')

    # loop through satellite list
    for satname in metadata.keys():

        # get images
        filepath = SDS_tools.get_filepath(settings['inputs'],satname)
        filenames = metadata[satname]['filenames']

        # initialise the output variables
        output_timestamp = []  # datetime at which the image was acquired (UTC time)
        output_shoreline = []  # vector of shoreline points
        output_filename = []   # filename of the images from which the shorelines where derived
        output_cloudcover = [] # cloud cover of the images
        output_geoaccuracy = []# georeferencing accuracy of the images
        output_idxkeep = []    # index that were kept during the analysis (cloudy images are skipped)
        output_t_mndwi = []    # MNDWI threshold used to map the shoreline
        
        # load classifiers (if sklearn version above 0.20, learn the new files)
        str_new = ''
        if not sklearn.__version__[:4] == '0.20':
            str_new = '_new'
        if satname in ['L5','L7','L8']:
            pixel_size = 15
            if settings['sand_color'] == 'dark':
                clf = joblib.load(os.path.join(filepath_models, 'NN_4classes_Landsat_dark%s.pkl'%str_new))
            elif settings['sand_color'] == 'bright':
                clf = joblib.load(os.path.join(filepath_models, 'NN_4classes_Landsat_bright%s.pkl'%str_new))
            else:
                clf = joblib.load(os.path.join(filepath_models, 'NN_4classes_Landsat%s.pkl'%str_new))

        elif satname == 'S2':
            pixel_size = 10
            clf = joblib.load(os.path.join(filepath_models, 'NN_4classes_S2%s.pkl'%str_new))

        # convert settings['min_beach_area'] and settings['buffer_size'] from metres to pixels
        buffer_size_pixels = np.ceil(settings['buffer_size']/pixel_size)
        min_beach_area_pixels = np.ceil(settings['min_beach_area']/pixel_size**2)

        # loop through the images
        for i in range(len(filenames)):

            print('\r%s:   %d%%' % (satname,int(((i+1)/len(filenames))*100)), end='')

            # get image filename
            fn = SDS_tools.get_filenames(filenames[i],filepath, satname)
            # preprocess image (cloud mask + pansharpening/downsampling)
            im_ms, georef, cloud_mask, im_extra, im_QA, im_nodata = SDS_preprocess.preprocess_single(fn, satname, settings['cloud_mask_issue'])
            # get image spatial reference system (epsg code) from metadata dict
            image_epsg = metadata[satname]['epsg'][i]
            
            # compute cloud_cover percentage (with no data pixels)
            cloud_cover_combined = np.divide(sum(sum(cloud_mask.astype(int))),
                                    (cloud_mask.shape[0]*cloud_mask.shape[1]))
            if cloud_cover_combined > 0.99: # if 99% of cloudy pixels in image skip
                continue
            # remove no data pixels from the cloud mask 
            # (for example L7 bands of no data should not be accounted for)
            cloud_mask_adv = np.logical_xor(cloud_mask, im_nodata) 
            # compute updated cloud cover percentage (without no data pixels)
            cloud_cover = np.divide(sum(sum(cloud_mask_adv.astype(int))),
                                    (cloud_mask.shape[0]*cloud_mask.shape[1]))
            # skip image if cloud cover is above user-defined threshold
            if cloud_cover > settings['cloud_thresh']:
                continue

            # calculate a buffer around the reference shoreline (if any has been digitised)
            im_ref_buffer = create_shoreline_buffer(cloud_mask.shape, georef, image_epsg,
                                                    pixel_size, settings)

            # classify image in 4 classes (sand, whitewater, water, other) with NN classifier
            im_classif, im_labels = classify_image_NN(im_ms, im_extra, cloud_mask,
                                    min_beach_area_pixels, clf)
            
            # if adjust_detection is True, let the user adjust the detected shoreline
            if settings['adjust_detection']:
                date = filenames[i][:19]
                skip_image, shoreline, t_mndwi = adjust_detection(im_ms, cloud_mask, im_labels,
                                                                  im_ref_buffer, image_epsg, georef,
                                                                  settings, date, satname, buffer_size_pixels)
                # if the user decides to skip the image, continue and do not save the mapped shoreline
                if skip_image:
                    continue
                
            # otherwise map the contours automatically with one of the two following functions:
            # if there are pixels in the 'sand' class --> use find_wl_contours2 (enhanced)
            # otherwise use find_wl_contours2 (traditional)
            else:
                try: # use try/except structure for long runs
                    if sum(sum(im_labels[:,:,0])) < 10 : # minimum number of sand pixels
                        # compute MNDWI image (SWIR-G)
                        im_mndwi = SDS_tools.nd_index(im_ms[:,:,4], im_ms[:,:,1], cloud_mask)
                        # find water contours on MNDWI grayscale image
                        contours_mwi, t_mndwi = find_wl_contours1(im_mndwi, cloud_mask, im_ref_buffer)
                    else:
                        # use classification to refine threshold and extract the sand/water interface
                        contours_mwi, t_mndwi = find_wl_contours2(im_ms, im_labels, cloud_mask,
                                                                  buffer_size_pixels, im_ref_buffer)
                except:
                    print('Could not map shoreline for this image: ' + filenames[i])
                    continue
    
                # process the water contours into a shoreline
                shoreline = process_shoreline(contours_mwi, cloud_mask, georef, image_epsg, settings)
    
                # visualise the mapped shorelines, there are two options:
                # if settings['check_detection'] = True, shows the detection to the user for accept/reject
                # if settings['save_figure'] = True, saves a figure for each mapped shoreline
                if settings['check_detection'] or settings['save_figure']:
                    date = filenames[i][:19]
                    if not settings['check_detection']:
                        plt.ioff() # turning interactive plotting off
                    skip_image = show_detection(im_ms, cloud_mask, im_labels, shoreline,
                                                image_epsg, georef, settings, date, satname)
                    # if the user decides to skip the image, continue and do not save the mapped shoreline
                    if skip_image:
                        continue

            # append to output variables
            output_timestamp.append(metadata[satname]['dates'][i])
            output_shoreline.append(shoreline)
            output_filename.append(filenames[i])
            output_cloudcover.append(cloud_cover)
            output_geoaccuracy.append(metadata[satname]['acc_georef'][i])
            output_idxkeep.append(i)
            output_t_mndwi.append(t_mndwi)

        # create dictionnary of output
        output[satname] = {
                'dates': output_timestamp,
                'shorelines': output_shoreline,
                'filename': output_filename,
                'cloud_cover': output_cloudcover,
                'geoaccuracy': output_geoaccuracy,
                'idx': output_idxkeep,
                'MNDWI_threshold': output_t_mndwi,
                }
        print('')

    # close figure window if still open
    if plt.get_fignums():
        plt.close()

    # change the format to have one list sorted by date with all the shorelines (easier to use)
    output = SDS_tools.merge_output(output)

    # save outputput structure as output.pkl
    filepath = os.path.join(filepath_data, sitename)
    with open(os.path.join(filepath, sitename + '_output.pkl'), 'wb') as f:
        pickle.dump(output, f)

    return output
Ejemplo n.º 10
0
def evaluate_classifier(classifier, metadata, settings):
    """
    Interactively visualise the new classifier.

    KV WRL 2018

    Arguments:
    -----------
        classifier: joblib object
            Multilayer Perceptron to be used for image classification
        metadata: dict
            contains all the information about the satellite images that were downloaded
        settings: dict
            contains the following fields:
        cloud_thresh: float
            value between 0 and 1 indicating the maximum cloud fraction in the image that is accepted
        sitename: string
            name of the site (also name of the folder where the images are stored)
        cloud_mask_issue: boolean
            True if there is an issue with the cloud mask and sand pixels are being masked on the images
        labels: dict
            the label name (key) and label number (value) for each class
        filepath_train: str
            directory in which to save the labelled data

    Returns:
    -----------

    """
    # create folder
    fp = os.path.join(os.getcwd(), 'evaluation')
    if not os.path.exists(fp):
        os.makedirs(fp)

    # initialize figure
    fig, ax = plt.subplots(1,
                           2,
                           figsize=[17, 10],
                           sharex=True,
                           sharey=True,
                           constrained_layout=True)
    mng = plt.get_current_fig_manager()
    mng.window.showMaximized()

    # create colormap for labels
    cmap = cm.get_cmap('tab20c')
    colorpalette = cmap(np.arange(0, 13, 1))
    colours = np.zeros((3, 4))
    colours[0, :] = colorpalette[5]
    colours[1, :] = np.array([204 / 255, 1, 1, 1])
    colours[2, :] = np.array([0, 91 / 255, 1, 1])
    # loop through satellites
    for satname in metadata.keys():
        filepath = SDS_tools.get_filepath(settings['inputs'], satname)
        filenames = metadata[satname]['filenames']

        # load classifiers and
        if satname in ['L5', 'L7', 'L8']:
            pixel_size = 15
        elif satname == 'S2':
            pixel_size = 10
        # convert settings['min_beach_area'] and settings['buffer_size'] from metres to pixels
        buffer_size_pixels = np.ceil(settings['buffer_size'] / pixel_size)
        min_beach_area_pixels = np.ceil(settings['min_beach_area'] /
                                        pixel_size**2)

        # loop through images
        for i in range(len(filenames)):
            # image filename
            fn = SDS_tools.get_filenames(filenames[i], filepath, satname)
            # read and preprocess image
            im_ms, georef, cloud_mask, im_extra, im_QA, im_nodata = SDS_preprocess.preprocess_single(
                fn, satname, settings['cloud_mask_issue'])
            image_epsg = metadata[satname]['epsg'][i]
            # calculate cloud cover
            cloud_cover = np.divide(
                sum(sum(cloud_mask.astype(int))),
                (cloud_mask.shape[0] * cloud_mask.shape[1]))
            # skip image if cloud cover is above threshold
            if cloud_cover > settings['cloud_thresh']:
                continue
            # calculate a buffer around the reference shoreline (if any has been digitised)
            im_ref_buffer = SDS_shoreline.create_shoreline_buffer(
                cloud_mask.shape, georef, image_epsg, pixel_size, settings)
            # classify image in 4 classes (sand, whitewater, water, other) with NN classifier
            im_classif, im_labels = SDS_shoreline.classify_image_NN(
                im_ms, im_extra, cloud_mask, min_beach_area_pixels, classifier)
            # there are two options to map the contours:
            # if there are pixels in the 'sand' class --> use find_wl_contours2 (enhanced)
            # otherwise use find_wl_contours2 (traditional)
            try:  # use try/except structure for long runs
                if sum(sum(im_labels[:, :, 0])) < 10:
                    # compute MNDWI image (SWIR-G)
                    im_mndwi = SDS_tools.nd_index(im_ms[:, :, 4],
                                                  im_ms[:, :, 1], cloud_mask)
                    # find water contours on MNDWI grayscale image
                    contours_mwi = SDS_shoreline.find_wl_contours1(
                        im_mndwi, cloud_mask, im_ref_buffer)
                else:
                    # use classification to refine threshold and extract the sand/water interface
                    contours_wi, contours_mwi = SDS_shoreline.find_wl_contours2(
                        im_ms, im_labels, cloud_mask, buffer_size_pixels,
                        im_ref_buffer)
            except:
                print('Could not map shoreline for this image: ' +
                      filenames[i])
                continue
            # process the water contours into a shoreline
            shoreline = SDS_shoreline.process_shoreline(
                contours_mwi, cloud_mask, georef, image_epsg, settings)
            try:
                sl_pix = SDS_tools.convert_world2pix(
                    SDS_tools.convert_epsg(shoreline, settings['output_epsg'],
                                           image_epsg)[:, [0, 1]], georef)
            except:
                # if try fails, just add nan into the shoreline vector so the next parts can still run
                sl_pix = np.array([[np.nan, np.nan], [np.nan, np.nan]])
            # make a plot
            im_RGB = SDS_preprocess.rescale_image_intensity(
                im_ms[:, :, [2, 1, 0]], cloud_mask, 99.9)
            # create classified image
            im_class = np.copy(im_RGB)
            for k in range(0, im_labels.shape[2]):
                im_class[im_labels[:, :, k], 0] = colours[k, 0]
                im_class[im_labels[:, :, k], 1] = colours[k, 1]
                im_class[im_labels[:, :, k], 2] = colours[k, 2]
            # show images
            ax[0].imshow(im_RGB)
            ax[1].imshow(im_RGB)
            ax[1].imshow(im_class, alpha=0.5)
            ax[0].axis('off')
            ax[1].axis('off')
            filename = filenames[i][:filenames[i].find('.')][:-4]
            ax[0].set_title(filename)
            ax[0].plot(sl_pix[:, 0], sl_pix[:, 1], 'k.', markersize=3)
            ax[1].plot(sl_pix[:, 0], sl_pix[:, 1], 'k.', markersize=3)
            # save figure
            fig.savefig(os.path.join(
                fp, settings['inputs']['sitename'] + filename[:19] + '.jpg'),
                        dpi=150)
            # clear axes
            for cax in fig.axes:
                cax.clear()

    # close the figure at the end
    plt.close()