Exemple #1
0
    def __init__(self, quality_control_file, band, with_stats,
                 number_of_processes):
        QualityControl.list.append(self)
        self.band = band
        self.band_name = 'band' + fix_zeros(band, 2)

        self.qcf = quality_control_file
        self.with_stats = with_stats
        self.number_of_processes = number_of_processes

        self.qc_check_lists = {}

        self.output_driver = None
        self.output_bands = []
        self.output_filename = "{0}_{1}_band{2}.tif".format(
            SatelliteData.tile, SatelliteData.shortname, fix_zeros(band, 2))

        if self.with_stats:
            # for save some statistics fields after check the quality control
            self.quality_control_statistics = {}

            # initialize quality control bands class
            for sd in SatelliteData.list:
                for qc_id_name, qc_checker in sd.qc_bands.items():
                    qc_checker.init_statistics(quality_control_file)
Exemple #2
0
    def save_results(self, output_dir):
        """Save all processed files in one file per each data band to process,
        each file to save has the precessed files as bands.

        :param output_dir: directory to save the output file
        :type output_dir: path
        """
        print("\nSaving the result for the band {0} in: {1}".format(
            self.band, self.output_filename))
        # get gdal properties of one of data band
        sd = SatelliteData.list[0]
        data_band_name = [
            x for x in sd.sub_datasets if 'b' + fix_zeros(self.band, 2) in x[1]
        ][0][0]
        gdal_data_band = gdal.Open(data_band_name, gdal.GA_ReadOnly)
        geotransform = gdal_data_band.GetGeoTransform()
        originX = geotransform[0]
        originY = geotransform[3]
        pixelWidth = geotransform[1]
        pixelHeight = geotransform[5]

        # create output raster
        driver = gdal.GetDriverByName('GTiff')
        nbands = len(self.output_bands)
        outRaster = driver.Create(
            os.path.join(output_dir, self.output_filename),
            sd.get_cols(self.band), sd.get_rows(self.band), nbands,
            gdal.GDT_Int16, ["COMPRESS=LZW", "PREDICTOR=2", "TILED=YES"])

        # write bands
        for nband, data_band_raster_mmap_file in enumerate(self.output_bands):
            # load result raster saved in file with memmap (joblib dump)
            data_band_raster = load(data_band_raster_mmap_file, mmap_mode='r')
            outband = outRaster.GetRasterBand(nband + 1)
            outband.WriteArray(data_band_raster)
            #outband.WriteArray(sd.get_data_band(self.band))
            outband.SetNoDataValue(self.nodata_value)
            #outband.FlushCache()  # FlushCache cause WriteEncodedTile/Strip() failed

            # clean
            outband = None
            del data_band_raster
            shutil.rmtree(os.path.dirname(data_band_raster_mmap_file))

        # set projection
        outRaster.SetGeoTransform(
            (originX, pixelWidth, 0, originY, 0, pixelHeight))
        outRasterSRS = osr.SpatialReference()
        outRasterSRS.ImportFromWkt(gdal_data_band.GetProjectionRef())
        outRaster.SetProjection(outRasterSRS.ExportToWkt())

        # clean
        gdal_data_band = None
        geotransform = None
        outRaster = None
Exemple #3
0
    def get_data_band(self, band):
        """Return the raster of the data band for respective band
        of the file.

        :param band: band to process
        :type band: int
        :return: raster of the data band
        :rtype: ndarray
        """

        # TODO: optimize/performance the table open/access in memory (pytables?)

        data_band_name = [
            x for x in self.sub_datasets if 'b' + fix_zeros(band, 2) in x[1]
        ][0][0]
        gdal_data_band = gdal.Open(data_band_name, gdal.GA_ReadOnly)
        data_band_raster = gdal_data_band.ReadAsArray()
        del gdal_data_band
        return data_band_raster
