예제 #1
0
def _test_array(image, dtype='float64'):

    with raster_tools.ropen(image) as l_info:

        image_array = l_info.read(bands2open=-1, d_type=dtype)

    l_info = None

    return image_array
예제 #2
0
def _test_object(image):

    with raster_tools.ropen(image) as l_info:

        bands = l_info.bands
        rows = l_info.rows
        cols = l_info.cols

    l_info = None

    return bands, rows, cols
예제 #3
0
def sfs_orfeo(parameter_object):

    com = 'otbcli_SFSTextureExtraction -in {} -channel {:d} -ram 512 ' \
          '-parameters.spethre {:d} -parameters.spathre {:d} -parameters.nbdir 40 ' \
          '-out {}'.format(parameter_object.input_image,
                           int(parameter_object.band_position),
                           int(parameter_object.sfs_threshold),
                           int(parameter_object.scales[-1]),
                           parameter_object.out_img)

    if not os.path.isfile(parameter_object.out_img):
        subprocess.call(com, shell=True)

    with raster_tools.ropen(parameter_object.out_img) as i_info:

        # 6 layers
        for bd in range(1, i_info.bands + 1):

            raster_tools.translate(parameter_object.out_img,
                                   parameter_object.out_img.replace(
                                       '.tif', '.vrt'),
                                   bandList=[bd],
                                   cell_size=i_info.cellY,
                                   format='VRT',
                                   d_type='float32')

            new_image = parameter_object.out_img.replace(
                'fea100', 'fea{:03d}'.format(bd))
            new_image = new_image.replace(
                'bd{:d}'.format(parameter_object.band_position), 'bd-rgb')

            raster_tools.warp(parameter_object.out_img.replace('.tif', '.vrt'),
                              new_image,
                              cell_size=parameter_object.sfs_resample,
                              resampleAlg='average',
                              warpMemoryLimit=256,
                              multithread=True,
                              creationOptions=[
                                  'COMPRESS=DEFLATE', 'BIGTIFF=YES',
                                  'TILED=YES'
                              ])

            os.remove(parameter_object.out_img.replace('.tif', '.vrt'))

    i_info = None
예제 #4
0
    def _setup_out_infos(self, **kwargs):
        """
        Creates the output image information objects
        """

        if isinstance(self.out_dir, str):

            if not os.path.isdir(self.out_dir):
                os.makedirs(self.out_dir)

        self.o_infos = list()

        for image_info in self.image_infos:

            d_name, f_name = os.path.split(image_info.file_name)
            f_base, f_ext = os.path.splitext(f_name)

            if isinstance(self.out_dir, str):
                d_name = self.out_dir

            out_name = os.path.join(d_name, '{}_hmm{}'.format(f_base, f_ext))

            if os.path.isfile(out_name + '.ovr'):
                os.remove(out_name + '.ovr')

            if os.path.isfile(out_name + '.aux.xml'):
                os.remove(out_name + '.aux.xml')

            if os.path.isfile(out_name):
                self.o_infos.append(
                    raster_tools.ropen(out_name, open2read=False))
            else:

                o_info = image_info.copy()

                if self.assign_class:

                    o_info.update_info(storage='byte', bands=1)

                self.o_infos.append(
                    raster_tools.create_raster(out_name, o_info, **kwargs))

        self.out_blocks = os.path.join(d_name, 'hmm_BLOCK.txt')
