Beispiel #1
0
def run(parameter_object):
    """
    Args:
        input_image, output_dir, band_positions=[1], use_rgb=False, block=2, scales=[8], triggers=['mean'],
        threshold=20, min_len=10, line_gap=2, weighted=False, sfs_thresh=80, resamp_sfs=0.,
        equalize=False, equalize_adapt=False, smooth=0, visualize=False, convert_stk=False, gdal_cache=256,
        do_pca=False, stack_feas=True, stack_only=False, neighbors=False, n_jobs=-1,
        reset_sects=False, image_max=0, lac_r=2, section_size=8000, chunk_size=512
    """

    global potsi, param_dict

    if parameter_object.n_jobs == 0:
        parameter_object.n_jobs = 1
    elif parameter_object.n_jobs < 0:
        parameter_object.n_jobs = multi.cpu_count()
    elif parameter_object.n_jobs > multi.cpu_count():
        parameter_object.n_jobs = multi.cpu_count()

    sputilities.parameter_checks(parameter_object)

    # Write the parameters to file.
    sputilities.write_log(parameter_object)

    if parameter_object.stack_only:

        new_feas_list = list()

        # If prompted, stack features without processing.
        parameter_object = sputilities.stack_features(parameter_object,
                                                      new_feas_list)

    else:

        # Create the status object.
        mts = sputilities.ManageStatus()

        parameter_object.remove_files = False

        # Setup the status dictionary.
        if os.path.isfile(parameter_object.status_file):

            mts.load_status(parameter_object.status_file)

            if parameter_object.section_size != mts.status_dict['SECTION_SIZE']:

                logger.warning(
                    'The section size was changed, so all existing tiled images will be removed.'
                )

                parameter_object.remove_files = True

            if not isinstance(mts.status_dict, dict):

                logger.error(
                    'The YAML file already existed, but was not properly stored and saved.\nPlease remove and re-run.'
                )
                raise AttributeError

        else:

            mts.status_dict = dict()

            mts.status_dict['ALL_FINISHED'] = 'no'
            mts.status_dict['BAND_ORDER'] = dict()

            # Save the band order.
            for trigger in parameter_object.triggers:

                mts.status_dict['BAND_ORDER']['{}'.format(
                    trigger)] = '{:d}-{:d}'.format(
                        parameter_object.band_info[trigger] + 1,
                        parameter_object.band_info[trigger] +
                        parameter_object.out_bands_dict[trigger] *
                        parameter_object.n_bands)

            mts.status_dict['SECTION_SIZE'] = parameter_object.section_size

            mts.dump_status(parameter_object.status_file)

        process_image = True

        if 'ALL_FINISHED' in mts.status_dict:

            if mts.status_dict['ALL_FINISHED'] == 'yes':
                process_image = False

        # Set the output features folder.
        parameter_object = sputilities.set_feas_dir(parameter_object)

        if parameter_object.remove_files:

            image_list = fnmatch.filter(os.listdir(parameter_object.feas_dir),
                                        '*.tif')

            if image_list:

                image_list = [
                    os.path.join(parameter_object.feas_dir, im_)
                    for im_ in image_list
                ]

                for full_image in image_list:
                    os.remove(full_image)

        if not process_image:
            logger.warning(
                'The input image, {}, is set as finished processing.'.format(
                    parameter_object.input_image))
        else:

            original_band_positions = copy.copy(
                parameter_object.band_positions)

            # Iterate over each feature trigger.
            for trigger in parameter_object.triggers:

                parameter_object.update_info(
                    trigger=trigger,
                    band_positions=original_band_positions,
                    band_counter=0)

                # Iterate over each band
                for band_position in parameter_object.band_positions:

                    parameter_object.update_info(band_position=band_position)

                    # Get the input image information.
                    with raster_tools.ropen(
                            parameter_object.input_image) as i_info:

                        # Check if any of the input
                        #   bands are corrupted.
                        i_info.check_corrupted_bands()

                        if i_info.corrupted_bands:

                            logger.error(
                                '\nThe following bands appear to be corrupted:\n{}'
                                .format(', '.join(i_info.corrupted_bands)))
                            raise CorruptedBandsError

                        # Get image statistics.
                        parameter_object = sputilities.get_stats(
                            i_info, parameter_object)

                        # Get the section size.
                        parameter_object = sputilities.get_section_size(
                            i_info, parameter_object)

                        # Get the number of sections in
                        #   the image (only used as a counter).
                        parameter_object = sputilities.get_n_sects(
                            i_info, parameter_object)

                        if parameter_object.trigger == 'saliency':

                            bp = raster_tools.BlockFunc(
                                get_saliency_tile_mean, [i_info],
                                None,
                                None,
                                band_list=[[1, 2, 3]],
                                d_types=['float32'],
                                write_array=False,
                                close_files=False,
                                be_quiet=True,
                                print_statement=
                                '\nGetting tile lab means for saliency',
                                out_attributes=['lab_means'],
                                block_rows=parameter_object.sect_row_size,
                                block_cols=parameter_object.sect_col_size,
                                min_max=[(parameter_object.image_min,
                                          parameter_object.image_max)] * 3,
                                vis_order=parameter_object.vis_order)

                            bp.run()

                            parameter_object.update_info(lab_means=np.array(
                                bp.lab_means, dtype='float32').mean(axis=0))

                    del i_info

                    mts = sputilities.ManageStatus()
                    mts.load_status(parameter_object.status_file)

                    for sect_counter in range(1, parameter_object.n_sects + 1):

                        parameter_object.update_info(
                            section_counter=sect_counter)
                        parameter_object = sputilities.scale_fea_check(
                            parameter_object)

                        if trigger == parameter_object.triggers[0]:
                            mts.status_dict[
                                parameter_object.out_img_base] = dict()

                        mts.status_dict[parameter_object.out_img_base][
                            '{TR}-{BD}'.format(TR=parameter_object.trigger,
                                               BD=parameter_object.
                                               band_position)] = 'unprocessed'

                    mts.dump_status(parameter_object.status_file)

                    param_dict = sputilities.class2dict(parameter_object)
                    potsi = parameter_object.section_idx_pairs

                    # PROCESS IN PARALLEL CHUNKS

                    parallel_chunk_counter = 1

                    for parallel_chunk in range(
                            1,
                            parameter_object.n_sects + parameter_object.n_jobs,
                            parameter_object.n_jobs):

                        pool = multi.Pool(processes=parameter_object.n_jobs)

                        if parallel_chunk + parameter_object.n_jobs < parameter_object.n_sects:
                            parallel_chunk_end = parallel_chunk + parameter_object.n_jobs
                        else:
                            parallel_chunk_end = parallel_chunk + (
                                parameter_object.n_sects - parallel_chunk) + 1

                        results = map(
                            _section_read_write,
                            range(parallel_chunk, parallel_chunk_end))

                        # results = pool.map(_section_read_write,
                        #                    range(parallel_chunk,
                        #                          parallel_chunk_end))
                        #
                        # pool.close()
                        # pool.join()
                        # del pool

                        logger.info('  Updating status ...')

                        for result in results:

                            parameter_object.update_info(
                                section_counter=parallel_chunk_counter)
                            parameter_object = sputilities.scale_fea_check(
                                parameter_object)

                            # Open the status YAML file.
                            mts = sputilities.ManageStatus()

                            # Load the status dictionary
                            mts.load_status(parameter_object.status_file)

                            if parameter_object.out_img_base in mts.status_dict:

                                if result:

                                    mts.status_dict[
                                        parameter_object.out_img_base][
                                            '{TR}-{BD}'.format(
                                                TR=parameter_object.trigger,
                                                BD=parameter_object.
                                                band_position)] = 'corrupt'

                                else:

                                    mts.status_dict[
                                        parameter_object.out_img_base][
                                            '{TR}-{BD}'.format(
                                                TR=parameter_object.trigger,
                                                BD=parameter_object.
                                                band_position)] = 'complete'

                            mts.dump_status(parameter_object.status_file)

                            parallel_chunk_counter += 1

                    # Parallel(n_jobs=parameter_object.n_jobs,
                    #          batch_size=1,
                    #          max_nbytes=None)(delayed(_section_read_write)(idx_pair,
                    #                                                        parameter_object.section_idx_pairs[idx_pair-1],
                    #                                                        param_dict)
                    #                           for idx_pair in range(1, parameter_object.n_sects+1))

                    parameter_object.band_counter += parameter_object.out_bands_dict[
                        parameter_object.trigger]

        # Check the corruption status.
        mts.load_status(parameter_object.status_file)

        n_corrupt = 0
        for k, v in mts.status_dict.items():

            if isinstance(v, dict):

                for ksub, vsub in v.iteritems():

                    if vsub in ['corrupt', 'incomplete']:
                        n_corrupt += 1

        if n_corrupt == 0:

            mts.status_dict['ALL_FINISHED'] = 'yes'
            mts.dump_status(parameter_object.status_file)

            # Finally, mosaic the image tiles.

            logger.info('  Creating the VRT mosaic ...')

            comp_dict = dict()

            # Get the image list.
            parameter_object = sputilities.scale_fea_check(parameter_object,
                                                           is_image=False)

            image_list = fnmatch.filter(os.listdir(parameter_object.feas_dir),
                                        parameter_object.search_wildcard)
            image_list = [
                os.path.join(parameter_object.feas_dir, im)
                for im in image_list
            ]

            comp_dict['001'] = image_list

            vrt_mosaic = parameter_object.status_file.replace('.yaml', '.vrt')

            vrt_builder(comp_dict,
                        vrt_mosaic,
                        force_type='float32',
                        be_quiet=True,
                        overwrite=True)

            if parameter_object.overviews:

                logger.info('\nBuilding VRT overviews ...')

                with raster_tools.ropen(vrt_mosaic,
                                        open2read=False) as vrt_info:

                    vrt_info.remove_overviews()
                    vrt_info.build_overviews(levels=[2, 4, 8, 16])

                del vrt_info

        else:

            if n_corrupt == 1:
                logger.warning(
                    '\nThere was {:d} corrupt or incomplete tile.\nRe-run the command with the same parameters.'
                    .format(n_corrupt))
            else:
                logger.warning(
                    '\nThere were {:d} corrupt or incomplete tiles.\nRe-run the command with the same parameters.'
                    .format(n_corrupt))
