示例#1
0
def create_shoreline_buffer(im_shape, georef, image_epsg, pixel_size, settings):
    """
    Creates a buffer around the reference shoreline. The size of the buffer is 
    given by settings['max_dist_ref'].

    KV WRL 2018

    Arguments:
    -----------
    im_shape: np.array
        size of the image (rows,columns)
    georef: np.array
        vector of 6 elements [Xtr, Xscale, Xshear, Ytr, Yshear, Yscale]
    image_epsg: int
        spatial reference system of the image from which the contours were extracted
    pixel_size: int
        size of the pixel in metres (15 for Landsat, 10 for Sentinel-2)
    settings: dict with the following keys
        'output_epsg': int
            output spatial reference system
        'reference_shoreline': np.array
            coordinates of the reference shoreline
        'max_dist_ref': int
            maximum distance from the reference shoreline in metres

    Returns:    
    -----------
    im_buffer: np.array
        binary image, True where the buffer is, False otherwise

    """
    # initialise the image buffer
    im_buffer = np.ones(im_shape).astype(bool)

    if 'reference_shoreline' in settings.keys():

        # convert reference shoreline to pixel coordinates
        ref_sl = settings['reference_shoreline']
        ref_sl_conv = SDS_tools.convert_epsg(ref_sl, settings['output_epsg'],image_epsg)[:,:-1]
        ref_sl_pix = SDS_tools.convert_world2pix(ref_sl_conv, georef)
        ref_sl_pix_rounded = np.round(ref_sl_pix).astype(int)

        # make sure that the pixel coordinates of the reference shoreline are inside the image
        idx_row = np.logical_and(ref_sl_pix_rounded[:,0] > 0, ref_sl_pix_rounded[:,0] < im_shape[1])
        idx_col = np.logical_and(ref_sl_pix_rounded[:,1] > 0, ref_sl_pix_rounded[:,1] < im_shape[0])
        idx_inside = np.logical_and(idx_row, idx_col)
        ref_sl_pix_rounded = ref_sl_pix_rounded[idx_inside,:]

        # create binary image of the reference shoreline (1 where the shoreline is 0 otherwise)
        im_binary = np.zeros(im_shape)
        for j in range(len(ref_sl_pix_rounded)):
            im_binary[ref_sl_pix_rounded[j,1], ref_sl_pix_rounded[j,0]] = 1
        im_binary = im_binary.astype(bool)

        # dilate the binary image to create a buffer around the reference shoreline
        max_dist_ref_pixels = np.ceil(settings['max_dist_ref']/pixel_size)
        se = morphology.disk(max_dist_ref_pixels)
        im_buffer = morphology.binary_dilation(im_binary, se)

    return im_buffer
示例#2
0
def show_detection(im_ms, cloud_mask, im_labels, shoreline, image_epsg, georef,
                   settings, date, satname):
    """
    Shows the detected shoreline to the user for visual quality control. The user can select "keep"
    if the shoreline detection is correct or "skip" if it is incorrect.

    KV WRL 2018

    Arguments:
    -----------
        im_ms: np.array
            RGB + downsampled NIR and SWIR
        cloud_mask: np.array
            2D cloud mask with True where cloud pixels are
        im_labels: np.array
            3D image containing a boolean image for each class in the order (sand, swash, water)
        shoreline: np.array
            array of points with the X and Y coordinates of the shoreline
        image_epsg: int
            spatial reference system of the image from which the contours were extracted
        georef: np.array
            vector of 6 elements [Xtr, Xscale, Xshear, Ytr, Yshear, Yscale]
        settings: dict
            contains the following fields:
        date: string
            date at which the image was taken
        satname: string
            indicates the satname (L5,L7,L8 or S2)

    Returns:    
    -----------
        skip_image: boolean
            True if the user wants to skip the image, False otherwise.

    """

    sitename = settings['inputs']['sitename']
    filepath_data = settings['inputs']['filepath']
    # subfolder where the .jpg file is stored if the user accepts the shoreline detection
    filepath = os.path.join(filepath_data, sitename, 'jpg_files', 'detection')

    im_RGB = SDS_preprocess.rescale_image_intensity(im_ms[:, :, [2, 1, 0]],
                                                    cloud_mask, 99.9)

    # compute classified image
    im_class = np.copy(im_RGB)
    cmap = cm.get_cmap('tab20c')
    colorpalette = cmap(np.arange(0, 13, 1))
    colours = np.zeros((3, 4))
    colours[0, :] = colorpalette[5]
    colours[1, :] = np.array([204 / 255, 1, 1, 1])
    colours[2, :] = np.array([0, 91 / 255, 1, 1])
    for k in range(0, im_labels.shape[2]):
        im_class[im_labels[:, :, k], 0] = colours[k, 0]
        im_class[im_labels[:, :, k], 1] = colours[k, 1]
        im_class[im_labels[:, :, k], 2] = colours[k, 2]

    # compute MNDWI grayscale image
    im_mwi = SDS_tools.nd_index(im_ms[:, :, 4], im_ms[:, :, 1], cloud_mask)

    # transform world coordinates of shoreline into pixel coordinates
    # use try/except in case there are no coordinates to be transformed (shoreline = [])
    try:
        sl_pix = SDS_tools.convert_world2pix(
            SDS_tools.convert_epsg(shoreline, settings['output_epsg'],
                                   image_epsg)[:, [0, 1]], georef)
    except:
        # if try fails, just add nan into the shoreline vector so the next parts can still run
        sl_pix = np.array([[np.nan, np.nan], [np.nan, np.nan]])

    if plt.get_fignums():
        # get open figure if it exists
        fig = plt.gcf()
        ax1 = fig.axes[0]
        ax2 = fig.axes[1]
        ax3 = fig.axes[2]
    else:
        # else create a new figure
        fig = plt.figure()
        fig.set_size_inches([12.53, 9.3])
        mng = plt.get_current_fig_manager()
        mng.window.showMaximized()

        # according to the image shape, decide whether it is better to have the images
        # in vertical subplots or horizontal subplots
        if im_RGB.shape[1] > 2 * im_RGB.shape[0]:
            # vertical subplots
            gs = gridspec.GridSpec(3, 1)
            gs.update(bottom=0.03, top=0.97, left=0.03, right=0.97)
            ax1 = fig.add_subplot(gs[0, 0])
            ax2 = fig.add_subplot(gs[1, 0])
            ax3 = fig.add_subplot(gs[2, 0])
        else:
            # horizontal subplots
            gs = gridspec.GridSpec(1, 3)
            gs.update(bottom=0.05, top=0.95, left=0.05, right=0.95)
            ax1 = fig.add_subplot(gs[0, 0])
            ax2 = fig.add_subplot(gs[0, 1])
            ax3 = fig.add_subplot(gs[0, 2])

    # change the color of nans to either black (0.0) or white (1.0) or somewhere in between
    nan_color = 1.0
    im_RGB = np.where(np.isnan(im_RGB), nan_color, im_RGB)
    im_class = np.where(np.isnan(im_class), 1.0, im_class)

    # create image 1 (RGB)
    ax1.imshow(im_RGB)
    ax1.plot(sl_pix[:, 0], sl_pix[:, 1], 'k.', markersize=3)
    ax1.axis('off')
    ax1.set_title(sitename, fontweight='bold', fontsize=16)

    # create image 2 (classification)
    ax2.imshow(im_class)
    ax2.plot(sl_pix[:, 0], sl_pix[:, 1], 'k.', markersize=3)
    ax2.axis('off')
    orange_patch = mpatches.Patch(color=colours[0, :], label='sand')
    white_patch = mpatches.Patch(color=colours[1, :], label='whitewater')
    blue_patch = mpatches.Patch(color=colours[2, :], label='water')
    black_line = mlines.Line2D([], [],
                               color='k',
                               linestyle='-',
                               label='shoreline')
    ax2.legend(handles=[orange_patch, white_patch, blue_patch, black_line],
               bbox_to_anchor=(1, 0.5),
               fontsize=10)
    ax2.set_title(date, fontweight='bold', fontsize=16)

    # create image 3 (MNDWI)
    ax3.imshow(im_mwi, cmap='bwr')
    ax3.plot(sl_pix[:, 0], sl_pix[:, 1], 'k.', markersize=3)
    ax3.axis('off')
    ax3.set_title(satname, fontweight='bold', fontsize=16)

    # additional options
    #    ax1.set_anchor('W')
    #    ax2.set_anchor('W')
    #    cb = plt.colorbar()
    #    cb.ax.tick_params(labelsize=10)
    #    cb.set_label('MNDWI values')
    #    ax3.set_anchor('W')

    # if check_detection is True, let user manually accept/reject the images
    skip_image = False
    if settings['check_detection']:

        # set a key event to accept/reject the detections (see https://stackoverflow.com/a/15033071)
        # this variable needs to be immuatable so we can access it after the keypress event
        key_event = {}

        def press(event):
            # store what key was pressed in the dictionary
            key_event['pressed'] = event.key

        # let the user press a key, right arrow to keep the image, left arrow to skip it
        # to break the loop the user can press 'escape'
        while True:
            btn_keep = plt.text(1.1,
                                0.9,
                                'keep ⇨',
                                size=12,
                                ha="right",
                                va="top",
                                transform=ax1.transAxes,
                                bbox=dict(boxstyle="square", ec='k', fc='w'))
            btn_skip = plt.text(-0.1,
                                0.9,
                                '⇦ skip',
                                size=12,
                                ha="left",
                                va="top",
                                transform=ax1.transAxes,
                                bbox=dict(boxstyle="square", ec='k', fc='w'))
            btn_esc = plt.text(0.5,
                               0,
                               '<esc> to quit',
                               size=12,
                               ha="center",
                               va="top",
                               transform=ax1.transAxes,
                               bbox=dict(boxstyle="square", ec='k', fc='w'))
            plt.draw()
            fig.canvas.mpl_connect('key_press_event', press)
            plt.waitforbuttonpress()
            # after button is pressed, remove the buttons
            btn_skip.remove()
            btn_keep.remove()
            btn_esc.remove()

            # keep/skip image according to the pressed key, 'escape' to break the loop
            if key_event.get('pressed') == 'right':
                skip_image = False
                break
            elif key_event.get('pressed') == 'left':
                skip_image = True
                break
            elif key_event.get('pressed') == 'escape':
                plt.close()
                raise StopIteration(
                    'User cancelled checking shoreline detection')
            else:
                plt.waitforbuttonpress()

    # if save_figure is True, save a .jpg under /jpg_files/detection
    if settings['save_figure'] and not skip_image:
        fig.savefig(os.path.join(filepath, date + '_' + satname + '.jpg'),
                    dpi=200)

    # Don't close the figure window, but remove all axes and settings, ready for next plot
    for ax in fig.axes:
        ax.clear()

    return skip_image
