Ejemplo n.º 1
0
    def to_vectorfield(self, path_out=None, fmt=None, mode='md'):
        # type: (str, str, str) -> GeoArray
        """Save the calculated X-/Y-shifts to a 2-band raster file that can be used to visualize a vectorfield.

        NOTE: For example ArcGIS is able to visualize such 2-band raster files as a vectorfield.

        :param path_out:    <str> the output path. If not given, it is automatically defined.
        :param fmt:         <str> output raster format string
        :param mode:        <str> The mode how the output is written ('uv' or 'md'; default: 'md')
                                    'uv': outputs X-/Y shifts
                                    'md': outputs magnitude and direction
        """
        assert mode in ['uv', 'md'], "'mode' must be either 'uv' (outputs X-/Y shifts) or 'md' " \
                                     "(outputs magnitude and direction)'. Got %s." % mode
        attr_b1 = 'X_SHIFT_M' if mode == 'uv' else 'ABS_SHIFT'
        attr_b2 = 'Y_SHIFT_M' if mode == 'uv' else 'ANGLE'

        xshift_arr, gt, prj = points_to_raster(points=self.CoRegPoints_table['geometry'],
                                               values=self.CoRegPoints_table[attr_b1],
                                               tgt_res=self.shift.xgsd * self.grid_res,
                                               prj=self.CoRegPoints_table.crs.to_wkt(),
                                               fillVal=self.outFillVal)

        yshift_arr, gt, prj = points_to_raster(points=self.CoRegPoints_table['geometry'],
                                               values=self.CoRegPoints_table[attr_b2],
                                               tgt_res=self.shift.xgsd * self.grid_res,
                                               prj=self.CoRegPoints_table.crs.to_wkt(),
                                               fillVal=self.outFillVal)

        out_GA = GeoArray(np.dstack([xshift_arr, yshift_arr]), gt, prj, nodata=self.outFillVal)

        path_out = path_out if path_out else \
            get_generic_outpath(dir_out=os.path.join(self.dir_out, 'CoRegPoints'),
                                fName_out="CoRegVectorfield%s_ws(%s_%s)__T_%s__R_%s.tif"
                                          % (self.grid_res, self.COREG_obj.win_size_XY[0],
                                             self.COREG_obj.win_size_XY[1], self.shift.basename, self.ref.basename))

        out_GA.save(path_out, fmt=fmt if fmt else 'Gtiff')

        return out_GA