Exemple #4
0
    def __init__(self, file, xml_file):
        """Initialize the class of MODIS products

        :param xml_file: path to input xml file
        :type xml_file: str
        :param file: path to input file
        :type file: str
        """
        super().list.append(self)
        super().__init__(file)

        # load metadata
        tree = ET.parse(xml_file)
        self.satellite = list(tree.iter('PlatformShortName'))[0].text  # Terra
        self.shortname = list(tree.iter('ShortName'))[0].text  # MOD09A1
        self.tile = list(
            tree.iter('LocalGranuleID'))[0].text.split('.')[2]  # h10v07
        # get the beginning date
        dt_d = [
            int(x)
            for x in list(tree.iter('RangeBeginningDate'))[0].text.split('-')
        ]
        self.start_date = date(dt_d[0], dt_d[1], dt_d[2])
        # calculate Julian date of the beginning date
        self.start_jday = self.start_date.timetuple().tm_yday
        # year and jday (ie 2015034), equal to filename string
        self.start_year_and_jday = "{0}{1}".format(
            self.start_date.year, fix_zeros(self.start_jday, 3))

        # save in globals vars of class
        SatelliteData.satellite = self.satellite
        SatelliteData.shortname = self.shortname
        SatelliteData.tile = self.tile

        qc_success_set = self.set_quality_control_bands()
        self.make_qc = qc_success_set

        del tree
Exemple #5
0
 def get_quality_control_bands(self, band):
     return [
         x for x in self.sub_datasets if 'b' + fix_zeros(band, 2) in x[1]
     ][0][0]
Exemple #6
0
 def get_nodata_value(self, band):
     data_band_name = [
         x for x in self.sub_datasets if 'b' + fix_zeros(band, 2) in x[1]
     ][0][0]
     gdal_data_band = gdal.Open(data_band_name, gdal.GA_ReadOnly)
     return gdal_data_band.GetRasterBand(1).GetNoDataValue()
Exemple #7
0
 def get_total_pixels(self, band):
     data_band_name = [
         x for x in self.sub_datasets if 'b' + fix_zeros(band, 2) in x[1]
     ][0][0]
     gdal_data_band = gdal.Open(data_band_name, gdal.GA_ReadOnly)
     return gdal_data_band.RasterXSize * gdal_data_band.RasterYSize