예제 #5
0
파일: spprocess.py 프로젝트: adbeda/spfeas
def run(parameter_object):
    """
    Args:
        input_image, output_dir, band_positions=[1], use_rgb=False, block=2, scales=[8], triggers=['mean'],
        threshold=20, min_len=10, line_gap=2, weighted=False, sfs_thresh=80, resamp_sfs=0.,
        equalize=False, equalize_adapt=False, smooth=0, visualize=False, convert_stk=False, gdal_cache=256,
        do_pca=False, stack_feas=True, stack_only=False, neighbors=False, n_jobs=-1,
        reset_sects=False, image_max=0, lac_r=2, section_size=8000, chunk_size=512
    """

    global potsi, param_dict

    if parameter_object.n_jobs == 0:
        parameter_object.n_jobs = 1
    elif parameter_object.n_jobs < 0:
        parameter_object.n_jobs = multi.cpu_count()
    elif parameter_object.n_jobs > multi.cpu_count():
        parameter_object.n_jobs = multi.cpu_count()

    sputilities.parameter_checks(parameter_object)

    # Write the parameters to file.
    sputilities.write_log(parameter_object)

    if parameter_object.stack_only:

        new_feas_list = list()

        # If prompted, stack features without processing.
        parameter_object = sputilities.stack_features(parameter_object,
                                                      new_feas_list)

    else:

        # Create the status object.
        mts = sputilities.ManageStatus()

        parameter_object.remove_files = False

        # Setup the status dictionary.
        if os.path.isfile(parameter_object.status_file):

            mts.load_status(parameter_object.status_file)

            if parameter_object.section_size != mts.status_dict['SECTION_SIZE']:

                logger.warning(
                    'The section size was changed, so all existing tiled images will be removed.'
                )

                parameter_object.remove_files = True

            if not isinstance(mts.status_dict, dict):

                logger.error(
                    'The YAML file already existed, but was not properly stored and saved.\nPlease remove and re-run.'
                )
                raise AttributeError

        else:

            mts.status_dict = dict()

            mts.status_dict['ALL_FINISHED'] = 'no'
            mts.status_dict['BAND_ORDER'] = dict()

            # Save the band order.
            for trigger in parameter_object.triggers:

                mts.status_dict['BAND_ORDER']['{}'.format(
                    trigger)] = '{:d}-{:d}'.format(
                        parameter_object.band_info[trigger] + 1,
                        parameter_object.band_info[trigger] +
                        parameter_object.out_bands_dict[trigger] *
                        parameter_object.n_bands)

            mts.status_dict['SECTION_SIZE'] = parameter_object.section_size

            mts.dump_status(parameter_object.status_file)

        process_image = True

        if 'ALL_FINISHED' in mts.status_dict:

            if mts.status_dict['ALL_FINISHED'] == 'yes':
                process_image = False

        # Set the output features folder.
        parameter_object = sputilities.set_feas_dir(parameter_object)

        if parameter_object.remove_files:

            image_list = fnmatch.filter(os.listdir(parameter_object.feas_dir),
                                        '*.tif')

            if image_list:

                image_list = [
                    os.path.join(parameter_object.feas_dir, im_)
                    for im_ in image_list
                ]

                for full_image in image_list:
                    os.remove(full_image)

        if not process_image:
            logger.warning(
                'The input image, {}, is set as finished processing.'.format(
                    parameter_object.input_image))
        else:

            original_band_positions = copy.copy(
                parameter_object.band_positions)

            # Iterate over each feature trigger.
            for trigger in parameter_object.triggers:

                parameter_object.update_info(
                    trigger=trigger,
                    band_positions=original_band_positions,
                    band_counter=0)

                # Iterate over each band
                for band_position in parameter_object.band_positions:

                    parameter_object.update_info(band_position=band_position)

                    # Get the input image information.
                    with raster_tools.ropen(
                            parameter_object.input_image) as i_info:

                        # Check if any of the input
                        #   bands are corrupted.
                        i_info.check_corrupted_bands()

                        if i_info.corrupted_bands:

                            logger.error(
                                '\nThe following bands appear to be corrupted:\n{}'
                                .format(', '.join(i_info.corrupted_bands)))
                            raise CorruptedBandsError

                        # Get image statistics.
                        parameter_object = sputilities.get_stats(
                            i_info, parameter_object)

                        # Get the section size.
                        parameter_object = sputilities.get_section_size(
                            i_info, parameter_object)

                        # Get the number of sections in
                        #   the image (only used as a counter).
                        parameter_object = sputilities.get_n_sects(
                            i_info, parameter_object)

                        if parameter_object.trigger == 'saliency':

                            bp = raster_tools.BlockFunc(
                                get_saliency_tile_mean, [i_info],
                                None,
                                None,
                                band_list=[[1, 2, 3]],
                                d_types=['float32'],
                                write_array=False,
                                close_files=False,
                                be_quiet=True,
                                print_statement=
                                '\nGetting tile lab means for saliency',
                                out_attributes=['lab_means'],
                                block_rows=parameter_object.sect_row_size,
                                block_cols=parameter_object.sect_col_size,
                                min_max=[(parameter_object.image_min,
                                          parameter_object.image_max)] * 3,
                                vis_order=parameter_object.vis_order)

                            bp.run()

                            parameter_object.update_info(lab_means=np.array(
                                bp.lab_means, dtype='float32').mean(axis=0))

                    del i_info

                    mts = sputilities.ManageStatus()
                    mts.load_status(parameter_object.status_file)

                    for sect_counter in range(1, parameter_object.n_sects + 1):

                        parameter_object.update_info(
                            section_counter=sect_counter)
                        parameter_object = sputilities.scale_fea_check(
                            parameter_object)

                        if trigger == parameter_object.triggers[0]:
                            mts.status_dict[
                                parameter_object.out_img_base] = dict()

                        mts.status_dict[parameter_object.out_img_base][
                            '{TR}-{BD}'.format(TR=parameter_object.trigger,
                                               BD=parameter_object.
                                               band_position)] = 'unprocessed'

                    mts.dump_status(parameter_object.status_file)

                    param_dict = sputilities.class2dict(parameter_object)
                    potsi = parameter_object.section_idx_pairs

                    # PROCESS IN PARALLEL CHUNKS

                    parallel_chunk_counter = 1

                    for parallel_chunk in range(
                            1,
                            parameter_object.n_sects + parameter_object.n_jobs,
                            parameter_object.n_jobs):

                        pool = multi.Pool(processes=parameter_object.n_jobs)

                        if parallel_chunk + parameter_object.n_jobs < parameter_object.n_sects:
                            parallel_chunk_end = parallel_chunk + parameter_object.n_jobs
                        else:
                            parallel_chunk_end = parallel_chunk + (
                                parameter_object.n_sects - parallel_chunk) + 1

                        results = map(
                            _section_read_write,
                            range(parallel_chunk, parallel_chunk_end))

                        # results = pool.map(_section_read_write,
                        #                    range(parallel_chunk,
                        #                          parallel_chunk_end))
                        #
                        # pool.close()
                        # pool.join()
                        # del pool

                        logger.info('  Updating status ...')

                        for result in results:

                            parameter_object.update_info(
                                section_counter=parallel_chunk_counter)
                            parameter_object = sputilities.scale_fea_check(
                                parameter_object)

                            # Open the status YAML file.
                            mts = sputilities.ManageStatus()

                            # Load the status dictionary
                            mts.load_status(parameter_object.status_file)

                            if parameter_object.out_img_base in mts.status_dict:

                                if result:

                                    mts.status_dict[
                                        parameter_object.out_img_base][
                                            '{TR}-{BD}'.format(
                                                TR=parameter_object.trigger,
                                                BD=parameter_object.
                                                band_position)] = 'corrupt'

                                else:

                                    mts.status_dict[
                                        parameter_object.out_img_base][
                                            '{TR}-{BD}'.format(
                                                TR=parameter_object.trigger,
                                                BD=parameter_object.
                                                band_position)] = 'complete'

                            mts.dump_status(parameter_object.status_file)

                            parallel_chunk_counter += 1

                    # Parallel(n_jobs=parameter_object.n_jobs,
                    #          batch_size=1,
                    #          max_nbytes=None)(delayed(_section_read_write)(idx_pair,
                    #                                                        parameter_object.section_idx_pairs[idx_pair-1],
                    #                                                        param_dict)
                    #                           for idx_pair in range(1, parameter_object.n_sects+1))

                    parameter_object.band_counter += parameter_object.out_bands_dict[
                        parameter_object.trigger]

        # Check the corruption status.
        mts.load_status(parameter_object.status_file)

        n_corrupt = 0
        for k, v in mts.status_dict.items():

            if isinstance(v, dict):

                for ksub, vsub in v.iteritems():

                    if vsub in ['corrupt', 'incomplete']:
                        n_corrupt += 1

        if n_corrupt == 0:

            mts.status_dict['ALL_FINISHED'] = 'yes'
            mts.dump_status(parameter_object.status_file)

            # Finally, mosaic the image tiles.

            logger.info('  Creating the VRT mosaic ...')

            comp_dict = dict()

            # Get the image list.
            parameter_object = sputilities.scale_fea_check(parameter_object,
                                                           is_image=False)

            image_list = fnmatch.filter(os.listdir(parameter_object.feas_dir),
                                        parameter_object.search_wildcard)
            image_list = [
                os.path.join(parameter_object.feas_dir, im)
                for im in image_list
            ]

            comp_dict['001'] = image_list

            vrt_mosaic = parameter_object.status_file.replace('.yaml', '.vrt')

            vrt_builder(comp_dict,
                        vrt_mosaic,
                        force_type='float32',
                        be_quiet=True,
                        overwrite=True)

            if parameter_object.overviews:

                logger.info('\nBuilding VRT overviews ...')

                with raster_tools.ropen(vrt_mosaic,
                                        open2read=False) as vrt_info:

                    vrt_info.remove_overviews()
                    vrt_info.build_overviews(levels=[2, 4, 8, 16])

                del vrt_info

        else:

            if n_corrupt == 1:
                logger.warning(
                    '\nThere was {:d} corrupt or incomplete tile.\nRe-run the command with the same parameters.'
                    .format(n_corrupt))
            else:
                logger.warning(
                    '\nThere were {:d} corrupt or incomplete tiles.\nRe-run the command with the same parameters.'
                    .format(n_corrupt))