Ejemplo n.º 2
0
class RefCube(object):
    """Data model class for reference cubes holding the training data for later fitted machine learning classifiers."""
    def __init__(self,
                 filepath='',
                 satellite='',
                 sensor='',
                 LayerBandsAssignment=None):
        # type: (str, str, str, list) -> None
        """Get instance of RefCube.

        :param filepath:                file path for importing an existing reference cube from disk
        :param satellite:               the satellite for which the reference cube holds its spectral data
        :param sensor:                  the sensor for which the reference cube holds its spectral data
        :param LayerBandsAssignment:    the LayerBandsAssignment for which the reference cube holds its spectral data
        """
        # privates
        self._col_imName_dict = dict()
        self._wavelenths = []

        # defaults
        self.data = GeoArray(np.empty(
            (0, 0, len(LayerBandsAssignment) if LayerBandsAssignment else 0)),
                             nodata=-9999)
        self.srcImNames = []

        # args/ kwargs
        self.filepath = filepath
        self.satellite = satellite
        self.sensor = sensor
        self.LayerBandsAssignment = LayerBandsAssignment or []

        if filepath:
            self.read_data_from_disk(filepath)

        if self.satellite and self.sensor and self.LayerBandsAssignment:
            self._add_bandnames_wavelenghts_to_meta()

    def _add_bandnames_wavelenghts_to_meta(self):
        # set bandnames
        self.data.bandnames = [
            'Band %s' % b for b in self.LayerBandsAssignment
        ]

        # set wavelengths
        self.data.metadata.band_meta['wavelength'] = self.wavelengths

    @property
    def n_images(self):
        """Return the number training images from which the reference cube contains spectral samples."""
        return self.data.shape[1]

    @property
    def n_signatures(self):
        """Return the number spectral signatures per training image included in the reference cube."""
        return self.data.shape[0]

    @property
    def n_clusters(self):
        """Return the number spectral clusters used for clustering source images for the reference cube."""
        if self.filepath:
            identifier = re.search('refcube__(.*).bsq',
                                   os.path.basename(self.filepath)).group(1)
            return int(identifier.split('__')[2].split('nclust')[1])

    @property
    def n_signatures_per_cluster(self):
        if self.n_clusters:
            return self.n_signatures // self.n_clusters

    @property
    def col_imName_dict(self):
        # type: () -> OrderedDict
        """Return an ordered dict containing the file base names of the original training images for each column."""
        return OrderedDict(
            (col, imName)
            for col, imName in zip(range(self.n_images), self.srcImNames))

    @col_imName_dict.setter
    def col_imName_dict(self, col_imName_dict):
        # type: (dict) -> None
        self._col_imName_dict = col_imName_dict
        self.srcImNames = list(col_imName_dict.values())

    @property
    def wavelengths(self):
        if not self._wavelenths and self.satellite and self.sensor and self.LayerBandsAssignment:
            self._wavelenths = list(
                RSR(self.satellite,
                    self.sensor,
                    LayerBandsAssignment=self.LayerBandsAssignment).wvl)

        return self._wavelenths

    @wavelengths.setter
    def wavelengths(self, wavelengths):
        self._wavelenths = wavelengths

    def add_refcube_array(self, refcube_array, src_imnames,
                          LayerBandsAssignment):
        # type: (Union[str, np.ndarray], list, list) -> None
        """Add the given given array to the RefCube instance.

        :param refcube_array:           3D array or file path  of the reference cube to be added
                                        (spectral samples /signatures x training images x spectral bands)
        :param src_imnames:             list of training image file base names from which the given cube received data
        :param LayerBandsAssignment:    LayerBandsAssignment of the spectral bands of the given 3D array
        :return:
        """
        # validation
        assert LayerBandsAssignment == self.LayerBandsAssignment, \
            "%s != %s" % (LayerBandsAssignment, self.LayerBandsAssignment)

        if self.data.size:
            new_cube = np.hstack([self.data, refcube_array])
            self.data = GeoArray(new_cube, nodata=self.data.nodata)
        else:
            self.data = GeoArray(refcube_array, nodata=self.data.nodata)

        self.srcImNames.extend(src_imnames)

    def add_spectra(self, spectra, src_imname, LayerBandsAssignment):
        # type: (np.ndarray, str, list) -> None
        """Add a set of spectral signatures to the reference cube.

        :param spectra:              2D numpy array with rows: spectral samples / columns: spectral information (bands)
        :param src_imname:           image basename of the source hyperspectral image
        :param LayerBandsAssignment: LayerBandsAssignment for the spectral dimension of the passed spectra,
                                     e.g., ['1', '2', '3', '4', '5', '6L', '6H', '7', '8']
        """
        # validation
        assert LayerBandsAssignment == self.LayerBandsAssignment, \
            "%s != %s" % (LayerBandsAssignment, self.LayerBandsAssignment)

        # reshape 2D spectra array to one image column (refcube is an image with spectral information in the 3rd dim.)
        im_col = spectra.reshape((spectra.shape[0], 1, spectra.shape[1]))

        meta = self.data.metadata  # needs to be copied to the new GeoArray

        if self.data.size:
            # validation
            if spectra.shape[0] != self.data.shape[0]:
                raise ValueError(
                    'The number of signatures in the given spectra array does not match the dimensions of '
                    'the reference cube.')

            # append spectra to existing reference cube
            new_cube = np.hstack([self.data, im_col])
            self.data = GeoArray(new_cube, nodata=self.data.nodata)

        else:
            self.data = GeoArray(im_col, nodata=self.data.nodata)

        # copy previous metadata to the new GeoArray instance
        self.data.metadata = meta

        # add source image name to list of image names
        self.srcImNames.append(src_imname)

    @property
    def metadata(self):
        """Return an ordered dictionary holding the metadata of the reference cube."""
        attrs2include = [
            'satellite', 'sensor', 'filepath', 'n_signatures', 'n_images',
            'n_clusters', 'n_signatures_per_cluster', 'col_imName_dict',
            'LayerBandsAssignment', 'wavelengths'
        ]
        return OrderedDict((k, getattr(self, k)) for k in attrs2include)

    def get_band_combination(self, tgt_LBA):
        # type: (List[str]) -> GeoArray
        """Get an array according to the bands order given by a target LayerBandsAssignment.

        :param tgt_LBA:     target LayerBandsAssignment
        :return:
        """
        if tgt_LBA != self.LayerBandsAssignment:
            cur_LBA_dict = dict(
                zip(self.LayerBandsAssignment,
                    range(len(self.LayerBandsAssignment))))
            tgt_bIdxList = [cur_LBA_dict[lr] for lr in tgt_LBA]

            return GeoArray(np.take(self.data, tgt_bIdxList, axis=2),
                            nodata=self.data.nodata)
        else:
            return self.data

    def get_spectra_dataframe(self, tgt_LBA):
        # type: (List[str]) -> DataFrame
        """Return a pandas.DataFrame [sample x band] according to the given LayerBandsAssignment.

        :param tgt_LBA: target LayerBandsAssignment
        :return:
        """
        imdata = self.get_band_combination(tgt_LBA)
        spectra = im2spectra(imdata)
        df = DataFrame(spectra, columns=['B%s' % band for band in tgt_LBA])

        return df

    def rearrange_layers(self, tgt_LBA):
        # type: (List[str]) -> None
        """Rearrange the spectral bands of the reference cube according to the given LayerBandsAssignment.

        :param tgt_LBA:     target LayerBandsAssignment
        """
        self.data = self.get_band_combination(tgt_LBA)
        self.LayerBandsAssignment = tgt_LBA

    def read_data_from_disk(self, filepath):
        self.data = GeoArray(filepath)

        with open(os.path.splitext(filepath)[0] + '.meta', 'r') as metaF:
            meta = json.load(metaF)
            for k, v in meta.items():
                if k in [
                        'n_signatures', 'n_images', 'n_clusters',
                        'n_signatures_per_cluster'
                ]:
                    continue  # skip pure getters
                else:
                    setattr(self, k, v)

    def save(self, path_out, fmt='ENVI'):
        # type: (str, str) -> None
        """Save the reference cube to disk.

        :param path_out:    output path on disk
        :param fmt:         output format as GDAL format code
        :return:
        """
        self.filepath = self.filepath or path_out
        self.data.save(out_path=path_out, fmt=fmt)

        # save metadata as JSON file
        with open(os.path.splitext(path_out)[0] + '.meta', 'w') as metaF:
            json.dump(self.metadata.copy(),
                      metaF,
                      separators=(',', ': '),
                      indent=4)

    def _get_spectra_by_label_imname(self,
                                     cluster_label,
                                     image_basename,
                                     n_spectra2get=100,
                                     random_state=0):
        cluster_start_pos_all = list(
            range(0, self.n_signatures, self.n_signatures_per_cluster))
        cluster_start_pos = cluster_start_pos_all[cluster_label]
        spectra = self.data[cluster_start_pos:cluster_start_pos +
                            self.n_signatures_per_cluster,
                            self.srcImNames.index(image_basename)]
        idxs_specIncl = np.random.RandomState(seed=random_state).choice(
            range(self.n_signatures_per_cluster), n_spectra2get)
        return spectra[idxs_specIncl, :]

    def plot_sample_spectra(self,
                            image_basename,
                            cluster_label='all',
                            include_mean_spectrum=True,
                            include_median_spectrum=True,
                            ncols=5,
                            **kw_fig):
        # type: (Union[str, int, List], str, bool, bool, int, dict) -> 'plt.figure'
        from matplotlib import pyplot as plt

        if isinstance(cluster_label, int):
            lbls2plot = [cluster_label]
        elif isinstance(cluster_label, list):
            lbls2plot = cluster_label
        elif cluster_label == 'all':
            lbls2plot = list(range(self.n_clusters))
        else:
            raise ValueError(cluster_label)

        # create a single plot
        if len(lbls2plot) == 1:
            if cluster_label == 'all':
                cluster_label = 0

            fig, axes = plt.figure(), None
            spectra = self._get_spectra_by_label_imname(
                cluster_label, image_basename, 100)
            for i in range(100):
                plt.plot(self.wavelengths, spectra[i, :])

            plt.xlabel('wavelength [nm]')
            plt.ylabel('%s %s\nreflectance [0-10000]' %
                       (self.satellite, self.sensor))
            plt.title('Cluster #%s' % cluster_label)

            if include_mean_spectrum:
                plt.plot(self.wavelengths,
                         np.mean(spectra, axis=0),
                         c='black',
                         lw=3)
            if include_median_spectrum:
                plt.plot(self.wavelengths,
                         np.median(spectra, axis=0),
                         '--',
                         c='black',
                         lw=3)

        # create a plot with multiple subplots
        else:
            nplots = len(lbls2plot)
            ncols = nplots if nplots < ncols else ncols
            nrows = nplots // ncols if not nplots % ncols else nplots // ncols + 1
            figsize = (4 * ncols, 3 * nrows)
            fig, axes = plt.subplots(nrows=nrows,
                                     ncols=ncols,
                                     figsize=figsize,
                                     sharex='all',
                                     sharey='all',
                                     **kw_fig)

            for lbl, ax in tqdm(zip(lbls2plot, axes.flatten()), total=nplots):
                spectra = self._get_spectra_by_label_imname(
                    lbl, image_basename, 100)

                for i in range(100):
                    ax.plot(self.wavelengths, spectra[i, :], lw=1)

                if include_mean_spectrum:
                    ax.plot(self.wavelengths,
                            np.mean(spectra, axis=0),
                            c='black',
                            lw=2)
                if include_median_spectrum:
                    ax.plot(self.wavelengths,
                            np.median(spectra, axis=0),
                            '--',
                            c='black',
                            lw=3)

                ax.grid(lw=0.2)
                ax.set_ylim(0, 10000)

                if ax.get_subplotspec().is_last_row():
                    ax.set_xlabel('wavelength [nm]')
                if ax.get_subplotspec().is_first_col():
                    ax.set_ylabel('%s %s\nreflectance [0-10000]' %
                                  (self.satellite, self.sensor))
                ax.set_title('Cluster #%s' % lbl)

        fig.suptitle("Refcube spectra from image '%s':" % image_basename,
                     fontsize=15)
        plt.tight_layout(rect=(0, 0, 1, .95))
        plt.show()

        return fig