示例#3
0
def evaluate_classifier(classifier, metadata, settings):
    """
    Apply the image classifier to all the images and save the classified images.

    KV WRL 2019

    Arguments:
    -----------
    classifier: joblib object
        classifier model to be used for image classification
    metadata: dict
        contains all the information about the satellite images that were downloaded
    settings: dict with the following keys
        'inputs': dict
            input parameters (sitename, filepath, polygon, dates, sat_list)
        'cloud_thresh': float
            value between 0 and 1 indicating the maximum cloud fraction in 
            the cropped image that is accepted
        'cloud_mask_issue': boolean
            True if there is an issue with the cloud mask and sand pixels
            are erroneously being masked on the images
        'output_epsg': int
            output spatial reference system as EPSG code
        'buffer_size': int
            size of the buffer (m) around the sandy pixels over which the pixels 
            are considered in the thresholding algorithm
        'min_beach_area': int
            minimum allowable object area (in metres^2) for the class 'sand',
            the area is converted to number of connected pixels
        'min_length_sl': int
            minimum length (in metres) of shoreline contour to be valid

    Returns:
    -----------
    Saves .jpg images with the output of the classification in the folder ./detection
    
    """  
    
    # create folder called evaluation
    fp = os.path.join(os.getcwd(), 'evaluation')
    if not os.path.exists(fp):
        os.makedirs(fp)
        
    # initialize figure (not interactive)
    plt.ioff()
    fig,ax = plt.subplots(1,2,figsize=[17,10],sharex=True, sharey=True,
                          constrained_layout=True)

    # create colormap for labels
    cmap = cm.get_cmap('tab20c')
    colorpalette = cmap(np.arange(0,13,1))
    colours = np.zeros((3,4))
    colours[0,:] = colorpalette[5]
    colours[1,:] = np.array([204/255,1,1,1])
    colours[2,:] = np.array([0,91/255,1,1])
    # loop through satellites
    for satname in metadata.keys():
        filepath = SDS_tools.get_filepath(settings['inputs'],satname)
        filenames = metadata[satname]['filenames']
        
        # load classifiers and
        if satname in ['L5','L7','L8']:
            pixel_size = 15
        elif satname == 'S2':
            pixel_size = 10
        # convert settings['min_beach_area'] and settings['buffer_size'] from metres to pixels
        buffer_size_pixels = np.ceil(settings['buffer_size']/pixel_size)
        min_beach_area_pixels = np.ceil(settings['min_beach_area']/pixel_size**2)
        
        # loop through images
        for i in range(len(filenames)):   
            # image filename
            fn = SDS_tools.get_filenames(filenames[i],filepath, satname)
            # read and preprocess image
            im_ms, georef, cloud_mask, im_extra, im_QA, im_nodata = SDS_preprocess.preprocess_single(fn, satname, settings['cloud_mask_issue'])
            image_epsg = metadata[satname]['epsg'][i]

            # compute cloud_cover percentage (with no data pixels)
            cloud_cover_combined = np.divide(sum(sum(cloud_mask.astype(int))),
                                    (cloud_mask.shape[0]*cloud_mask.shape[1]))
            if cloud_cover_combined > 0.99: # if 99% of cloudy pixels in image skip
                continue

            # remove no data pixels from the cloud mask (for example L7 bands of no data should not be accounted for)
            cloud_mask_adv = np.logical_xor(cloud_mask, im_nodata)
            # compute updated cloud cover percentage (without no data pixels)
            cloud_cover = np.divide(sum(sum(cloud_mask_adv.astype(int))),
                                    (sum(sum((~im_nodata).astype(int)))))
            # skip image if cloud cover is above threshold
            if cloud_cover > settings['cloud_thresh']:
                continue
            # calculate a buffer around the reference shoreline (if any has been digitised)
            im_ref_buffer = SDS_shoreline.create_shoreline_buffer(cloud_mask.shape, georef, image_epsg,
                                                    pixel_size, settings)
            # classify image in 4 classes (sand, whitewater, water, other) with NN classifier
            im_classif, im_labels = SDS_shoreline.classify_image_NN(im_ms, im_extra, cloud_mask,
                                    min_beach_area_pixels, classifier)
            # there are two options to map the contours:
            # if there are pixels in the 'sand' class --> use find_wl_contours2 (enhanced)
            # otherwise use find_wl_contours2 (traditional)
            try: # use try/except structure for long runs
                if sum(sum(im_labels[:,:,0])) < 10 :
                    # compute MNDWI image (SWIR-G)
                    im_mndwi = SDS_tools.nd_index(im_ms[:,:,4], im_ms[:,:,1], cloud_mask)
                    # find water contours on MNDWI grayscale image
                    contours_mwi, t_mndwi = SDS_shoreline.find_wl_contours1(im_mndwi, cloud_mask, im_ref_buffer)
                else:
                    # use classification to refine threshold and extract the sand/water interface
                    contours_mwi, t_mndwi = SDS_shoreline.find_wl_contours2(im_ms, im_labels,
                                                cloud_mask, buffer_size_pixels, im_ref_buffer)
            except:
                print('Could not map shoreline for this image: ' + filenames[i])
                continue
            # process the water contours into a shoreline
            shoreline = SDS_shoreline.process_shoreline(contours_mwi, cloud_mask, georef, image_epsg, settings)
            try:
                sl_pix = SDS_tools.convert_world2pix(SDS_tools.convert_epsg(shoreline,
                                                                            settings['output_epsg'],
                                                                            image_epsg)[:,[0,1]], georef)
            except:
                # if try fails, just add nan into the shoreline vector so the next parts can still run
                sl_pix = np.array([[np.nan, np.nan],[np.nan, np.nan]])
            # make a plot
            im_RGB = SDS_preprocess.rescale_image_intensity(im_ms[:,:,[2,1,0]], cloud_mask, 99.9)
            # create classified image
            im_class = np.copy(im_RGB)
            for k in range(0,im_labels.shape[2]):
                im_class[im_labels[:,:,k],0] = colours[k,0]
                im_class[im_labels[:,:,k],1] = colours[k,1]
                im_class[im_labels[:,:,k],2] = colours[k,2]        
            # show images
            ax[0].imshow(im_RGB)
            ax[1].imshow(im_RGB)
            ax[1].imshow(im_class, alpha=0.5)
            ax[0].axis('off')
            ax[1].axis('off')
            filename = filenames[i][:filenames[i].find('.')][:-4] 
            ax[0].set_title(filename)  
            ax[0].plot(sl_pix[:,0], sl_pix[:,1], 'k.', markersize=3)
            ax[1].plot(sl_pix[:,0], sl_pix[:,1], 'k.', markersize=3)
            # save figure
            fig.savefig(os.path.join(fp,settings['inputs']['sitename'] + filename[:19] +'.jpg'), dpi=150)
            # clear axes
            for cax in fig.axes:
               cax.clear()
   
    # close the figure at the end
    plt.close()
