예제 #1
0
def get_spatial_ref_from_wkt(wkt_or_crs_name):
    '''
    Function to return SpatialReference object for supplied WKT
    @param wkt: Well-known text or CRS name for SpatialReference, including "EPSG:XXXX"
    @return spatial_ref: SpatialReference from WKT
    '''
    if not wkt_or_crs_name:
        return None

    spatial_ref = SpatialReference()

    result = spatial_ref.SetFromUserInput(wkt_or_crs_name)
    if not result:
        logger.debug(
            'CRS determined using SpatialReference.SetFromUserInput({})'.
            format(wkt_or_crs_name))
        return spatial_ref

    # Try to resolve WKT
    result = spatial_ref.ImportFromWkt(wkt_or_crs_name)
    if not result:
        logger.debug(
            'CRS determined using SpatialReference.ImportFromWkt({})'.format(
                wkt_or_crs_name))
        return spatial_ref

    # Try to resolve CRS name - either mapped or original
    modified_crs_name = CRS_NAME_MAPPING.get(
        wkt_or_crs_name) or wkt_or_crs_name
    result = spatial_ref.SetWellKnownGeogCS(modified_crs_name)
    if not result:
        logger.debug(
            'CRS determined using SpatialReference.SetWellKnownGeogCS({})'.
            format(modified_crs_name))
        return spatial_ref

    match = re.match('EPSG:(\d+)$', wkt_or_crs_name, re.IGNORECASE)
    if match:
        epsg_code = int(match.group(1))
        result = spatial_ref.ImportFromEPSG(epsg_code)
        if not result:
            logger.debug(
                'CRS determined using SpatialReference.ImportFromEPSG({})'.
                format(epsg_code))
            return spatial_ref

    # Try common formulations for UTM zones
    #TODO: Fix this so it works in the Northern hemisphere
    modified_crs_name = re.sub('\s+', '', wkt_or_crs_name.strip().upper())
    utm_match = (re.match('(\w+)/MGAZONE(\d+)$', modified_crs_name)
                 or re.match('(\w+)/(\d+)S$', modified_crs_name)
                 or re.match('(EPSG:283)(\d{2})$', modified_crs_name)
                 or re.match('(MGA)(\d{2}$)', modified_crs_name))
    if utm_match:
        modified_crs_name = utm_match.group(1)
        modified_crs_name = CRS_NAME_MAPPING.get(
            modified_crs_name) or modified_crs_name
        utm_zone = int(utm_match.group(2))
        result = spatial_ref.SetWellKnownGeogCS(modified_crs_name)
    if not result:
        spatial_ref.SetUTM(
            utm_zone, False
        )  # Put this here to avoid potential side effects in downstream code
        logger.debug(
            'UTM CRS determined using SpatialReference.SetWellKnownGeogCS({}) (zone{})'
            .format(modified_crs_name, utm_zone))
        return spatial_ref

    assert not result, 'Invalid WKT or CRS name: "{}"'.format(wkt_or_crs_name)