Beispiel #2
0
def stack_features(parameter_object, new_feas_list):
    """
    Stacks features
    """

    for trigger in parameter_object.triggers:

        parameter_object.update_info(trigger=trigger)

        # Set the output features folder.
        parameter_object = set_feas_dir(parameter_object)

        for band_p in parameter_object.band_positions:

            parameter_object.update_info(band_position=band_p)

            # Get feature names
            obds = 1
            for scale in parameter_object.scales:

                parameter_object.update_info(scale=scale)

                for feature in xrange(
                        1, parameter_object.features_dict[trigger] + 1):

                    parameter_object.update_info(feature=feature)

                    parameter_object = scale_fea_check(parameter_object)

                    # skip the feature if it doesn't exist
                    if not os.path.isfile(parameter_object.out_img):
                        continue

                    new_feas_list.append(parameter_object.out_img)

                    obds += 1

    scs_str = [str(sc) for sc in parameter_object.scales]
    band_pos_str = [str(bp) for bp in parameter_object.band_positions]

    # write band list to text
    fea_list_txt = os.path.join(
        parameter_object.output_dir,
        '{}.{}.stk.bd{}.block{}.scales{}_fea_list.txt'.format(
            parameter_object.f_base, '-'.join(parameter_object.triggers),
            '-'.join(band_pos_str), parameter_object.block, '-'.join(scs_str)))

    # remove stacked VRT list
    if os.path.isfile(fea_list_txt):
        os.remove(fea_list_txt)

    with open(fea_list_txt, 'wb') as fea_list_txt_wr:

        fea_list_txt_wr.write('Layer Name\n')

        for fea_ctr, fea_name in enumerate(new_feas_list):
            fea_list_txt_wr.write('{:d} {}\n'.format(fea_ctr + 1, fea_name))

    # stack features here
    out_vrt = os.path.join(
        parameter_object.output_dir,
        '{}.{}.stk.bd{}.block{}.scales{}.vrt'.format(
            parameter_object.f_base, '-'.join(parameter_object.triggers),
            '-'.join(band_pos_str), parameter_object.block, '-'.join(scs_str)))

    if os.path.isfile(out_vrt):
        os.remove(out_vrt)

    stack_dict = dict()

    for ni, new_feas in enumerate(new_feas_list):
        stack_dict[str(ni + 1)] = [new_feas]

    logger.info('Stacking variables ...')

    vrt_builder(stack_dict, out_vrt, force_type='float32', be_quiet=True)

    parameter_object.update_info(out_vrt=out_vrt)

    return parameter_object