示例#4
0
def get_reference_sl(metadata, settings):
    """
    Allows the user to manually digitize a reference shoreline that is used seed
    the shoreline detection algorithm. The reference shoreline helps to detect 
    the outliers, making the shoreline detection more robust.

    KV WRL 2018

    Arguments:
    -----------
    metadata: dict
        contains all the information about the satellite images that were downloaded
    settings: dict with the following keys
        'inputs': dict
            input parameters (sitename, filepath, polygon, dates, sat_list)
        'cloud_thresh': float
            value between 0 and 1 indicating the maximum cloud fraction in 
            the cropped image that is accepted
        'cloud_mask_issue': boolean
            True if there is an issue with the cloud mask and sand pixels
            are erroneously being masked on the images
        'output_epsg': int
            output spatial reference system as EPSG code

    Returns:
    -----------
    reference_shoreline: np.array
        coordinates of the reference shoreline that was manually digitized. 
        This is also saved as a .pkl and .geojson file.

    """

    sitename = settings['inputs']['sitename']
    filepath_data = settings['inputs']['filepath']
    pts_coords = []
    # check if reference shoreline already exists in the corresponding folder
    filepath = os.path.join(filepath_data, sitename)
    filename = sitename + '_reference_shoreline.pkl'
    # if it exist, load it and return it
    if filename in os.listdir(filepath):
        print('Reference shoreline already exists and was loaded')
        with open(os.path.join(filepath, sitename + '_reference_shoreline.pkl'), 'rb') as f:
            refsl = pickle.load(f)
        return refsl
    
    # otherwise get the user to manually digitise a shoreline on S2, L8 or L5 images (no L7 because of scan line error)
    else:
        # first try to use S2 images (10m res for manually digitizing the reference shoreline)
        if 'S2' in metadata.keys():
            satname = 'S2'
            filepath = SDS_tools.get_filepath(settings['inputs'],satname)
            filenames = metadata[satname]['filenames']
        # if no S2 images, try L8  (15m res in the RGB with pansharpening)
        elif not 'S2' in metadata.keys() and 'L8' in metadata.keys():
            satname = 'L8'
            filepath = SDS_tools.get_filepath(settings['inputs'],satname)
            filenames = metadata[satname]['filenames']
        # if no S2 images and no L8, use L5 images (L7 images have black diagonal bands making it
        # hard to manually digitize a shoreline)
        elif not 'S2' in metadata.keys() and not 'L8' in metadata.keys() and 'L5' in metadata.keys():
            satname = 'L5'
            filepath = SDS_tools.get_filepath(settings['inputs'],satname)
            filenames = metadata[satname]['filenames']
        else:
            raise Exception('You cannot digitize the shoreline on L7 images (because of gaps in the images), add another L8, S2 or L5 to your dataset.')
            
        # create figure
        fig, ax = plt.subplots(1,1, figsize=[18,9], tight_layout=True)
        mng = plt.get_current_fig_manager()
        mng.window.showMaximized()
        # loop trhough the images
        for i in range(len(filenames)):

            # read image
            fn = SDS_tools.get_filenames(filenames[i],filepath, satname)
            im_ms, georef, cloud_mask, im_extra, im_QA, im_nodata = preprocess_single(fn, satname, settings['cloud_mask_issue'])

            # calculate cloud cover
            cloud_cover = np.divide(sum(sum(cloud_mask.astype(int))),
                                    (cloud_mask.shape[0]*cloud_mask.shape[1]))

            # skip image if cloud cover is above threshold
            if cloud_cover > settings['cloud_thresh']:
                continue

            # rescale image intensity for display purposes
            im_RGB = rescale_image_intensity(im_ms[:,:,[2,1,0]], cloud_mask, 99.9)

            # plot the image RGB on a figure
            ax.axis('off')
            ax.imshow(im_RGB)

            # decide if the image if good enough for digitizing the shoreline
            ax.set_title('Press <right arrow> if image is clear enough to digitize the shoreline.\n' +
                      'If the image is cloudy press <left arrow> to get another image', fontsize=14)
            # set a key event to accept/reject the detections (see https://stackoverflow.com/a/15033071)
            # this variable needs to be immuatable so we can access it after the keypress event
            skip_image = False
            key_event = {}
            def press(event):
                # store what key was pressed in the dictionary
                key_event['pressed'] = event.key
            # let the user press a key, right arrow to keep the image, left arrow to skip it
            # to break the loop the user can press 'escape'
            while True:
                btn_keep = plt.text(1.1, 0.9, 'keep ⇨', size=12, ha="right", va="top",
                                    transform=ax.transAxes,
                                    bbox=dict(boxstyle="square", ec='k',fc='w'))
                btn_skip = plt.text(-0.1, 0.9, '⇦ skip', size=12, ha="left", va="top",
                                    transform=ax.transAxes,
                                    bbox=dict(boxstyle="square", ec='k',fc='w'))
                btn_esc = plt.text(0.5, 0, '<esc> to quit', size=12, ha="center", va="top",
                                    transform=ax.transAxes,
                                    bbox=dict(boxstyle="square", ec='k',fc='w'))
                plt.draw()
                fig.canvas.mpl_connect('key_press_event', press)
                plt.waitforbuttonpress()
                # after button is pressed, remove the buttons
                btn_skip.remove()
                btn_keep.remove()
                btn_esc.remove()
                # keep/skip image according to the pressed key, 'escape' to break the loop
                if key_event.get('pressed') == 'right':
                    skip_image = False
                    break
                elif key_event.get('pressed') == 'left':
                    skip_image = True
                    break
                elif key_event.get('pressed') == 'escape':
                    plt.close()
                    raise StopIteration('User cancelled checking shoreline detection')
                else:
                    plt.waitforbuttonpress()
                    
            if skip_image:
                ax.clear()
                continue
            else:
                # create two new buttons
                add_button = plt.text(0, 0.9, 'add', size=16, ha="left", va="top",
                                       transform=plt.gca().transAxes,
                                       bbox=dict(boxstyle="square", ec='k',fc='w'))
                end_button = plt.text(1, 0.9, 'end', size=16, ha="right", va="top",
                                       transform=plt.gca().transAxes,
                                       bbox=dict(boxstyle="square", ec='k',fc='w'))
                # add multiple reference shorelines (until user clicks on <end> button)
                pts_sl = np.expand_dims(np.array([np.nan, np.nan]),axis=0)
                geoms = []
                while 1:
                    add_button.set_visible(False)
                    end_button.set_visible(False)
                    # update title (instructions)
                    ax.set_title('Click points along the shoreline (enough points to capture the beach curvature).\n' +
                              'Start at one end of the beach.\n' + 'When finished digitizing, click <ENTER>',
                              fontsize=14)
                    plt.draw()

                    # let user click on the shoreline
                    pts = ginput(n=50000, timeout=1e9, show_clicks=True)
                    pts_pix = np.array(pts)
                    # convert pixel coordinates to world coordinates
                    pts_world = SDS_tools.convert_pix2world(pts_pix[:,[1,0]], georef)

                    # interpolate between points clicked by the user (1m resolution)
                    pts_world_interp = np.expand_dims(np.array([np.nan, np.nan]),axis=0)
                    for k in range(len(pts_world)-1):
                        pt_dist = np.linalg.norm(pts_world[k,:]-pts_world[k+1,:])
                        xvals = np.arange(0,pt_dist)
                        yvals = np.zeros(len(xvals))
                        pt_coords = np.zeros((len(xvals),2))
                        pt_coords[:,0] = xvals
                        pt_coords[:,1] = yvals
                        phi = 0
                        deltax = pts_world[k+1,0] - pts_world[k,0]
                        deltay = pts_world[k+1,1] - pts_world[k,1]
                        phi = np.pi/2 - np.math.atan2(deltax, deltay)
                        tf = transform.EuclideanTransform(rotation=phi, translation=pts_world[k,:])
                        pts_world_interp = np.append(pts_world_interp,tf(pt_coords), axis=0)
                    pts_world_interp = np.delete(pts_world_interp,0,axis=0)

                    # save as geometry (to create .geojson file later)
                    geoms.append(geometry.LineString(pts_world_interp))

                    # convert to pixel coordinates and plot
                    pts_pix_interp = SDS_tools.convert_world2pix(pts_world_interp, georef)
                    pts_sl = np.append(pts_sl, pts_world_interp, axis=0)
                    ax.plot(pts_pix_interp[:,0], pts_pix_interp[:,1], 'r--')
                    ax.plot(pts_pix_interp[0,0], pts_pix_interp[0,1],'ko')
                    ax.plot(pts_pix_interp[-1,0], pts_pix_interp[-1,1],'ko')

                    # update title and buttons
                    add_button.set_visible(True)
                    end_button.set_visible(True)
                    ax.set_title('click on <add> to digitize another shoreline or on <end> to finish and save the shoreline(s)',
                              fontsize=14)
                    plt.draw()

                    # let the user click again (<add> another shoreline or <end>)
                    pt_input = ginput(n=1, timeout=1e9, show_clicks=False)
                    pt_input = np.array(pt_input)

                    # if user clicks on <end>, save the points and break the loop
                    if pt_input[0][0] > im_ms.shape[1]/2:
                        add_button.set_visible(False)
                        end_button.set_visible(False)
                        plt.title('Reference shoreline saved as ' + sitename + '_reference_shoreline.pkl and ' + sitename + '_reference_shoreline.geojson')
                        plt.draw()
                        ginput(n=1, timeout=3, show_clicks=False)
                        plt.close()
                        break

                pts_sl = np.delete(pts_sl,0,axis=0)
                # convert world image coordinates to user-defined coordinate system
                image_epsg = metadata[satname]['epsg'][i]
                pts_coords = SDS_tools.convert_epsg(pts_sl, image_epsg, settings['output_epsg'])

                # save the reference shoreline as .pkl
                filepath = os.path.join(filepath_data, sitename)
                with open(os.path.join(filepath, sitename + '_reference_shoreline.pkl'), 'wb') as f:
                    pickle.dump(pts_coords, f)

                # also store as .geojson in case user wants to drag-and-drop on GIS for verification
                for k,line in enumerate(geoms):
                    gdf = gpd.GeoDataFrame(geometry=gpd.GeoSeries(line))
                    gdf.index = [k]
                    gdf.loc[k,'name'] = 'reference shoreline ' + str(k+1)
                    # store into geodataframe
                    if k == 0:
                        gdf_all = gdf
                    else:
                        gdf_all = gdf_all.append(gdf)
                gdf_all.crs = {'init':'epsg:'+str(image_epsg)}
                # convert from image_epsg to user-defined coordinate system
                gdf_all = gdf_all.to_crs({'init': 'epsg:'+str(settings['output_epsg'])})
                # save as geojson
                gdf_all.to_file(os.path.join(filepath, sitename + '_reference_shoreline.geojson'),
                                driver='GeoJSON', encoding='utf-8')

                print('Reference shoreline has been saved in ' + filepath)
                break
            
    # check if a shoreline was digitised
    if len(pts_coords) == 0:
        raise Exception('No cloud free images are available to digitise the reference shoreline,'+
                        'download more images and try again') 

    return pts_coords
