def run(parameter_object): """ Args: input_image, output_dir, band_positions=[1], use_rgb=False, block=2, scales=[8], triggers=['mean'], threshold=20, min_len=10, line_gap=2, weighted=False, sfs_thresh=80, resamp_sfs=0., equalize=False, equalize_adapt=False, smooth=0, visualize=False, convert_stk=False, gdal_cache=256, do_pca=False, stack_feas=True, stack_only=False, neighbors=False, n_jobs=-1, reset_sects=False, image_max=0, lac_r=2, section_size=8000, chunk_size=512 """ global potsi, param_dict if parameter_object.n_jobs == 0: parameter_object.n_jobs = 1 elif parameter_object.n_jobs < 0: parameter_object.n_jobs = multi.cpu_count() elif parameter_object.n_jobs > multi.cpu_count(): parameter_object.n_jobs = multi.cpu_count() sputilities.parameter_checks(parameter_object) # Write the parameters to file. sputilities.write_log(parameter_object) if parameter_object.stack_only: new_feas_list = list() # If prompted, stack features without processing. parameter_object = sputilities.stack_features(parameter_object, new_feas_list) else: # Create the status object. mts = sputilities.ManageStatus() parameter_object.remove_files = False # Setup the status dictionary. if os.path.isfile(parameter_object.status_file): mts.load_status(parameter_object.status_file) if parameter_object.section_size != mts.status_dict['SECTION_SIZE']: logger.warning( 'The section size was changed, so all existing tiled images will be removed.' ) parameter_object.remove_files = True if not isinstance(mts.status_dict, dict): logger.error( 'The YAML file already existed, but was not properly stored and saved.\nPlease remove and re-run.' ) raise AttributeError else: mts.status_dict = dict() mts.status_dict['ALL_FINISHED'] = 'no' mts.status_dict['BAND_ORDER'] = dict() # Save the band order. for trigger in parameter_object.triggers: mts.status_dict['BAND_ORDER']['{}'.format( trigger)] = '{:d}-{:d}'.format( parameter_object.band_info[trigger] + 1, parameter_object.band_info[trigger] + parameter_object.out_bands_dict[trigger] * parameter_object.n_bands) mts.status_dict['SECTION_SIZE'] = parameter_object.section_size mts.dump_status(parameter_object.status_file) process_image = True if 'ALL_FINISHED' in mts.status_dict: if mts.status_dict['ALL_FINISHED'] == 'yes': process_image = False # Set the output features folder. parameter_object = sputilities.set_feas_dir(parameter_object) if parameter_object.remove_files: image_list = fnmatch.filter(os.listdir(parameter_object.feas_dir), '*.tif') if image_list: image_list = [ os.path.join(parameter_object.feas_dir, im_) for im_ in image_list ] for full_image in image_list: os.remove(full_image) if not process_image: logger.warning( 'The input image, {}, is set as finished processing.'.format( parameter_object.input_image)) else: original_band_positions = copy.copy( parameter_object.band_positions) # Iterate over each feature trigger. for trigger in parameter_object.triggers: parameter_object.update_info( trigger=trigger, band_positions=original_band_positions, band_counter=0) # Iterate over each band for band_position in parameter_object.band_positions: parameter_object.update_info(band_position=band_position) # Get the input image information. with raster_tools.ropen( parameter_object.input_image) as i_info: # Check if any of the input # bands are corrupted. i_info.check_corrupted_bands() if i_info.corrupted_bands: logger.error( '\nThe following bands appear to be corrupted:\n{}' .format(', '.join(i_info.corrupted_bands))) raise CorruptedBandsError # Get image statistics. parameter_object = sputilities.get_stats( i_info, parameter_object) # Get the section size. parameter_object = sputilities.get_section_size( i_info, parameter_object) # Get the number of sections in # the image (only used as a counter). parameter_object = sputilities.get_n_sects( i_info, parameter_object) if parameter_object.trigger == 'saliency': bp = raster_tools.BlockFunc( get_saliency_tile_mean, [i_info], None, None, band_list=[[1, 2, 3]], d_types=['float32'], write_array=False, close_files=False, be_quiet=True, print_statement= '\nGetting tile lab means for saliency', out_attributes=['lab_means'], block_rows=parameter_object.sect_row_size, block_cols=parameter_object.sect_col_size, min_max=[(parameter_object.image_min, parameter_object.image_max)] * 3, vis_order=parameter_object.vis_order) bp.run() parameter_object.update_info(lab_means=np.array( bp.lab_means, dtype='float32').mean(axis=0)) del i_info mts = sputilities.ManageStatus() mts.load_status(parameter_object.status_file) for sect_counter in range(1, parameter_object.n_sects + 1): parameter_object.update_info( section_counter=sect_counter) parameter_object = sputilities.scale_fea_check( parameter_object) if trigger == parameter_object.triggers[0]: mts.status_dict[ parameter_object.out_img_base] = dict() mts.status_dict[parameter_object.out_img_base][ '{TR}-{BD}'.format(TR=parameter_object.trigger, BD=parameter_object. band_position)] = 'unprocessed' mts.dump_status(parameter_object.status_file) param_dict = sputilities.class2dict(parameter_object) potsi = parameter_object.section_idx_pairs # PROCESS IN PARALLEL CHUNKS parallel_chunk_counter = 1 for parallel_chunk in range( 1, parameter_object.n_sects + parameter_object.n_jobs, parameter_object.n_jobs): pool = multi.Pool(processes=parameter_object.n_jobs) if parallel_chunk + parameter_object.n_jobs < parameter_object.n_sects: parallel_chunk_end = parallel_chunk + parameter_object.n_jobs else: parallel_chunk_end = parallel_chunk + ( parameter_object.n_sects - parallel_chunk) + 1 results = map( _section_read_write, range(parallel_chunk, parallel_chunk_end)) # results = pool.map(_section_read_write, # range(parallel_chunk, # parallel_chunk_end)) # # pool.close() # pool.join() # del pool logger.info(' Updating status ...') for result in results: parameter_object.update_info( section_counter=parallel_chunk_counter) parameter_object = sputilities.scale_fea_check( parameter_object) # Open the status YAML file. mts = sputilities.ManageStatus() # Load the status dictionary mts.load_status(parameter_object.status_file) if parameter_object.out_img_base in mts.status_dict: if result: mts.status_dict[ parameter_object.out_img_base][ '{TR}-{BD}'.format( TR=parameter_object.trigger, BD=parameter_object. band_position)] = 'corrupt' else: mts.status_dict[ parameter_object.out_img_base][ '{TR}-{BD}'.format( TR=parameter_object.trigger, BD=parameter_object. band_position)] = 'complete' mts.dump_status(parameter_object.status_file) parallel_chunk_counter += 1 # Parallel(n_jobs=parameter_object.n_jobs, # batch_size=1, # max_nbytes=None)(delayed(_section_read_write)(idx_pair, # parameter_object.section_idx_pairs[idx_pair-1], # param_dict) # for idx_pair in range(1, parameter_object.n_sects+1)) parameter_object.band_counter += parameter_object.out_bands_dict[ parameter_object.trigger] # Check the corruption status. mts.load_status(parameter_object.status_file) n_corrupt = 0 for k, v in mts.status_dict.items(): if isinstance(v, dict): for ksub, vsub in v.iteritems(): if vsub in ['corrupt', 'incomplete']: n_corrupt += 1 if n_corrupt == 0: mts.status_dict['ALL_FINISHED'] = 'yes' mts.dump_status(parameter_object.status_file) # Finally, mosaic the image tiles. logger.info(' Creating the VRT mosaic ...') comp_dict = dict() # Get the image list. parameter_object = sputilities.scale_fea_check(parameter_object, is_image=False) image_list = fnmatch.filter(os.listdir(parameter_object.feas_dir), parameter_object.search_wildcard) image_list = [ os.path.join(parameter_object.feas_dir, im) for im in image_list ] comp_dict['001'] = image_list vrt_mosaic = parameter_object.status_file.replace('.yaml', '.vrt') vrt_builder(comp_dict, vrt_mosaic, force_type='float32', be_quiet=True, overwrite=True) if parameter_object.overviews: logger.info('\nBuilding VRT overviews ...') with raster_tools.ropen(vrt_mosaic, open2read=False) as vrt_info: vrt_info.remove_overviews() vrt_info.build_overviews(levels=[2, 4, 8, 16]) del vrt_info else: if n_corrupt == 1: logger.warning( '\nThere was {:d} corrupt or incomplete tile.\nRe-run the command with the same parameters.' .format(n_corrupt)) else: logger.warning( '\nThere were {:d} corrupt or incomplete tiles.\nRe-run the command with the same parameters.' .format(n_corrupt))
def stack_features(parameter_object, new_feas_list): """ Stacks features """ for trigger in parameter_object.triggers: parameter_object.update_info(trigger=trigger) # Set the output features folder. parameter_object = set_feas_dir(parameter_object) for band_p in parameter_object.band_positions: parameter_object.update_info(band_position=band_p) # Get feature names obds = 1 for scale in parameter_object.scales: parameter_object.update_info(scale=scale) for feature in xrange( 1, parameter_object.features_dict[trigger] + 1): parameter_object.update_info(feature=feature) parameter_object = scale_fea_check(parameter_object) # skip the feature if it doesn't exist if not os.path.isfile(parameter_object.out_img): continue new_feas_list.append(parameter_object.out_img) obds += 1 scs_str = [str(sc) for sc in parameter_object.scales] band_pos_str = [str(bp) for bp in parameter_object.band_positions] # write band list to text fea_list_txt = os.path.join( parameter_object.output_dir, '{}.{}.stk.bd{}.block{}.scales{}_fea_list.txt'.format( parameter_object.f_base, '-'.join(parameter_object.triggers), '-'.join(band_pos_str), parameter_object.block, '-'.join(scs_str))) # remove stacked VRT list if os.path.isfile(fea_list_txt): os.remove(fea_list_txt) with open(fea_list_txt, 'wb') as fea_list_txt_wr: fea_list_txt_wr.write('Layer Name\n') for fea_ctr, fea_name in enumerate(new_feas_list): fea_list_txt_wr.write('{:d} {}\n'.format(fea_ctr + 1, fea_name)) # stack features here out_vrt = os.path.join( parameter_object.output_dir, '{}.{}.stk.bd{}.block{}.scales{}.vrt'.format( parameter_object.f_base, '-'.join(parameter_object.triggers), '-'.join(band_pos_str), parameter_object.block, '-'.join(scs_str))) if os.path.isfile(out_vrt): os.remove(out_vrt) stack_dict = dict() for ni, new_feas in enumerate(new_feas_list): stack_dict[str(ni + 1)] = [new_feas] logger.info('Stacking variables ...') vrt_builder(stack_dict, out_vrt, force_type='float32', be_quiet=True) parameter_object.update_info(out_vrt=out_vrt) return parameter_object