예제 #6
0
파일: spprocess.py 프로젝트: adbeda/spfeas
def _write_section2file(this_parameter_object__, meta_info, section2write,
                        i_sect, j_sect, out_rows, out_cols, section_counter):
    """
    Writes the section array to disk

    Args:
        this_parameter_object__ (class)
        meta_info (`rinfo` object)
        section2write (list of 1d arrays)
        i_sect (int)
        j_sect (int)
        section_counter (int)
    """

    logger.info('  Writing section {:d} of {:d} to file ...'.format(
        section_counter, this_parameter_object__.n_sects))

    o_info = meta_info.copy()

    o_info = sputilities.get_output_info_tile(meta_info, o_info,
                                              this_parameter_object__, i_sect,
                                              j_sect, out_rows, out_cols)

    if not isinstance(section2write, np.ndarray):

        section2write = np.zeros((o_info.bands, o_info.rows, o_info.cols),
                                 dtype='uint8')

    start_band = this_parameter_object__.band_info[
        this_parameter_object__.
        trigger] + this_parameter_object__.band_counter + 1
    n_bands = this_parameter_object__.out_bands_dict[
        this_parameter_object__.trigger]

    if section2write[0].shape[0] == 0 or section2write[0].shape[1] == 0:
        pass
    else:

        if os.path.isfile(this_parameter_object__.out_img):

            # Open the file and write the new bands.
            with raster_tools.ropen(this_parameter_object__.out_img,
                                    open2read=False) as out_raster:

                array_layer_counter = 0

                # Write each scale and feature.
                for feature_band in range(start_band, start_band + n_bands):

                    out_raster.write_array(section2write[array_layer_counter],
                                           band=feature_band)
                    out_raster.close_band()

                    array_layer_counter += 1

        else:

            # Create the output raster.
            with raster_tools.create_raster(this_parameter_object__.out_img,
                                            o_info,
                                            bigtiff='yes') as out_raster:

                array_layer_counter = 0

                # Write each scale and feature.
                for feature_band in range(start_band, start_band + n_bands):

                    out_raster.write_array(section2write[array_layer_counter],
                                           band=feature_band)
                    out_raster.close_band()

                    array_layer_counter += 1

        del out_raster

    is_corrupt = False

    # The tile won't be written to file
    #   in the case of zero-length sections.
    if os.path.isfile(this_parameter_object__.out_img):

        # Check if any of the bands are corrupted.
        with raster_tools.ropen(this_parameter_object__.out_img) as ob_info:

            ob_info.check_corrupted_bands()

            if ob_info.corrupted_bands:
                is_corrupt = True
            else:
                is_corrupt = False

        del ob_info

    return is_corrupt