示例#5
0
def get_reference_sl(metadata, settings):
    """
    Allows the user to manually digitize a reference shoreline that is used seed the shoreline 
    detection algorithm. The reference shoreline helps to detect the outliers, making the shoreline
    detection more robust.
    
    KV WRL 2018

    Arguments:
    -----------
        metadata: dict
            contains all the information about the satellite images that were downloaded
        settings: dict
            contains the following fields:
        'cloud_thresh': float
            value between 0 and 1 indicating the maximum cloud fraction in the image that is accepted
        'sitename': string
            name of the site (also name of the folder where the images are stored)
        'output_epsg': int
            epsg code of the desired spatial reference system
                    
    Returns:
    -----------
        reference_shoreline: np.array
            coordinates of the reference shoreline that was manually digitized
            
    """

    sitename = settings['inputs']['sitename']
    filepath_data = settings['inputs']['filepath']

    # check if reference shoreline already exists in the corresponding folder
    filepath = os.path.join(filepath_data, sitename)
    filename = sitename + '_reference_shoreline.pkl'
    if filename in os.listdir(filepath):
        print('Reference shoreline already exists and was loaded')
        with open(
                os.path.join(filepath, sitename + '_reference_shoreline.pkl'),
                'rb') as f:
            refsl = pickle.load(f)
        return refsl

    else:
        # first try to use S2 images (10m res for manually digitizing the reference shoreline)
        if 'S2' in metadata.keys():
            satname = 'S2'
            filepath = SDS_tools.get_filepath(settings['inputs'], satname)
            filenames = metadata[satname]['filenames']
        # if no S2 images, try L8  (15m res in the RGB with pansharpening)
        elif not 'S2' in metadata.keys() and 'L8' in metadata.keys():
            satname = 'L8'
            filepath = SDS_tools.get_filepath(settings['inputs'], satname)
            filenames = metadata[satname]['filenames']
        # if no S2 images and no L8, use L5 images (L7 images have black diagonal bands making it
        # hard to manually digitize a shoreline)
        elif not 'S2' in metadata.keys() and not 'L8' in metadata.keys(
        ) and 'L5' in metadata.keys():
            satname = 'L5'
            filepath = SDS_tools.get_filepath(settings['inputs'], satname)
            filenames = metadata[satname]['filenames']
        else:
            raise Exception(
                'You cannot digitize the shoreline on L7 images, add another L8, S2 or L5 to your dataset.'
            )

        # loop trhough the images
        for i in range(len(filenames)):

            # read image
            fn = SDS_tools.get_filenames(filenames[i], filepath, satname)
            im_ms, georef, cloud_mask, im_extra, imQA = preprocess_single(
                fn, satname, settings['cloud_mask_issue'])

            # calculate cloud cover
            cloud_cover = np.divide(
                sum(sum(cloud_mask.astype(int))),
                (cloud_mask.shape[0] * cloud_mask.shape[1]))

            # skip image if cloud cover is above threshold
            if cloud_cover > settings['cloud_thresh']:
                continue

            # rescale image intensity for display purposes
            im_RGB = rescale_image_intensity(im_ms[:, :, [2, 1, 0]],
                                             cloud_mask, 99.9)

            # plot the image RGB on a figure
            fig = plt.figure()
            fig.set_size_inches([18, 9])
            fig.set_tight_layout(True)
            plt.axis('off')
            plt.imshow(im_RGB)

            # decide if the image if good enough for digitizing the shoreline
            plt.title(
                'click <keep> if image is clear enough to digitize the shoreline.\n'
                + 'If not (too cloudy) click on <skip> to get another image',
                fontsize=14)
            keep_button = plt.text(0,
                                   0.9,
                                   'keep',
                                   size=16,
                                   ha="left",
                                   va="top",
                                   transform=plt.gca().transAxes,
                                   bbox=dict(boxstyle="square", ec='k',
                                             fc='w'))
            skip_button = plt.text(1,
                                   0.9,
                                   'skip',
                                   size=16,
                                   ha="right",
                                   va="top",
                                   transform=plt.gca().transAxes,
                                   bbox=dict(boxstyle="square", ec='k',
                                             fc='w'))
            mng = plt.get_current_fig_manager()
            mng.window.showMaximized()

            # let user click on the image once
            pt_input = ginput(n=1, timeout=1e9, show_clicks=False)
            pt_input = np.array(pt_input)

            # if clicks next to <skip>, show another image
            if pt_input[0][0] > im_ms.shape[1] / 2:
                plt.close()
                continue

            else:
                # remove keep and skip buttons
                keep_button.set_visible(False)
                skip_button.set_visible(False)
                # create two new buttons
                add_button = plt.text(0,
                                      0.9,
                                      'add',
                                      size=16,
                                      ha="left",
                                      va="top",
                                      transform=plt.gca().transAxes,
                                      bbox=dict(boxstyle="square",
                                                ec='k',
                                                fc='w'))
                end_button = plt.text(1,
                                      0.9,
                                      'end',
                                      size=16,
                                      ha="right",
                                      va="top",
                                      transform=plt.gca().transAxes,
                                      bbox=dict(boxstyle="square",
                                                ec='k',
                                                fc='w'))

                # add multiple reference shorelines (until user clicks on <end> button)
                pts_sl = np.expand_dims(np.array([np.nan, np.nan]), axis=0)
                geoms = []
                while 1:
                    add_button.set_visible(False)
                    end_button.set_visible(False)
                    # update title (instructions)
                    plt.title(
                        'Click points along the shoreline (enough points to capture the beach curvature).\n'
                        + 'Start at one end of the beach.\n' +
                        'When finished digitizing, click <ENTER>',
                        fontsize=14)
                    plt.draw()

                    # let user click on the shoreline
                    pts = ginput(n=50000, timeout=1e9, show_clicks=True)
                    pts_pix = np.array(pts)
                    # convert pixel coordinates to world coordinates
                    pts_world = SDS_tools.convert_pix2world(
                        pts_pix[:, [1, 0]], georef)

                    # interpolate between points clicked by the user (1m resolution)
                    pts_world_interp = np.expand_dims(np.array(
                        [np.nan, np.nan]),
                                                      axis=0)
                    for k in range(len(pts_world) - 1):
                        pt_dist = np.linalg.norm(pts_world[k, :] -
                                                 pts_world[k + 1, :])
                        xvals = np.arange(0, pt_dist)
                        yvals = np.zeros(len(xvals))
                        pt_coords = np.zeros((len(xvals), 2))
                        pt_coords[:, 0] = xvals
                        pt_coords[:, 1] = yvals
                        phi = 0
                        deltax = pts_world[k + 1, 0] - pts_world[k, 0]
                        deltay = pts_world[k + 1, 1] - pts_world[k, 1]
                        phi = np.pi / 2 - np.math.atan2(deltax, deltay)
                        tf = transform.EuclideanTransform(
                            rotation=phi, translation=pts_world[k, :])
                        pts_world_interp = np.append(pts_world_interp,
                                                     tf(pt_coords),
                                                     axis=0)
                    pts_world_interp = np.delete(pts_world_interp, 0, axis=0)

                    # save as geometry (to create .geojson file later)
                    geoms.append(geometry.LineString(pts_world_interp))

                    # convert to pixel coordinates and plot
                    pts_pix_interp = SDS_tools.convert_world2pix(
                        pts_world_interp, georef)
                    pts_sl = np.append(pts_sl, pts_world_interp, axis=0)
                    plt.plot(pts_pix_interp[:, 0], pts_pix_interp[:, 1], 'r--')
                    plt.plot(pts_pix_interp[0, 0], pts_pix_interp[0, 1], 'ko')
                    plt.plot(pts_pix_interp[-1, 0], pts_pix_interp[-1, 1],
                             'ko')

                    # update title and buttons
                    add_button.set_visible(True)
                    end_button.set_visible(True)
                    plt.title(
                        'click <add> to digitize another shoreline or <end> to finish and save the shoreline(s)',
                        fontsize=14)
                    plt.draw()

                    # let the user click again (<add> another shoreline or <end>)
                    pt_input = ginput(n=1, timeout=1e9, show_clicks=False)
                    pt_input = np.array(pt_input)

                    # if user clicks on <end>, save the points and break the loop
                    if pt_input[0][0] > im_ms.shape[1] / 2:
                        add_button.set_visible(False)
                        end_button.set_visible(False)
                        plt.title('Reference shoreline saved as ' + sitename +
                                  '_reference_shoreline.pkl and ' + sitename +
                                  '_reference_shoreline.geojson')
                        plt.draw()
                        ginput(n=1, timeout=3, show_clicks=False)
                        plt.close()
                        break

                pts_sl = np.delete(pts_sl, 0, axis=0)
                # convert world image coordinates to user-defined coordinate system
                image_epsg = metadata[satname]['epsg'][i]
                pts_coords = SDS_tools.convert_epsg(pts_sl, image_epsg,
                                                    settings['output_epsg'])

                # save the reference shoreline as .pkl
                filepath = os.path.join(filepath_data, sitename)
                with open(
                        os.path.join(filepath,
                                     sitename + '_reference_shoreline.pkl'),
                        'wb') as f:
                    pickle.dump(pts_coords, f)

                # also store as .geojson in case user wants to drag-and-drop on GIS for verification
                for k, line in enumerate(geoms):
                    gdf = gpd.GeoDataFrame(geometry=gpd.GeoSeries(line))
                    gdf.index = [k]
                    gdf.loc[k, 'name'] = 'reference shoreline ' + str(k + 1)
                    # store into geodataframe
                    if k == 0:
                        gdf_all = gdf
                    else:
                        gdf_all = gdf_all.append(gdf)
                gdf_all.crs = {'init': 'epsg:' + str(image_epsg)}
                # convert from image_epsg to user-defined coordinate system
                gdf_all = gdf_all.to_crs(
                    {'init': 'epsg:' + str(settings['output_epsg'])})
                # save as geojson
                gdf_all.to_file(os.path.join(
                    filepath, sitename + '_reference_shoreline.geojson'),
                                driver='GeoJSON',
                                encoding='utf-8')

                print('Reference shoreline has been saved in ' + filepath)
                break

    return pts_coords