예제 #2
0
class HyImage(HyData):
    """
    A class for hyperspectral image data. These can be individual scenes or hyperspectral orthoimages.
    """
    def __init__(self, data, **kwds):
        """
        Create an image object from a data array.

        *Arguments*:
         - data = a numpy array such that data[x][y][band] gives each pixel value.
        *Keywords*:
         - affine = an affine transform of the format returned by GDAL.GetGeoTransform().
         - project = string defining the project. Default is None.
         - sensor = sensor name. Default is "unknown".
         - header = path to associated header file. Default is None.
        """

        #call constructor for HyData
        super().__init__(data, **kwds)

        # special case - if dataset only has oneband, slice it so it still has
        # the format data[x,y,b].
        if not self.data is None:
            if len(self.data.shape) == 2:
                self.data = self.data[:, :, np.newaxis]

        #load any additional project information (specific to images)
        self.set_projection(kwds.get("project", None))
        self.affine = kwds.get("affine", [0, 1, 0, 0, 0, 1])

        #special header formatting
        self.header['file type'] = 'ENVI Standard'

    def copy(self, data=True):
        """
        Make a deep copy of this image instance.
        *Arguments*:
         - data = True if a copy of the data should be made, otherwise only copy header.
        *Returns*
          - a new HyImage instance.
        """
        if not data:
            return HyImage(None,
                           header=self.header.copy(),
                           projection=self.projection,
                           affine=self.affine)
        else:
            return HyImage(self.data.copy(),
                           header=self.header.copy(),
                           projection=self.projection,
                           affine=self.affine)

    def xdim(self):
        """
        Return number of pixels in x (first dimension of data array)
        """
        return self.data.shape[0]

    def ydim(self):
        """
        Return number of pixels in y (second dimension of data array)
        """
        return self.data.shape[1]

    def aspx(self):
        """
        Return the aspect ratio of this image (width/height).
        """
        return self.ydim() / self.xdim()

    def get_extent(self):
        """
        Returns the width and height of this image in world coordinates.

        *Returns*
         - tuple with (width, height).
        """
        return self.xdim * self.pixel_size[0], self.ydim * self.pixel_size[1]

    def set_projection(self, proj):
        """
        Set this project to an existing osgeo.osr.SpatialReference or GDAL georeference string.

        *Arguments*:
         - proj = the project to use as osgeo.osr.SpatialReference or GDAL georeference string.
        """
        try:
            from osgeo.osr import SpatialReference
        except:
            assert False, "Error - GDAL must be installed to work with spatial projections in hylite."

        if proj is None:
            self.projection = None
        elif isinstance(proj, SpatialReference):
            self.projection = proj
        elif isinstance(proj, str):
            self.projection = SpatialReference(proj)
        else:
            print("Invalid project %s" % proj)
            raise

    def set_projection_EPSG(self, EPSG):
        """
        Sets this image project using an EPSG code.

        *Arguments*:
         - EPSG = string EPSG code that can be passed to SpatialReference.SetFromUserInput(...).
        """

        try:
            from osgeo.osr import SpatialReference
        except:
            assert False, "Error - GDAL must be installed to work with spatial projections in hylite."

        self.projection = SpatialReference()
        self.projection.SetFromUserInput(EPSG)

    def get_projection_EPSG(self):
        """
        Gets a string describing this projections EPSG code (if it is an EPSG project).

        *Returns*:
         - an EPSG code string of the format "EPSG:XXXX".
        """
        if self.projection is None:
            return None
        else:
            return "%s:%s" % (self.projection.GetAttrValue(
                "AUTHORITY", 0), self.projection.GetAttrValue("AUTHORITY", 1))

    def pix_to_world(self, px, py, proj=None):
        """
        Take pixel coordinates and return world coordinates

        *Arguments*:
         - px = the pixel x-coord.
         - py = the pixel y-coord.
         - proj = the coordinate system to use. Default (None) uses the same system as this image. Otherwise
                   an osr.SpatialReference can be passed (HyImage.project), or an EPSG string (e.g. get_projection_EPSG(...)).
        *Returns*:
         - the world coordinates in the coordinate system defined by get_projection_EPSG(...).
        """

        try:
            from osgeo import osr
            import osgeo.gdal as gdal
            from osgeo import ogr
        except:
            assert False, "Error - GDAL must be installed to work with spatial projections in hylite."

        # parse project
        if proj is None:
            proj = self.projection
        elif isinstance(proj, str) or isinstance(proj, int):
            epsg = proj
            if isinstance(epsg, str):
                try:
                    epsg = int(str.split(':')[1])
                except:
                    assert False, "Error - %s is an invalid EPSG code." % proj
            proj = osr.SpatialReference()
            proj.ImportFromEPSG(epsg)

        # check we have all the required info
        assert isinstance(proj, osr.SpatialReference
                          ), "Error - invalid spatial reference %s" % proj
        assert (not self.affine is None) and (
            not self.projection is None
        ), "Error - project information is undefined."

        #project to world coordinates in this images project/world coords
        x, y = gdal.ApplyGeoTransform(self.affine, px, py)

        #project to target coords (if different)
        if not proj.IsSameGeogCS(self.projection):
            P = ogr.Geometry(ogr.wkbPoint)
            if proj.EPSGTreatsAsNorthingEasting():
                P.AddPoint(x, y)
            else:
                P.AddPoint(y, x)
            P.AssignSpatialReference(
                self.projection)  # tell the point what coordinates it's in
            P.TransformTo(proj)  # reproject it to the out spatial reference
            x, y = P.GetX(), P.GetY()

            #do we need to transpose?
            if proj.EPSGTreatsAsLatLong():
                x, y = y, x  #we want lon,lat not lat,lon
        return x, y

    def world_to_pix(self, x, y, proj=None):
        """
        Take world coordinates and return pixel coordinates

        *Arguments*:
         - x = the world x-coord.
         - y = the world y-coord.
         - proj = the coordinate system of the input coordinates. Default (None) uses the same system as this image. Otherwise
                   an osr.SpatialReference can be passed (HyImage.project), or an EPSG string (e.g. get_projection_EPSG(...)).

        *Returns*:
         - the pixel coordinates based on the affine transform stored in self.affine.
        """

        try:
            from osgeo import osr
            import osgeo.gdal as gdal
            from osgeo import ogr
        except:
            assert False, "Error - GDAL must be installed to work with spatial projections in hylite."

        # parse project
        if proj is None:
            proj = self.projection
        elif isinstance(proj, str) or isinstance(proj, int):
            epsg = proj
            if isinstance(epsg, str):
                try:
                    epsg = int(str.split(':')[1])
                except:
                    assert False, "Error - %s is an invalid EPSG code." % proj
            proj = osr.SpatialReference()
            proj.ImportFromEPSG(epsg)

        # check we have all the required info
        assert isinstance(proj, osr.SpatialReference
                          ), "Error - invalid spatial reference %s" % proj
        assert (not self.affine is None) and (
            not self.projection is None
        ), "Error - project information is undefined."

        # project to this images CS (if different)
        if not proj.IsSameGeogCS(self.projection):
            P = ogr.Geometry(ogr.wkbPoint)
            if proj.EPSGTreatsAsNorthingEasting():
                P.AddPoint(x, y)
            else:
                P.AddPoint(y, x)
            P.AssignSpatialReference(
                proj)  # tell the point what coordinates it's in
            P.AddPoint(x, y)
            P.TransformTo(
                self.projection)  # reproject it to the out spatial reference
            x, y = P.GetX(), P.GetY()
            if self.projection.EPSGTreatsAsLatLong(
            ):  # do we need to transpose?
                x, y = y, x  # we want lon,lat not lat,lon

        inv = gdal.InvGeoTransform(self.affine)
        assert not inv is None, "Error - could not invert affine transform?"

        #apply
        return gdal.ApplyGeoTransform(inv, x, y)

    def flip(self, axis='x'):
        """
        Flip the image on the x or y axis.

        *Arguments*:
        - axis = 'x' or 'y' or both 'xy'.
        """

        if 'x' in axis.lower():
            self.data = np.flip(self.data, axis=0)
        if 'y' in axis.lower():
            self.data = np.flip(self.data, axis=1)

    def rot90(self):
        """
        Rotate this image by 90 degrees by transposing the underlying data array. Combine with flip('x') or flip('y')
        to achieve positive/negative rotations.
        """
        self.data = np.transpose(self.data, (1, 0, 2))
        self.push_to_header()

    #####################################
    ##IMAGE FILTERING
    #####################################
    def fill_holes(self):
        """
        Replaces nan pixel with an average of their neighbours, thus removing 1-pixel large holes from an image. Note that
        for performance reasons this assumes that holes line up across bands. Note that this is not vectorized so very slow...
        """

        # perform greyscale dilation
        dilate = self.data.copy()
        mask = np.logical_not(np.isfinite(dilate))
        dilate[mask] = 0
        for b in range(self.band_count()):
            dilate[:, :, b] = sp.ndimage.grey_dilation(dilate[:, :, b],
                                                       size=(3, 3))

        # map back to holes in dataset
        self.data[mask] = dilate[mask]
        #self.data[self.data == 0] = np.nan  # replace remaining 0's with nans

    def blur(self, n=3):
        """
        Applies a gaussian kernel of size n to the image using OpenCV.

        *Arguments*:
         - n = the dimensions of the gaussian kernel to convolve. Default is 3. Increase for more blurry results.
        """

        nanmask = np.isnan(self.data)
        assert isinstance(
            n, int
        ) and n >= 3, "Error - invalid kernel. N must be an integer > 3. "
        kernel = np.ones((n, n), np.float32) / (n**2)
        self.data = cv2.filter2D(self.data, -1, kernel)
        self.data[nanmask] = np.nan  # remove mask

    def erode(self, size=3, iterations=1):
        """
        Apply an erode filter to this image to expand background (nan) pixels. Refer to open-cv's erode
        function for more details.

        *Arguments*:
         - size = the size of the erode filter. Default is a 3x3 kernel.
         - iterations = the number of erode iterations. Default is 1.
        """

        # erode
        kernel = np.ones((size, size), np.uint8)
        if self.is_float():
            mask = np.isfinite(self.data).any(axis=-1)
            mask = cv2.erode(mask.astype(np.uint8),
                             kernel,
                             iterations=iterations)
            self.data[mask == 0, :] = np.nan
        else:
            mask = (self.data != 0).any(axis=-1)
            mask = cv2.erode(mask.astype(np.uint8),
                             kernel,
                             iterations=iterations)
            self.data[mask == 0, :] = 0

    def resize(self, newdims, interpolation=cv2.INTER_LINEAR):
        """
        Resize this image with opencv.

        *Arguments*:
         - newdims = the new image dimensions.
        """

        self.data = cv2.resize(self.data, (newdims[1], newdims[0]),
                               interpolation=interpolation)

    def despeckle(self, size=5):
        """
        Despeckle each band of this image (independently) using a median filter.

        *Arguments*:
         - size = the size of the median filter kernel. Default is 5. Must be an odd number.
        """

        assert (size % 2) == 1, "Error - size must be an odd integer"

        if self.is_float():
            self.data = cv2.medianBlur(self.data.astype(np.float32), size)
        else:
            self.data = cv2.medianBlur(self.data, size)

    #####################################
    ##FEATURES AND FEATURE MATCHING
    ######################################
    def get_keypoints(self,
                      band,
                      eq=False,
                      mask=True,
                      method='sift',
                      cfac=0.0,
                      bfac=0.0,
                      **kwds):
        """
         Get feature descriptors from the specified band.

         *Arguments*:
          - band = the band index (int) or wavelength (float) to extract features from. Alternatively, a tuple can be passed
                    containing a range of bands (min : max) to average before feature matching.
          - eq = True if the image should be histogram equalized first. Default is False.
          - mask = True if 0 value pixels should be masked. Default is True.
          - method = the feature detector to use. Options are 'SIFT' and 'ORB' (faster but less accurate). Default is 'SIFT'.
          - cfac = contrast adjustment to apply to hyperspectral bands before matching. Default is 0.0.
          - bfac = brightness adjustment to apply to hyperspectral bands before matching. Default is 0.0.

         *Keywords*:
          - keyword arguments are passed to the opencv feature detector. For SIFT these are:
                - contrastThreshold: default is 0.01.
                - edgeThreshold: default is 10.
                - sigma: default is 1.0

            For ORB these are:
                - nfeatures = the number of features to detect. Default is 5000.

         *Returns*:
          - k, d = the keypoints detected and corresponding feature descriptors
         """

        # get image
        if isinstance(band, int) or isinstance(band, float):  #single band
            image = self.data[:, :, self.get_band_index(band)]
        elif isinstance(band, tuple):  #range of bands (averaged)
            idx0 = self.get_band_index(band[0])
            idx1 = self.get_band_index(band[1])

            #deal with out of range errors
            if idx0 is None:
                idx0 = 0
            if idx1 is None:
                idx1 = self.band_count()

            #average bands
            image = np.nanmean(self.data[:, :, idx0:idx1], axis=2)
        else:
            assert False, "Error, unrecognised band %s" % band

        #normalise image to range 0 - 1
        image -= np.nanmin(image)
        image = image / np.nanmax(image)

        #apply brightness/contrast adjustment
        image = (1.0 + cfac) * image + bfac
        image[image > 1.0] = 1.0
        image[image < 0.0] = 0.0

        #convert image to uint8 for opencv
        image = np.uint8(255 * image)
        if eq:
            image = cv2.equalizeHist(image)

        if mask:
            mask = np.zeros(image.shape, dtype=np.uint8)
            mask[image != 0] = 255  # include only non-zero pixels
        else:
            mask = None

        if 'sift' in method.lower():  # SIFT

            # setup default keywords
            kwds["contrastThreshold"] = kwds.get("contrastThreshold", 0.01)
            kwds["edgeThreshold"] = kwds.get("edgeThreshold", 10)
            kwds["sigma"] = kwds.get("sigma", 1.0)

            # make feature detector
            alg = cv2.xfeatures2d.SIFT_create(**kwds)
        elif 'orb' in method.lower():  # orb
            kwds['nfeatures'] = kwds.get('nfeatures', 5000)
            alg = cv2.ORB_create(scoreType=cv2.ORB_FAST_SCORE, **kwds)
        else:
            assert False, "Error - %s is not a recognised feature detector." % method

        # detect keypoints
        kp = alg.detect(image, mask)

        # extract and return feature vectors
        return alg.compute(image, kp)

    @classmethod
    def match_keypoints(cls,
                        kp1,
                        kp2,
                        d1,
                        d2,
                        method='SIFT',
                        dist=0.7,
                        tree=5,
                        check=100,
                        min_count=5):
        """
        Compares keypoint feature vectors from two images and returns matching pairs.

        *Arguments*:
         - kp1 = keypoints from the first image
         - kp2 = keypoints from the second image
         - d1 = descriptors for the keypoints from the first image
         - d2 = descriptors for the keypoints from the second image
         - method = the method used to calculate the feature descriptors. Should be 'sift' or 'orb'. Default is 'sift'.
         - dist = minimum match distance (0 to 1), default is 0.7
         - tree  = ??
         - check = 100 ?? Default is 100.
         - min_count = the minimum number of matches to consider a valid matching operation. If fewer matches are found,
                       then the function returns None, None. Default is 5.
        """

        if 'sift' in method.lower():
            algorithm = cv2.NORM_INF
        elif 'orb' in method.lower():
            algorithm = cv2.NORM_HAMMING
        else:
            assert False, "Error - unknown matching algorithm %s" % method

        #calculate flann matches
        index_params = dict(algorithm=algorithm, trees=tree)
        search_params = dict(checks=check)
        flann = cv2.FlannBasedMatcher(index_params, search_params)
        matches = flann.knnMatch(d1, d2, k=2)

        # store all the good matches as per Lowe's ratio test.
        good = []
        for m, n in matches:
            if m.distance < dist * n.distance:
                good.append(m)

        if len(good) < min_count:
            return None, None
        else:
            src_pts = np.float32([kp1[m.queryIdx].pt
                                  for m in good]).reshape(-1, 1, 2)
            dst_pts = np.float32([kp2[m.trainIdx].pt
                                  for m in good]).reshape(-1, 1, 2)
            return src_pts, dst_pts

    ############################
    ## Visualisation methods
    ############################
    def quick_plot(self,
                   band=0,
                   ax=None,
                   bfac=0.0,
                   cfac=0.0,
                   samples=False,
                   **kwds):
        """
        Plot a band using matplotlib.imshow(...).

        *Arguments*:
         - band = the band name (string), index (integer) or wavelength (float) to plot. Default is 0. If a tuple is passed then
                  each band in the tuple (string or index) will be mapped to rgb.
         - ax = an axis object to plot to. If none, plt.imshow( ... ) is used.
         - bfac = a brightness adjustment to apply to RGB mappings (-1 to 1)
         - cfac = a contrast adjustment to apply to RGB mappings (-1 to 1)
         - samples = True if sample points (defined in the header file) should be plotted. Default is False.
        *Keywords*:
         - keywords are passed to matplotlib.imshow( ... ).

         Additional special keywords include:
         - mask = a 2 boolean mask containing true if pixels should be drawn and false otherwise.
        *Returns*:
         - fig, ax = the figure and axes object created (or passed through the ax keyword). If a colorbar is created,
                     (band is an integer or a float), then this will be stored in ax.cbar.
        """

        #create new axes?
        if ax is None:
            fig, ax = plt.subplots(figsize=(18,
                                            18 * self.ydim() / self.xdim()))

        #map individual band using colourmap
        if isinstance(band, str) or isinstance(band, int) or isinstance(
                band, float):
            #get band
            data = self.data[:, :, self.get_band_index(band)]

            #mask nans (and apply custom mask)
            mask = np.isnan(data)
            if not np.isnan(self.header.get_data_ignore_value()):
                mask = mask + data == self.header.get_data_ignore_value()
            if 'mask' in kwds:
                mask = mask + kwds.get('mask')
                del kwds['mask']
            data = np.ma.array(data, mask=mask > 0)
            ax.cbar = ax.imshow(data.T, **kwds)

        #map 3 bands to RGB
        elif isinstance(band, tuple) or isinstance(band, list):
            #get band indices and range
            rgb = []
            for b in band:
                rgb.append(self.get_band_index(b))

            #slice image (as copy) and map to 0 - 1
            img = np.array(self.data[:, :, rgb])
            if np.isnan(img).all():
                print("Warning - image contains no data.")
                return ax.get_figure(), ax

            mn = kwds.get("vmin", np.nanmin(img))
            mx = kwds.get("vmax", np.nanmax(img))
            img = (img - mn) / (mx - mn)

            #apply brightness/contrast mapping
            img = (1.0 + cfac) * img + bfac
            img[img > 1.0] = 1.0
            img[img < 0.0] = 0.0

            #apply masking so background is white
            img[np.logical_not(np.isfinite(img))] = 1.0
            if 'mask' in kwds:
                img[kwds.get("mask"), :] = 1.0
                del kwds['mask']

            #plot
            ax.imshow(np.transpose(img, (1, 0, 2)), **kwds)
            ax.cbar = None  # no colorbar

        # plot samples?
        if samples:
            for n in self.header.get_class_names():
                points = np.array(self.header.get_sample_points(n))
                ax.scatter(points[:, 0], points[:, 1], s=4)

        return ax.get_figure(), ax

    def createGIF(self, path, bands=None, figsize=(10, 10), fps=10, **kwds):
        """
        Create and save an animated gif that loops through the bands of the image.

        *Arguments*:
         - path = the path to save the .gif
         - bands = Tuple containing the range of band indices to draw. Default is the whole range.
         - figsize = the size of the image to draw. Default is (10,10).
         - fps = the framerate (frames per second) of the gif. Default is 10.
        *Keywords*:
         - keywords are passed directly to matplotlib.imshow. Use this to specify cmap etc.
        """

        frames = []
        if bands is None:
            bands = (0, self.band_count())
        else:
            assert 0 < bands[0] < self.band_count(), "Error - invalid range."
            assert 0 < bands[1] < self.band_count(), "Error - invalid range."
            assert bands[1] > bands[0], "Error - invalid range."

        #plot frames
        for i in range(bands[0], bands[1]):
            fig, ax = plt.subplots(figsize=figsize)
            ax.imshow(self.data[:, :, i], **kwds)
            fig.canvas.draw()
            frames.append(
                np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8'))
            frames[-1] = np.reshape(frames[-1],
                                    (fig.canvas.get_width_height()[1],
                                     fig.canvas.get_width_height()[0], 3))
            plt.close(fig)

        #save gif
        imageio.mimsave(os.path.splitext(path)[0] + ".gif", frames, fps=fps)

    ## masking
    def mask(self,
             mask=None,
             flag=np.nan,
             invert=False,
             crop=False,
             bands=None):
        """
         Apply a mask to an image, flagging masked pixels with the specified value. Note that this applies the mask to the
         image in-situ.

         *Arguments*:
          - flag = the value to use for masked pixels. Default is np.nan
          - mask = a numpy array defining the mask polygon of the format [[x1,y1],[x2,y2],...]. If None is passed then
                    pickPolygon( ... ) is used to interactively define a polygon. If a file path is passed then the polygon
                    will be loaded using np.load( ... ). Alternatively if mask.shape == image.shape[0,1] then it is treated as a
                    binary image mask (must be boolean) and True values will be masked across all bands. Default is None.
          - invert = if True, pixels within the polygon will be masked. If False, pixels outside the polygon are masked. Default is False.
          - crop = True if rows/columns containing only zeros should be removed. Default is False.
          - bands = the bands of the image to plot if no mask is specified. If None, the middle band is used.

         *Returns*:
          - mask = a boolean array with True where pixels are masked and False elsewhere.
          - poly = the mask polygon array in the format described above. Useful if the polygon was interactively defined.
         """

        if mask is None:  # pick mask interactively
            if bands is None:
                bands = int(self.band_count() / 2)

            regions = self.pickPolygons(region_names=["mask"], bands=bands)

            # the user bailed without picking a mask?
            if len(regions) == 0:
                print("Warning - no mask picked/applied.")
                return

            # extract polygon mask
            mask = regions[0]

        # convert polygon mask to binary mask
        if mask.shape[1] == 2:

            # build meshgrid with pixel coords
            xx, yy = np.meshgrid(np.arange(self.xdim()),
                                 np.arange(self.ydim()))
            xx = xx.flatten()
            yy = yy.flatten()
            points = np.vstack([xx, yy]).T  # coordinates of each pixel

            # calculate per-pixel mask
            mask = path.Path(mask).contains_points(points)
            mask = mask.reshape((self.ydim(), self.xdim())).T

            # flip as we want to mask (==True) outside points (unless invert is true)
            if not invert:
                mask = np.logical_not(mask)

        # apply binary image mask
        assert mask.shape[0] == self.data.shape[0] and mask.shape[1] == self.data.shape[1], \
            "Error - mask shape %s does not match image shape %s" % (mask.shape, self.data.shape)
        for b in range(self.band_count()):
            self.data[:, :, b][mask] = flag

        # crop image
        if crop:
            # calculate non-masked pixels
            valid = np.logical_not(mask)

            # integrate along axes
            xdata = np.sum(valid, axis=1) > 0.0
            ydata = np.sum(valid, axis=0) > 0.0

            # calculate domain containing valid pixels
            xmin = np.argmax(xdata)
            xmax = xdata.shape[0] - np.argmax(xdata[::-1])
            ymin = np.argmax(ydata)
            ymax = ydata.shape[0] - np.argmax(ydata[::-1])

            # crop
            self.data = self.data[xmin:xmax, ymin:ymax, :]

        return mask

    ##################################################
    ## Interactive tools for picking regions/pixels
    ##################################################
    def pickPolygons(self, region_names, bands=0):
        """
        Creates a matplotlib gui for selecting polygon regions in an image.

        *Arguments*:
         - image = the image to pick on
         - names = a list containing the names of the regions to pick. If a string is passed only one name is used.
         - bands = the bands of the image to plot.
        """

        if isinstance(region_names, str):
            region_names = [region_names]

        assert isinstance(region_names,
                          list), "Error - names must be a list or a string."

        # set matplotlib backend
        backend = matplotlib.get_backend()
        matplotlib.use('Qt5Agg')  # need this backend for ROIPoly to work

        # plot image and extract roi's
        fig, ax = self.quick_plot(bands)
        roi = MultiRoi(roi_names=region_names)
        plt.close(fig)  # close figure

        # extract regions
        regions = []
        for name, r in roi.rois.items():
            # store region
            x = r.x
            y = r.y
            regions.append(np.vstack([x, y]).T)

        # restore matplotlib backend (if possible)
        try:
            matplotlib.use(backend)
        except:
            print(
                "Warning: could not reset matplotlib backend. Plots will remain interactive..."
            )
            pass

        return regions

    def pickPoints(self,
                   n=-1,
                   bands=hylite.RGB,
                   integer=True,
                   title="Pick Points",
                   **kwds):
        """
        Creates a matplotlib gui for picking pixels from an image.

        *Arguments*:
         - n = the number of pixels to pick, or -1 if the user can select as many as they wish. Default is -1.
         - bands = the bands of the image to plot. Default is HyImage.RGB
         - integer = True if points coordinates should be cast to integers (for use as indices). Default is True.
         - title = The title of the point picking window.
        *Keywords*: Keywords are passed to HyImage.quick_plot( ... )
        *Returns*: A list containing the picked point coordinates [ (x1,y1), (x2,y2), ... ].
        """

        # set matplotlib backend
        backend = matplotlib.get_backend()
        matplotlib.use('Qt5Agg')  # need this backend for ROIPoly to work

        # create figure
        fig, ax = self.quick_plot(bands, **kwds)
        ax.set_title(title)

        # get points
        points = fig.ginput(n)

        if integer:
            points = [(int(p[0]), int(p[1])) for p in points]

        # restore matplotlib backend (if possible)
        try:
            matplotlib.use(backend)
        except:
            print(
                "Warning: could not reset matplotlib backend. Plots will remain interactive..."
            )
            pass

        return points

    def pickSamples(self, names=None, store=True, **kwds):
        """
        Pick sample probe points and store these in the image header file.

        *Arguments*:
         - names = the name of the sample to pick, or a list of names to pick multiple.
         - store = True if sample should be stored in the image header file (for later access). Default is True.
        *Keywords*: Keywords are passed to HyImage.quick_plot( ... )
        *Returns*:  a list containing a list of points for each sample.
        """

        if isinstance(names, str):
            names = [names]

        # pick points
        points = []
        for s in names:
            pnts = self.pickPoints(title="%s" % s, **kwds)
            if store:
                self.header['sample %s' % s] = pnts  # store in header
            points.append(pnts)
        # add class to header file
        if store:
            cls_names = self.header.get_class_names()
            if cls_names is None:
                cls_names = []
            self.header['class names'] = cls_names + names

        return points

    def getSpectralLibrary(self, samples=None, names=None, s=8):
        """
        Extract a spectral library by sampling and averaging pixels within the specified distance of sample points.

        *Arguments*:
         - samples = a list of samples to use, as returned by pickSamples(...). If None (default) then sample points
                     are pulled from the header file (if defined).
         - Array of names corresponding to the samples (or None).
         - s = the number of pixels on either side of the sample point to extract. Default is 8.

        *Returns*: a HyLibrary instance
        """

        if samples is None:
            names = self.header.get_class_names()
            samples = [self.header.get_sample_points(n) for n in names]

        refl = []
        upper = []
        lower = []
        for n, sample in zip(names, samples):
            spectra = []
            for p in sample:
                spectra.append(
                    self.data[max(p[0] - s, 0):min(p[0] +
                                                   s, self.data.shape[0]),
                              max(p[1] -
                                  s, 0):min(p[1] +
                                            s, self.data.shape[1])].reshape(
                                                -1, self.band_count()))
            spectra = np.vstack(spectra)
            l, m, u = np.nanpercentile(spectra, (25, 50, 75), axis=0)
            lower.append(l)
            refl.append(m)
            upper.append(u)

        return HyLibrary(names,
                         np.vstack(refl),
                         lower=np.vstack(lower),
                         upper=np.vstack(upper),
                         wav=self.get_wavelengths())