예제 #7
0
파일: spprocess.py 프로젝트: adbeda/spfeas
def _section_read_write(section_counter):
    """
    Handles the section reading and writing

    Args:
        section_counter (int)
    """

    section_pair = potsi[section_counter - 1]

    # this_parameter_object_ = this_parameter_object.copy()
    this_parameter_object_ = copy.copy(param_dict)
    this_parameter_object_ = sputilities.dict2class(this_parameter_object_)

    # Get the input image information.
    with raster_tools.ropen(
            this_parameter_object_.input_image) as this_image_info:

        this_parameter_object_.update_info(section_counter=section_counter)

        # Set the output name.
        this_parameter_object_ = sputilities.scale_fea_check(
            this_parameter_object_)

        # Open the status YAML file.
        mts_ = sputilities.ManageStatus()

        # Load the status dictionary
        mts_.load_status(this_parameter_object_.status_file)

        # Check file status.
        if os.path.isfile(this_parameter_object_.out_img):

            if this_parameter_object_.out_img_base in mts_.status_dict:

                if this_parameter_object_.trigger in mts_.status_dict[
                        this_parameter_object_.out_img_base]:

                    # Check every trigger because the
                    #   entire file needs to be removed.
                    status_list = [
                        mts_.status_dict[this_parameter_object_.out_img_base]
                        ['{TR}-{BD}'.format(
                            TR=tr, BD=this_parameter_object_.band_position)]
                        for tr in this_parameter_object_.triggers
                    ]

                    if 'corrupt' in status_list:

                        logger.info('Re-running {} ...'.format(
                            this_parameter_object_.out_img))

                        # Remove the file on the first trigger
                        #   if the file is corrupt.
                        if this_parameter_object_.trigger == this_parameter_object_.triggers[
                                0]:
                            os.remove(this_parameter_object_.out_img)

                        mts_.status_dict[this_parameter_object_.out_img_base][
                            '{TR}-{BD}'.format(
                                TR=this_parameter_object_.trigger,
                                BD=this_parameter_object_.band_position
                            )] = 'incomplete'
                        mts_.dump_status(this_parameter_object_.status_file)

                    elif ('corrupt' not in status_list) and ('incomplete'
                                                             in status_list):

                        logger.info('Re-running {} ...'.format(
                            this_parameter_object_.out_img))

                    else:

                        if this_parameter_object_.overwrite:

                            logger.info('Re-running {} ...'.format(
                                this_parameter_object_.out_img))

                            # Remove the file on the first trigger.
                            if this_parameter_object_.trigger == this_parameter_object_.triggers[
                                    0]:
                                os.remove(this_parameter_object_.out_img)

                            mts_.status_dict[
                                this_parameter_object_.out_img_base][
                                    '{TR}-{BD}'.format(
                                        TR=this_parameter_object_.trigger,
                                        BD=this_parameter_object_.band_position
                                    )] = 'incomplete'
                            mts_.dump_status(
                                this_parameter_object_.status_file)

                        else:

                            logger.info('{} is already finished ...'.format(
                                this_parameter_object_.out_img))
                            return

            else:

                # Remove the file on the first trigger.
                if this_parameter_object_.trigger == this_parameter_object_.triggers[
                        0]:
                    os.remove(this_parameter_object_.out_img)

                logger.info('Re-running {} ...'.format(
                    this_parameter_object_.out_img))

        i_sect = section_pair[0]
        j_sect = section_pair[1]

        # Row and column section bounds checking
        n_rows = raster_tools.n_rows_cols(i_sect,
                                          this_parameter_object_.sect_row_size,
                                          this_image_info.rows)

        n_cols = raster_tools.n_rows_cols(j_sect,
                                          this_parameter_object_.sect_col_size,
                                          this_image_info.cols)

        # Open the image array.
        if this_parameter_object_.trigger.upper(
        ) in this_parameter_object_.spectral_indices:

            wavelengths = utils.VI_WAVELENGTHS[
                this_parameter_object_.trigger.upper()]

            # Check if the sensor supports the spectral index
            utils.sensor_wavelength_check(this_parameter_object_.sat_sensor,
                                          wavelengths)

            # Get the band positions needed
            #   to process the spectral index.
            spectral_bands = utils.get_index_bands(
                this_parameter_object_.trigger,
                this_parameter_object_.sat_sensor)

            sect_in = this_image_info.read(bands2open=spectral_bands,
                                           i=i_sect,
                                           j=j_sect,
                                           rows=n_rows,
                                           cols=n_cols,
                                           d_type='float32')

            sect_in[sect_in >= this_parameter_object_.
                    image_max] = this_parameter_object_.image_max
            sect_in /= this_parameter_object_.image_max

            vie = VegIndicesEquations(sect_in, chunk_size=-1)
            sect_in = vie.compute(this_parameter_object_.trigger.upper(),
                                  out_type=1)

            this_parameter_object_.update_info(image_min=0, image_max=1)

        elif this_parameter_object_.trigger == 'saliency':

            sect_in = saliency(this_image_info, this_parameter_object_, i_sect,
                               j_sect, n_rows, n_cols)

            this_parameter_object_.update_info(image_min=0, image_max=255)

        elif this_parameter_object_.trigger == 'seg':

            sect_in = this_image_info.read(bands2open=[1, 2, 3],
                                           i=i_sect,
                                           j=j_sect,
                                           rows=n_rows,
                                           cols=n_cols)

            sect_in = segment_image(sect_in, this_parameter_object_)

        elif this_parameter_object_.trigger == 'grad':

            if this_image_info.bands >= 3:

                sect_in = sputilities.convert_rgb2gray(
                    this_image_info, i_sect, j_sect, n_rows, n_cols,
                    this_parameter_object_.sat_sensor)[0]

            else:

                sect_in = this_image_info.read(
                    bands2open=this_parameter_object_.band_position,
                    i=i_sect,
                    j=j_sect,
                    rows=n_rows,
                    cols=n_cols)

            sect_in = get_mag_avg(sect_in)

            this_parameter_object_.update_info(image_min=0, image_max=30)

        elif this_parameter_object_.use_rgb and this_parameter_object_.trigger \
                not in this_parameter_object_.spectral_indices + ['grad', 'saliency', 'seg']:

            sect_in = sputilities.convert_rgb2gray(
                this_image_info, i_sect, j_sect, n_rows, n_cols,
                this_parameter_object_.sat_sensor)[0]

        else:

            sect_in = this_image_info.read(
                bands2open=this_parameter_object_.band_position,
                i=i_sect,
                j=j_sect,
                rows=n_rows,
                cols=n_cols)

        if this_parameter_object_.trigger == 'dmp':

            # The Differential Morphological Profile
            #   is a [D x M x N] array
            # where,
            #   D = the opening/closing derivative.
            sect_in = get_dmp(sect_in, this_parameter_object_.image_min,
                              this_parameter_object_.image_max)

        if this_parameter_object_.trigger == 'gabor':

            sect_in = convolve_gabor(sect_in, this_parameter_object_.image_min,
                                     this_parameter_object_.image_max,
                                     this_parameter_object_.scales)

        if this_parameter_object_.trigger == 'orb':

            sect_in = get_orb_keypoints(sect_in,
                                        this_parameter_object_.image_min,
                                        this_parameter_object_.image_max)

        this_parameter_object_.update_info(i_sect_blk_ctr=1, j_sect_blk_ctr=1)

        if this_parameter_object_.trigger in ['dmp', 'gabor']:
            l_rows, l_cols = sect_in[0].shape
        else:
            l_rows, l_cols = sect_in.shape

        # Compute section statistics.
        section_stats_array = spsplit.get_section_stats(
            sect_in, l_rows, l_cols, this_parameter_object_, section_counter)

        # Get the section output rows and columns.
        out_rows, out_cols = spsplit.get_out_dims(l_rows, l_cols,
                                                  this_parameter_object_)

        # Reshape the list of features into
        #   <features x rows x columns> array.
        out_section_array = spreshape.reshape_feature_list(
            section_stats_array, out_rows, out_cols, this_parameter_object_)

        is_corrupt = _write_section2file(this_parameter_object_,
                                         this_image_info, out_section_array,
                                         i_sect, j_sect, out_rows, out_cols,
                                         section_counter)

    this_parameter_object_ = None
    this_image_info_ = None

    return is_corrupt