示例#6
0
def adjust_detection(im_ms, cloud_mask, im_labels, im_ref_buffer, image_epsg, georef,
                       settings, date, satname, buffer_size_pixels):
    """
    Advanced version of show detection where the user can adjust the detected 
    shorelines with a slide bar.

    KV WRL 2020

    Arguments:
    -----------
    im_ms: np.array
        RGB + downsampled NIR and SWIR
    cloud_mask: np.array
        2D cloud mask with True where cloud pixels are
    im_labels: np.array
        3D image containing a boolean image for each class in the order (sand, swash, water)
    im_ref_buffer: np.array
        Binary image containing a buffer around the reference shoreline
    image_epsg: int
        spatial reference system of the image from which the contours were extracted
    georef: np.array
        vector of 6 elements [Xtr, Xscale, Xshear, Ytr, Yshear, Yscale]
    date: string
        date at which the image was taken
    satname: string
        indicates the satname (L5,L7,L8 or S2)
    buffer_size_pixels: int
        buffer_size converted to number of pixels
    settings: dict with the following keys
        'inputs': dict
            input parameters (sitename, filepath, polygon, dates, sat_list)
        'output_epsg': int
            output spatial reference system as EPSG code
        'save_figure': bool
            if True, saves a -jpg file for each mapped shoreline

    Returns:
    -----------
    skip_image: boolean
        True if the user wants to skip the image, False otherwise
    shoreline: np.array
        array of points with the X and Y coordinates of the shoreline 
    t_mndwi: float
        value of the MNDWI threshold used to map the shoreline

    """

    sitename = settings['inputs']['sitename']
    filepath_data = settings['inputs']['filepath']
    # subfolder where the .jpg file is stored if the user accepts the shoreline detection
    filepath = os.path.join(filepath_data, sitename, 'jpg_files', 'detection')
    # format date
    date_str = datetime.strptime(date,'%Y-%m-%d-%H-%M-%S').strftime('%Y-%m-%d  %H:%M:%S')
    im_RGB = SDS_preprocess.rescale_image_intensity(im_ms[:,:,[2,1,0]], cloud_mask, 99.9)

    # compute classified image
    im_class = np.copy(im_RGB)
    cmap = cm.get_cmap('tab20c')
    colorpalette = cmap(np.arange(0,13,1))
    colours = np.zeros((3,4))
    colours[0,:] = colorpalette[5]
    colours[1,:] = np.array([204/255,1,1,1])
    colours[2,:] = np.array([0,91/255,1,1])
    for k in range(0,im_labels.shape[2]):
        im_class[im_labels[:,:,k],0] = colours[k,0]
        im_class[im_labels[:,:,k],1] = colours[k,1]
        im_class[im_labels[:,:,k],2] = colours[k,2]

    # compute MNDWI grayscale image
    im_mndwi = SDS_tools.nd_index(im_ms[:,:,4], im_ms[:,:,1], cloud_mask)
    # buffer MNDWI using reference shoreline
    im_mndwi_buffer = np.copy(im_mndwi)
    im_mndwi_buffer[~im_ref_buffer] = np.nan

    # get MNDWI pixel intensity in each class (for histogram plot)
    int_sand = im_mndwi[im_labels[:,:,0]]
    int_ww = im_mndwi[im_labels[:,:,1]]
    int_water = im_mndwi[im_labels[:,:,2]]
    labels_other = np.logical_and(np.logical_and(~im_labels[:,:,0],~im_labels[:,:,1]),~im_labels[:,:,2])
    int_other = im_mndwi[labels_other]
    
    # create figure
    if plt.get_fignums():
            # if it exists, open the figure 
            fig = plt.gcf()
            ax1 = fig.axes[0]
            ax2 = fig.axes[1]
            ax3 = fig.axes[2]
            ax4 = fig.axes[3]      
    else:
        # else create a new figure
        fig = plt.figure()
        fig.set_size_inches([18, 9])
        mng = plt.get_current_fig_manager()
        mng.window.showMaximized()
        gs = gridspec.GridSpec(2, 3, height_ratios=[4,1])
        gs.update(bottom=0.05, top=0.95, left=0.03, right=0.97)
        ax1 = fig.add_subplot(gs[0,0])
        ax2 = fig.add_subplot(gs[0,1], sharex=ax1, sharey=ax1)
        ax3 = fig.add_subplot(gs[0,2], sharex=ax1, sharey=ax1)
        ax4 = fig.add_subplot(gs[1,:])
    ##########################################################################
    # to do: rotate image if too wide
    ##########################################################################

    # change the color of nans to either black (0.0) or white (1.0) or somewhere in between
    nan_color = 1.0
    im_RGB = np.where(np.isnan(im_RGB), nan_color, im_RGB)
    im_class = np.where(np.isnan(im_class), 1.0, im_class)

    # plot image 1 (RGB)
    ax1.imshow(im_RGB)
    ax1.axis('off')
    ax1.set_title('%s - %s'%(sitename, satname), fontsize=12)

    # plot image 2 (classification)
    ax2.imshow(im_class)
    ax2.axis('off')
    orange_patch = mpatches.Patch(color=colours[0,:], label='sand')
    white_patch = mpatches.Patch(color=colours[1,:], label='whitewater')
    blue_patch = mpatches.Patch(color=colours[2,:], label='water')
    black_line = mlines.Line2D([],[],color='k',linestyle='-', label='shoreline')
    ax2.legend(handles=[orange_patch,white_patch,blue_patch, black_line],
               bbox_to_anchor=(1.1, 0.5), fontsize=10)
    ax2.set_title(date_str, fontsize=12)

    # plot image 3 (MNDWI)
    ax3.imshow(im_mndwi, cmap='bwr')
    ax3.axis('off')
    ax3.set_title('MNDWI', fontsize=12)
    
    # plot histogram of MNDWI values
    binwidth = 0.01
    ax4.set_facecolor('0.75')
    ax4.yaxis.grid(color='w', linestyle='--', linewidth=0.5)
    ax4.set(ylabel='PDF',yticklabels=[], xlim=[-1,1])
    if len(int_sand) > 0 and sum(~np.isnan(int_sand)) > 0:
        bins = np.arange(np.nanmin(int_sand), np.nanmax(int_sand) + binwidth, binwidth)
        ax4.hist(int_sand, bins=bins, density=True, color=colours[0,:], label='sand')
    if len(int_ww) > 0 and sum(~np.isnan(int_ww)) > 0:
        bins = np.arange(np.nanmin(int_ww), np.nanmax(int_ww) + binwidth, binwidth)
        ax4.hist(int_ww, bins=bins, density=True, color=colours[1,:], label='whitewater', alpha=0.75) 
    if len(int_water) > 0 and sum(~np.isnan(int_water)) > 0:
        bins = np.arange(np.nanmin(int_water), np.nanmax(int_water) + binwidth, binwidth)
        ax4.hist(int_water, bins=bins, density=True, color=colours[2,:], label='water', alpha=0.75) 
    if len(int_other) > 0 and sum(~np.isnan(int_other)) > 0:
        bins = np.arange(np.nanmin(int_other), np.nanmax(int_other) + binwidth, binwidth)
        ax4.hist(int_other, bins=bins, density=True, color='C4', label='other', alpha=0.5) 
    
    # automatically map the shoreline based on the classifier if enough sand pixels
    try:
        if sum(sum(im_labels[:,:,0])) > 10:
            # use classification to refine threshold and extract the sand/water interface
            contours_mndwi, t_mndwi = find_wl_contours2(im_ms, im_labels, cloud_mask,
                                                        buffer_size_pixels, im_ref_buffer)
        else:       
            # find water contours on MNDWI grayscale image
            contours_mndwi, t_mndwi = find_wl_contours1(im_mndwi, cloud_mask, im_ref_buffer)    
    except:
        print('Could not map shoreline so image was skipped')
        # clear axes and return skip_image=True, so that image is skipped above
        for ax in fig.axes:
            ax.clear()
        return True,[],[]

    # process the water contours into a shoreline
    shoreline = process_shoreline(contours_mndwi, cloud_mask, georef, image_epsg, settings)
    # convert shoreline to pixels
    if len(shoreline) > 0:
        sl_pix = SDS_tools.convert_world2pix(SDS_tools.convert_epsg(shoreline,
                                                                    settings['output_epsg'],
                                                                    image_epsg)[:,[0,1]], georef)
    else: sl_pix = np.array([[np.nan, np.nan],[np.nan, np.nan]])
    # plot the shoreline on the images
    sl_plot1 = ax1.plot(sl_pix[:,0], sl_pix[:,1], 'k.', markersize=3)
    sl_plot2 = ax2.plot(sl_pix[:,0], sl_pix[:,1], 'k.', markersize=3)
    sl_plot3 = ax3.plot(sl_pix[:,0], sl_pix[:,1], 'k.', markersize=3)
    t_line = ax4.axvline(x=t_mndwi,ls='--', c='k', lw=1.5, label='threshold')
    ax4.legend(loc=1)
    plt.draw() # to update the plot
    # adjust the threshold manually by letting the user change the threshold
    ax4.set_title('Click on the plot below to change the location of the threhsold and adjust the shoreline detection. When finished, press <Enter>')
    while True:  
        # let the user click on the threshold plot
        pt = ginput(n=1, show_clicks=True, timeout=-1)
        # if a point was clicked
        if len(pt) > 0: 
            # if user clicked somewhere wrong and value is not between -1 and 1
            if np.abs(pt[0][0]) >= 1: continue
            # update the threshold value
            t_mndwi = pt[0][0]
            # update the plot
            t_line.set_xdata([t_mndwi,t_mndwi])
            # map contours with new threshold
            contours = measure.find_contours(im_mndwi_buffer, t_mndwi)
            # remove contours that contain NaNs (due to cloud pixels in the contour)
            contours = process_contours(contours) 
            # process the water contours into a shoreline
            shoreline = process_shoreline(contours, cloud_mask, georef, image_epsg, settings)
            # convert shoreline to pixels
            if len(shoreline) > 0:
                sl_pix = SDS_tools.convert_world2pix(SDS_tools.convert_epsg(shoreline,
                                                                            settings['output_epsg'],
                                                                            image_epsg)[:,[0,1]], georef)
            else: sl_pix = np.array([[np.nan, np.nan],[np.nan, np.nan]])
            # update the plotted shorelines
            sl_plot1[0].set_data([sl_pix[:,0], sl_pix[:,1]])
            sl_plot2[0].set_data([sl_pix[:,0], sl_pix[:,1]])
            sl_plot3[0].set_data([sl_pix[:,0], sl_pix[:,1]])
            fig.canvas.draw_idle()
        else:
            ax4.set_title('MNDWI pixel intensities and threshold')
            break
    
    # let user manually accept/reject the image
    skip_image = False
    # set a key event to accept/reject the detections (see https://stackoverflow.com/a/15033071)
    # this variable needs to be immuatable so we can access it after the keypress event
    key_event = {}
    def press(event):
        # store what key was pressed in the dictionary
        key_event['pressed'] = event.key
    # let the user press a key, right arrow to keep the image, left arrow to skip it
    # to break the loop the user can press 'escape'
    while True:
        btn_keep = plt.text(1.1, 0.9, 'keep ⇨', size=12, ha="right", va="top",
                            transform=ax1.transAxes,
                            bbox=dict(boxstyle="square", ec='k',fc='w'))
        btn_skip = plt.text(-0.1, 0.9, '⇦ skip', size=12, ha="left", va="top",
                            transform=ax1.transAxes,
                            bbox=dict(boxstyle="square", ec='k',fc='w'))
        btn_esc = plt.text(0.5, 0, '<esc> to quit', size=12, ha="center", va="top",
                            transform=ax1.transAxes,
                            bbox=dict(boxstyle="square", ec='k',fc='w'))
        plt.draw()
        fig.canvas.mpl_connect('key_press_event', press)
        plt.waitforbuttonpress()
        # after button is pressed, remove the buttons
        btn_skip.remove()
        btn_keep.remove()
        btn_esc.remove()

        # keep/skip image according to the pressed key, 'escape' to break the loop
        if key_event.get('pressed') == 'right':
            skip_image = False
            break
        elif key_event.get('pressed') == 'left':
            skip_image = True
            break
        elif key_event.get('pressed') == 'escape':
            plt.close()
            raise StopIteration('User cancelled checking shoreline detection')
        else:
            plt.waitforbuttonpress()

    # if save_figure is True, save a .jpg under /jpg_files/detection
    if settings['save_figure'] and not skip_image:
        fig.savefig(os.path.join(filepath, date + '_' + satname + '.jpg'), dpi=150)

    # don't close the figure window, but remove all axes and settings, ready for next plot
    for ax in fig.axes:
        ax.clear()

    return skip_image, shoreline, t_mndwi