Exemple #8
0
    def save_statistics(self, output_dir):
        """Save statistics of invalid pixels in a image that show the time series of
        all invalid pixels of all filters as the result after apply the QC4SD
        """
        # force matplotlib to not use any Xwindows backend.
        import matplotlib
        matplotlib.use('Agg')

        import matplotlib.ticker as mtick
        import matplotlib.pyplot as plt

        # path to save statistics
        path_stats = os.path.join(output_dir,
                                  self.output_filename.split('.tif')[0])
        # if not os.path.isdir(path_stats):
        #     os.makedirs(path_stats)

        ################################
        # graph invalid pixels for the time series

        img_filename = os.path.join(
            output_dir,
            self.output_filename.split('.tif')[0] + "_stats.png")
        print(
            "Saving the image of statistics of invalid pixels in: {0}".format(
                os.path.basename(img_filename)))

        ################################
        # prepare data
        all_filter_names = set()
        for sd_invalid_pixels in self.quality_control_statistics.values():
            filters = sd_invalid_pixels['invalid_pixels']
            # delete elements if the values are empty
            filters = {k: filters[k] for k in filters if filters[k]}
            # unpacking the dicts of all filters
            filters = [x for x in filters.values()]
            _tmp_dict = {}
            for filter in filters:
                _tmp_dict.update(filter)
            filters = _tmp_dict
            all_filter_names = all_filter_names | set(filters.keys())
        all_filter_names = sorted(list(all_filter_names))

        sd_names_sorted = sorted(self.quality_control_statistics.keys())
        all_invalid_pixels = []
        for sd_name in sd_names_sorted:
            filters = self.quality_control_statistics[sd_name][
                'invalid_pixels']
            # delete elements if the values are empty
            filters = {k: filters[k] for k in filters if filters[k]}
            # unpacking the dicts of all filters
            filters = [x for x in filters.values()]
            _tmp_dict = {}
            for filter in filters:
                _tmp_dict.update(filter)
            filters = _tmp_dict

            sd_time_series = [
                self.quality_control_statistics[sd_name]
                ['total_invalid_pixels']
            ]
            sd_time_series += [
                self.quality_control_statistics[sd_name]['nodata_pixels']
            ]
            for filter_name in all_filter_names:
                if filter_name in filters:
                    sd_time_series.append(filters[filter_name])
                else:
                    sd_time_series.append(float('nan'))
            all_invalid_pixels.append(sd_time_series)

        all_filter_names = ['total_invalid_pixels'
                            ] + ['nodata_pixels'] + list(all_filter_names)
        #all_filter_names = [name.replace('_', ' ') for name in all_filter_names]

        ################################
        # plot

        width = 10 + len(sd_names_sorted) * 0.4
        if len(sd_names_sorted) == 1: width = 7
        fig, ax = plt.subplots(1, 1, figsize=(width, 8), facecolor='white')
        ax.spines['top'].set_visible(False)
        ax.spines['bottom'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.spines['left'].set_visible(False)
        ax.get_xaxis().tick_bottom()
        ax.get_yaxis().tick_left()
        plt.tick_params(axis='both',
                        which='both',
                        bottom='off',
                        top='off',
                        labelbottom='on',
                        left='off',
                        right='off',
                        labelleft='on')
        all_invalid_pixels_T = list(map(list, zip(*all_invalid_pixels)))

        # delete all group of list that have only zeros, this is delete types of
        # invalid pixels that not filter any pixel in all image across the time
        delete_zeros_lists = [
            idx for idx, values in enumerate(all_invalid_pixels_T)
            if [x for x in set(values) if not isnan(x)] == [0]
        ]
        delete_zeros_lists.reverse()
        for del_idx in delete_zeros_lists:
            del all_invalid_pixels_T[del_idx]
            del all_filter_names[del_idx]
        # rewrite list after clean
        all_invalid_pixels = list(map(list, zip(*all_invalid_pixels_T)))

        if not all_invalid_pixels:
            print(
                "\nWARNING: the invalid pixels is zero! nothing pixels was filtered.\n"
            )
            return

        max_y = max([max(sub_l) for sub_l in all_invalid_pixels_T
                     ])  # y max over all times

        # fix position for y label
        y_pos_label_fixed = deepcopy(all_invalid_pixels[-1])
        # if any item in the last position is nan, put the last
        # valid value for this item
        for idx, y_pos in enumerate(y_pos_label_fixed):
            iter_pos = -1
            if isnan(y_pos):
                while isnan(all_invalid_pixels[iter_pos][idx]):
                    iter_pos += -1
                y_pos_label_fixed[idx] = all_invalid_pixels[iter_pos][idx]

        # set initial value for repulsive_items_list
        repulsive_distance = max_y * 0.035

        fix_list = True
        while fix_list:
            fix_list, repulsive_distance, y_pos_label_fixed = repulsive_items_list(
                y_pos_label_fixed, repulsive_distance)

        # define colors
        import matplotlib as mpl
        import matplotlib.cm as cm
        norm = mpl.colors.Normalize(vmin=0, vmax=len(all_invalid_pixels_T))
        cmap = cm.Set1
        m = cm.ScalarMappable(norm=norm, cmap=cmap)

        if len(sd_names_sorted) == 1:  # for only one image
            for idx, line in enumerate(all_invalid_pixels_T):
                if idx == 0:
                    plt.plot(0,
                             line,
                             'ro',
                             markersize=9,
                             color=m.to_rgba(idx),
                             linewidth=3.4,
                             alpha=1)
                    # put value of total invalid pixel for each x item (time)
                    for x, y in zip(range(len(SatelliteData.list)), line):
                        ax.text(
                            x,
                            y + max_y * 0.02,
                            "{0}%".format(
                                round(
                                    100 * y /
                                    SatelliteData.list[idx].get_total_pixels(
                                        self.band), 2)),
                            ha='center',
                            va='bottom',
                            color=m.to_rgba(idx),
                            fontweight='bold',
                            fontsize=12,
                            alpha=1)
                else:
                    plt.plot(0,
                             line,
                             'ro',
                             markersize=9,
                             color=m.to_rgba(idx),
                             linewidth=3,
                             alpha=1)
                # y label of filter name
                plt.text(0.3,
                         y_pos_label_fixed[idx],
                         all_filter_names[idx],
                         fontsize=12,
                         weight='bold',
                         color=m.to_rgba(idx),
                         alpha=1)
            plt.xlim(-0.5, 0.5)
            plt.xticks(range(len(sd_names_sorted)),
                       sd_names_sorted,
                       rotation=90)
            plt.ylim(-max_y * 0.01, max_y + max_y * 0.07)
            plt.title(
                "Invalid pixels for {0} {1} in band {2}\nQC4SD - IDEAM".format(
                    SatelliteData.tile, SatelliteData.shortname,
                    fix_zeros(self.band, 2)),
                fontsize=18,
                weight='bold',
                color="#3A3A3A")
            plt.xlabel("Date", fontsize=14, weight='bold', color="#3A3A3A")
            plt.ylabel("Number of invalid pixels",
                       fontsize=14,
                       weight='bold',
                       color="#3A3A3A")
            plt.tick_params(axis='both',
                            which='major',
                            labelsize=14,
                            color="#3A3A3A")
            ax.grid(True, color='gray')
            ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%g'))
            fig.tight_layout()
            fig.subplots_adjust(right=0.6, left=0.4)
        else:
            for idx, line in enumerate(all_invalid_pixels_T):
                if idx == 0:
                    plt.plot(line,
                             color=m.to_rgba(idx),
                             linewidth=3.4,
                             alpha=1)
                    # put value of total invalid pixel for each x item (time)
                    for x, y in zip(range(len(SatelliteData.list)), line):
                        ax.text(
                            x,
                            y + max_y * 0.02,
                            "{0}%".format(
                                round(
                                    100 * y /
                                    SatelliteData.list[idx].get_total_pixels(
                                        self.band), 2)),
                            ha='center',
                            va='bottom',
                            color=m.to_rgba(idx),
                            fontweight='bold',
                            fontsize=12,
                            alpha=1)
                else:
                    plt.plot(line, color=m.to_rgba(idx), linewidth=3, alpha=1)
                # y label of filter name
                plt.text(len(SatelliteData.list) - 1 +
                         len(SatelliteData.list) * 0.02,
                         y_pos_label_fixed[idx],
                         all_filter_names[idx],
                         fontsize=12,
                         weight='bold',
                         color=m.to_rgba(idx),
                         alpha=1)
            plt.xlim(
                -len(SatelliteData.list) * 0.02,
                len(SatelliteData.list) - 1 + len(SatelliteData.list) * 0.02)
            plt.xticks(range(len(sd_names_sorted)),
                       sd_names_sorted,
                       rotation=90)
            plt.ylim(-max_y * 0.01, max_y + max_y * 0.07)
            plt.title(
                "Invalid pixels for {0} {1} in band {2}\nQC4SD - IDEAM".format(
                    SatelliteData.tile, SatelliteData.shortname,
                    fix_zeros(self.band, 2)),
                fontsize=18,
                weight='bold',
                color="#3A3A3A")
            plt.xlabel("Date", fontsize=14, weight='bold', color="#3A3A3A")
            plt.ylabel("Number of invalid pixels",
                       fontsize=14,
                       weight='bold',
                       color="#3A3A3A")
            plt.tick_params(axis='both',
                            which='major',
                            labelsize=14,
                            color="#3A3A3A")
            ax.grid(True, color='gray')
            ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%g'))
            fig.tight_layout()
            fig.subplots_adjust(right=1.02 - 3.6 / width)

        plt.savefig(img_filename, dpi=86)
        plt.close('all')

        # trim whitespace
        try:
            call([
                "convert", img_filename, "-trim", "-bordercolor", "white",
                "-border", "8x8", "+repage", "-alpha", "off", img_filename
            ])
            print('trim image successfully')
        except:
            pass