예제 #8
0
    def fit_predict(self, lc_probabilities):
        """
        Fits a Hidden Markov Model

        Args:
            lc_probabilities (str list): A list of image class conditional probabilities. Each image in the list
                should be shaped [layers x rows x columns], where layers are equal to the number of land cover
                classes.
        """

        if not lc_probabilities:

            logger.error('The `fit` method cannot be executed without data.')
            raise ValueError

        if MKL_INSTALLED:
            n_threads_ = mkl_rt.MKL_Set_Num_Threads(self.n_jobs)

        self.lc_probabilities = lc_probabilities
        self.n_steps = len(self.lc_probabilities)

        # Get image information.
        with raster_tools.ropen(self.lc_probabilities[0]) as i_info:

            self.n_labels = i_info.bands
            self.rows = i_info.rows
            self.cols = i_info.cols

        i_info = None

        if not isinstance(self.n_labels, int):

            logger.error(
                'The number of layers was not properly extracted from the image set.'
            )
            raise TypeError

        if not isinstance(self.rows, int):

            logger.error(
                'The number of rows was not properly extracted from the image set.'
            )
            raise TypeError

        if not isinstance(self.cols, int):

            logger.error(
                'The number of columns was not properly extracted from the image set.'
            )
            raise TypeError

        # Setup the transition matrix.
        self._transition_matrix()

        self.methods = {
            'forward-backward': forward_backward,
            'viterbi': viterbi
        }

        # Open the images.
        self.image_infos = [
            raster_tools.ropen(image) for image in self.lc_probabilities
        ]

        self._setup_out_infos(**self.kwargs)

        # Iterate over the image block by block.
        self._block_func()