示例#7
0
def evaluate_classifier(classifier, metadata, settings):
    """
    Interactively visualise the new classifier.

    KV WRL 2018

    Arguments:
    -----------
        classifier: joblib object
            Multilayer Perceptron to be used for image classification
        metadata: dict
            contains all the information about the satellite images that were downloaded
        settings: dict
            contains the following fields:
        cloud_thresh: float
            value between 0 and 1 indicating the maximum cloud fraction in the image that is accepted
        sitename: string
            name of the site (also name of the folder where the images are stored)
        cloud_mask_issue: boolean
            True if there is an issue with the cloud mask and sand pixels are being masked on the images
        labels: dict
            the label name (key) and label number (value) for each class
        filepath_train: str
            directory in which to save the labelled data

    Returns:
    -----------

    """
    # create folder
    fp = os.path.join(os.getcwd(), 'evaluation')
    if not os.path.exists(fp):
        os.makedirs(fp)

    # initialize figure
    fig, ax = plt.subplots(1,
                           2,
                           figsize=[17, 10],
                           sharex=True,
                           sharey=True,
                           constrained_layout=True)
    mng = plt.get_current_fig_manager()
    mng.window.showMaximized()

    # create colormap for labels
    cmap = cm.get_cmap('tab20c')
    colorpalette = cmap(np.arange(0, 13, 1))
    colours = np.zeros((3, 4))
    colours[0, :] = colorpalette[5]
    colours[1, :] = np.array([204 / 255, 1, 1, 1])
    colours[2, :] = np.array([0, 91 / 255, 1, 1])
    # loop through satellites
    for satname in metadata.keys():
        filepath = SDS_tools.get_filepath(settings['inputs'], satname)
        filenames = metadata[satname]['filenames']

        # load classifiers and
        if satname in ['L5', 'L7', 'L8']:
            pixel_size = 15
        elif satname == 'S2':
            pixel_size = 10
        # convert settings['min_beach_area'] and settings['buffer_size'] from metres to pixels
        buffer_size_pixels = np.ceil(settings['buffer_size'] / pixel_size)
        min_beach_area_pixels = np.ceil(settings['min_beach_area'] /
                                        pixel_size**2)

        # loop through images
        for i in range(len(filenames)):
            # image filename
            fn = SDS_tools.get_filenames(filenames[i], filepath, satname)
            # read and preprocess image
            im_ms, georef, cloud_mask, im_extra, im_QA, im_nodata = SDS_preprocess.preprocess_single(
                fn, satname, settings['cloud_mask_issue'])
            image_epsg = metadata[satname]['epsg'][i]
            # calculate cloud cover
            cloud_cover = np.divide(
                sum(sum(cloud_mask.astype(int))),
                (cloud_mask.shape[0] * cloud_mask.shape[1]))
            # skip image if cloud cover is above threshold
            if cloud_cover > settings['cloud_thresh']:
                continue
            # calculate a buffer around the reference shoreline (if any has been digitised)
            im_ref_buffer = SDS_shoreline.create_shoreline_buffer(
                cloud_mask.shape, georef, image_epsg, pixel_size, settings)
            # classify image in 4 classes (sand, whitewater, water, other) with NN classifier
            im_classif, im_labels = SDS_shoreline.classify_image_NN(
                im_ms, im_extra, cloud_mask, min_beach_area_pixels, classifier)
            # there are two options to map the contours:
            # if there are pixels in the 'sand' class --> use find_wl_contours2 (enhanced)
            # otherwise use find_wl_contours2 (traditional)
            try:  # use try/except structure for long runs
                if sum(sum(im_labels[:, :, 0])) < 10:
                    # compute MNDWI image (SWIR-G)
                    im_mndwi = SDS_tools.nd_index(im_ms[:, :, 4],
                                                  im_ms[:, :, 1], cloud_mask)
                    # find water contours on MNDWI grayscale image
                    contours_mwi = SDS_shoreline.find_wl_contours1(
                        im_mndwi, cloud_mask, im_ref_buffer)
                else:
                    # use classification to refine threshold and extract the sand/water interface
                    contours_wi, contours_mwi = SDS_shoreline.find_wl_contours2(
                        im_ms, im_labels, cloud_mask, buffer_size_pixels,
                        im_ref_buffer)
            except:
                print('Could not map shoreline for this image: ' +
                      filenames[i])
                continue
            # process the water contours into a shoreline
            shoreline = SDS_shoreline.process_shoreline(
                contours_mwi, cloud_mask, georef, image_epsg, settings)
            try:
                sl_pix = SDS_tools.convert_world2pix(
                    SDS_tools.convert_epsg(shoreline, settings['output_epsg'],
                                           image_epsg)[:, [0, 1]], georef)
            except:
                # if try fails, just add nan into the shoreline vector so the next parts can still run
                sl_pix = np.array([[np.nan, np.nan], [np.nan, np.nan]])
            # make a plot
            im_RGB = SDS_preprocess.rescale_image_intensity(
                im_ms[:, :, [2, 1, 0]], cloud_mask, 99.9)
            # create classified image
            im_class = np.copy(im_RGB)
            for k in range(0, im_labels.shape[2]):
                im_class[im_labels[:, :, k], 0] = colours[k, 0]
                im_class[im_labels[:, :, k], 1] = colours[k, 1]
                im_class[im_labels[:, :, k], 2] = colours[k, 2]
            # show images
            ax[0].imshow(im_RGB)
            ax[1].imshow(im_RGB)
            ax[1].imshow(im_class, alpha=0.5)
            ax[0].axis('off')
            ax[1].axis('off')
            filename = filenames[i][:filenames[i].find('.')][:-4]
            ax[0].set_title(filename)
            ax[0].plot(sl_pix[:, 0], sl_pix[:, 1], 'k.', markersize=3)
            ax[1].plot(sl_pix[:, 0], sl_pix[:, 1], 'k.', markersize=3)
            # save figure
            fig.savefig(os.path.join(
                fp, settings['inputs']['sitename'] + filename[:19] + '.jpg'),
                        dpi=150)
            # clear axes
            for cax in fig.axes:
                cax.clear()

    # close the figure at the end
    plt.close()