Ejemplo n.º 3
0
    def correct_shifts(self) -> collections.OrderedDict:
        if not self.q:
            print('Correcting geometric shifts...')

        t_start = time.time()

        if not self.warping_needed:
            """NO RESAMPLING NEEDED"""

            self.is_shifted = True
            self.is_resampled = False
            xmin, ymin, xmax, ymax = self._get_out_extent()

            if not self.q:
                print(
                    "NOTE: The detected shift is corrected by updating the map info of the target image only, i.e., "
                    "without any resampling. Set the 'align_grids' parameter to True if you need the target and the "
                    "reference coordinate grids to be aligned.")

            if self.cliptoextent:
                # TODO validate results
                # TODO -> output extent does not seem to be the requested one! (only relevant if align_grids=False)
                # get shifted array
                shifted_geoArr = GeoArray(self.im2shift[:],
                                          tuple(self.updated_gt),
                                          self.shift_prj)

                # clip with target extent
                #  NOTE: get_mapPos() does not perform any resampling as long as source and target projection are equal
                self.arr_shifted, self.updated_gt, self.updated_projection = \
                    shifted_geoArr.get_mapPos((xmin, ymin, xmax, ymax),
                                              self.shift_prj,
                                              fillVal=self.nodata,
                                              band2get=self.band2process)

                self.updated_map_info = geotransform2mapinfo(
                    self.updated_gt, self.updated_projection)

            else:
                # array keeps the same; updated gt and prj are taken from coreg_info
                self.arr_shifted = self.im2shift[:, :, self.band2process] \
                    if self.band2process is not None else self.im2shift[:]

            out_geoArr = GeoArray(self.arr_shifted,
                                  self.updated_gt,
                                  self.updated_projection,
                                  q=self.q)
            out_geoArr.nodata = self.nodata  # equals self.im2shift.nodata after __init__()
            out_geoArr.metadata = self.im2shift.metadata[[self.band2process]] \
                if self.band2process is not None else self.im2shift.metadata

            self.GeoArray_shifted = out_geoArr

            if self.path_out:
                out_geoArr.save(self.path_out, fmt=self.fmt_out)

        else:  # FIXME equal_prj==False ist noch NICHT implementiert
            """RESAMPLING NEEDED"""
            # FIXME avoid reading the whole band if clip_extent is passed

            in_arr = self.im2shift[:, :, self.band2process] \
                if self.band2process is not None and self.im2shift.ndim == 3 else self.im2shift[:]

            if not self.GCPList:
                # apply XY-shifts to input image gt 'shift_gt' in order to correct the shifts before warping
                self.shift_gt[0], self.shift_gt[3] = self.updated_gt[
                    0], self.updated_gt[3]

            # get resampled array
            out_arr, out_gt, out_prj = \
                warp_ndarray(in_arr, self.shift_gt, self.shift_prj, self.ref_prj,
                             rspAlg=_dict_rspAlg_rsp_Int[self.rspAlg],
                             in_nodata=self.nodata,
                             out_nodata=self.nodata,
                             out_gsd=self.out_gsd,
                             out_bounds=self._get_out_extent(),  # always returns an extent snapped to the target grid
                             gcpList=self.GCPList,
                             # polynomialOrder=str(3),
                             # options='-refine_gcps 500 1.9',
                             # warpOptions=['-refine_gcps 500 1.9'],
                             # options='-wm 10000',# -order 3',
                             # options=['-order 3'],
                             # options=['GDAL_CACHEMAX 800 '],
                             # warpMemoryLimit=125829120, # 120MB
                             CPUs=self.CPUs,
                             progress=self.progress,
                             q=self.q)

            out_geoArr = GeoArray(out_arr, out_gt, out_prj, q=self.q)
            out_geoArr.nodata = self.nodata  # equals self.im2shift.nodata after __init__()
            out_geoArr.metadata = self.im2shift.metadata[[self.band2process]] \
                if self.band2process is not None else self.im2shift.metadata

            self.arr_shifted = out_arr
            self.updated_gt = out_gt
            self.updated_projection = out_prj
            self.updated_map_info = geotransform2mapinfo(out_gt, out_prj)
            self.GeoArray_shifted = out_geoArr
            self.is_shifted = True
            self.is_resampled = True

            if self.path_out:
                out_geoArr.save(self.path_out,
                                fmt=self.fmt_out,
                                creationOptions=self.out_creaOpt)

        # validation
        if not is_coord_grid_equal(
                self.updated_gt, *self.out_grid, tolerance=1.e8):
            raise RuntimeError(
                'DESHIFTER output dataset has not the desired target pixel grid. Target grid '
                'was %s. Output geotransform is %s.' %
                (str(self.out_grid), str(self.updated_gt)))
        # TODO to be continued (extent, map info, ...)

        if self.v:
            print('Time for shift correction: %.2fs' % (time.time() - t_start))
        return self.deshift_results