예제 #9
0
def raster_calc(output,
                equation=None,
                out_type='byte',
                extent=None,
                overwrite=False,
                be_quiet=False,
                out_no_data=0,
                row_block_size=2000,
                col_block_size=2000,
                apply_all_bands=False,
                **kwargs):
    """
    Raster calculator

    Args:
        output (str): The output image.
        equation (Optional[str]): The equation to calculate.
        out_type (Optional[str]): The output raster storage type. Default is 'byte'.
        extent (Optional[str]): An image or instance of ``mappy.ropen`` to use for the output extent. Default is None.
        overwrite (Optional[bool]): Whether to overwrite an existing IDW image. Default is False.
        be_quiet (Optional[bool]): Whether to be quiet and do not report progress. Default is False.
        out_no_data (Optional[int]): The output no data value. Default is 0.
        row_block_size (Optional[int]): The row block chunk size. Default is 2000.
        col_block_size (Optional[int]): The column block chunk size. Default is 2000.
        apply_all_bands (Optional[bool]): Whether to apply the equation to all bands. Default is False.
        **kwargs (str): The rasters to compute. E.g., A='/some_raster1.tif', F='/some_raster2.tif'.
            Band positions default to 1 unless given as [A]_band.

    Examples:
        >>> from mpglue.raster_calc import raster_calc
        >>>
        >>> # Multiply image A x image B
        >>> raster_calc('/output.tif',
        >>>             equation='A * B',
        >>>             A='/some_raster1.tif',
        >>>             B='some_raster2.tif')
        >>>
        >>> # Reads as...
        >>> # Where image A equals 1 AND image B is greater than 5,
        >>> #   THEN write image A, OTHERWISE write 0
        >>> raster_calc('/output.tif',
        >>>             equation='where((A == 1) & (B > 5), A, 0)',
        >>>             A='/some_raster1.tif',
        >>>             B='some_raster2.tif')
        >>>
        >>> # Use different bands from the same image. The letter given for the
        >>> #   image must be the same for the band, followed by _band.
        >>> # E.g., for raster 'n', the corresponding band would be 'n_band'. For
        >>> #   raster 'r', the corresponding band would be 'r_band', etc.
        >>> raster_calc('/output.tif',
        >>>             equation='(n - r) / (n + r)',
        >>>             n='/some_raster.tif',
        >>>             n_band=4,
        >>>             r='/some_raster.tif',
        >>>             r_band=3)

    Returns:
        None, writes to ``output``.
    """

    # Set the image dictionary
    image_dict = dict()
    info_dict = dict()
    info_list = list()
    band_dict = dict()

    temp_files = list()

    if isinstance(extent, str):

        ot_info = raster_tools.ropen(extent)

        temp_dict = copy(kwargs)

        for kw, vw in viewitems(kwargs):

            if isinstance(vw, str):

                d_name, f_name = os.path.split(vw)
                f_base, __ = os.path.splitext(f_name)

                vw_sub = os.path.join(d_name, '{}_temp.vrt'.format(f_base))

                raster_tools.translate(vw,
                                       vw_sub,
                                       format='VRT',
                                       projWin=[
                                           ot_info.left, ot_info.top,
                                           ot_info.right, ot_info.bottom
                                       ])

                temp_files.append(vw_sub)

                temp_dict[kw] = vw_sub

        kwargs = temp_dict

    for kw, vw in viewitems(kwargs):

        if '_band' not in kw:
            band_dict['{}_band'.format(kw)] = 1

        if isinstance(vw, str):

            image_dict[kw] = vw

            exec('i_info_{} = raster_tools.ropen(r"{}")'.format(kw, vw))
            exec('info_dict["{}"] = i_info_{}'.format(kw, kw))
            exec('info_list.append(i_info_{})'.format(kw))

        if isinstance(vw, int):
            band_dict[kw] = vw

    for key, value in viewitems(image_dict):
        equation = equation.replace(key, 'marrvar_{}'.format(key))

    # Check for NumPy functions.
    # for np_func in dir(np):
    #
    #     if 'np.' + np_func in equation:
    #
    #         equation = 'np.{}'.format(equation)
    #         break

    for kw, vw in viewitems(info_dict):

        o_info = copy(vw)
        break

    n_bands = 1 if not apply_all_bands else o_info.bands

    if isinstance(extent, raster_tools.ropen):

        # Set the extent from an object.
        overlap_info = extent

    elif isinstance(extent, str):

        # Set the extent from an existing image.
        overlap_info = raster_tools.ropen(extent)

    else:

        # Check overlapping extent
        overlap_info = info_list[0].copy()

        for i_ in range(1, len(info_list)):

            # Get the minimum overlapping extent
            # from all input images.
            overlap_info = raster_tools.GetMinExtent(overlap_info,
                                                     info_list[i_])

    o_info.update_info(left=overlap_info.left,
                       right=overlap_info.right,
                       top=overlap_info.top,
                       bottom=overlap_info.bottom,
                       rows=overlap_info.rows,
                       cols=overlap_info.cols,
                       storage=out_type,
                       bands=n_bands)

    if overwrite:
        overwrite_file(output)

    out_rst = raster_tools.create_raster(output, o_info)

    if n_bands == 1:
        out_rst.get_band(1)

    block_rows, block_cols = raster_tools.block_dimensions(
        o_info.rows,
        o_info.cols,
        row_block_size=row_block_size,
        col_block_size=col_block_size)

    if not be_quiet:
        ctr, pbar = _iteration_parameters(o_info.rows, o_info.cols, block_rows,
                                          block_cols)

    # Iterate over the minimum overlapping extent.
    for i in range(0, o_info.rows, block_rows):

        n_rows = raster_tools.n_rows_cols(i, block_rows, o_info.rows)

        for j in range(0, o_info.cols, block_cols):

            n_cols = raster_tools.n_rows_cols(j, block_cols, o_info.cols)

            # For each image, get the offset and
            # convert bands in the equation to ndarrays.
            for key, value in viewitems(image_dict):

                # exec 'x_off, y_off = vector_tools.get_xy_offsets3(overlap_info, i_info_{})'.format(key)
                x_off, y_off = vector_tools.get_xy_offsets(
                    image_info=info_dict[key],
                    x=overlap_info.left,
                    y=overlap_info.top,
                    check_position=False)[2:]

                exec(
                    'marrvar_{KEY} = info_dict["{KEY}"].read(bands2open=band_dict["{KEY}_band"], i=i+y_off, j=j+x_off, rows=n_rows, cols=n_cols, d_type="float32")'
                    .format(KEY=key))

            if '&&' in equation:

                out_array = np.empty((n_bands, n_rows, n_cols),
                                     dtype='float32')

                for eqidx, equation_ in enumerate(equation.split('&&')):

                    if 'nan_to_num' in equation_:

                        if not equation_.startswith('np.'):
                            equation_ = 'np.' + equation_

                        equation_ = 'out_array[eqidx] = {}'.format(equation_)
                        exec(equation_)

                    else:
                        out_array[eqidx] = ne.evaluate(equation_)

            else:

                if 'nan_to_num' in equation:

                    equation_ = 'out_array = {}'.format(equation)
                    exec(equation_)

                else:
                    out_array = ne.evaluate(equation)

            # Set the output no data values.
            out_array[np.isnan(out_array) | np.isinf(out_array)] = out_no_data

            if n_bands == 1:

                out_rst.write_array(out_array, i=i, j=j)

            else:

                for lidx in range(0, n_bands):

                    out_rst.write_array(out_array[lidx],
                                        i=i,
                                        j=j,
                                        band=lidx + 1)

            if not be_quiet:

                pbar.update(ctr)
                ctr += 1

    if not be_quiet:
        pbar.finish()

    # Close the input image.
    for key, value in viewitems(info_dict):
        info_dict[key].close()

    # close the output drivers
    out_rst.close_all()

    out_rst = None

    # Cleanup
    for temp_file in temp_files:

        if os.path.isfile(temp_file):
            os.remove(temp_file)