def _create_best_tracks(start_time_string, end_time_string, top_input_dir_name, data_source, top_output_dir_name, tracking_scale_metres2): """Runs the best-track algorithm with default parameters. :param start_time_string: See documentation at top of file. :param end_time_string: Same. :param top_input_dir_name: Same. :param data_source: Same. :param top_output_dir_name: Same. :param tracking_scale_metres2: Same. """ tracking_utils.check_data_source(data_source) start_time_unix_sec = time_conversion.string_to_unix_sec( start_time_string, INPUT_TIME_FORMAT) end_time_unix_sec = time_conversion.string_to_unix_sec( end_time_string, INPUT_TIME_FORMAT) first_date_string = time_conversion.time_to_spc_date_string( start_time_unix_sec) last_date_string = time_conversion.time_to_spc_date_string( end_time_unix_sec) file_dictionary = best_tracks_smart_io.find_files_for_smart_io( start_time_unix_sec=start_time_unix_sec, start_spc_date_string=first_date_string, end_time_unix_sec=end_time_unix_sec, end_spc_date_string=last_date_string, data_source=data_source, tracking_scale_metres2=tracking_scale_metres2, top_input_dir_name=top_input_dir_name, top_output_dir_name=top_output_dir_name) best_tracks_smart_io.run_best_track(smart_file_dict=file_dictionary)
def _download_rap_analyses(first_init_time_string, last_init_time_string, top_local_directory_name): """Downloads zero-hour analyses from the RAP (Rapid Refresh) model. :param first_init_time_string: See documentation at top of file. :param last_init_time_string: Same. :param top_local_directory_name: Same. """ first_init_time_unix_sec = time_conversion.string_to_unix_sec( first_init_time_string, INPUT_TIME_FORMAT) last_init_time_unix_sec = time_conversion.string_to_unix_sec( last_init_time_string, INPUT_TIME_FORMAT) time_interval_sec = HOURS_TO_SECONDS * nwp_model_utils.get_time_steps( nwp_model_utils.RAP_MODEL_NAME)[1] init_times_unix_sec = time_periods.range_and_interval_to_list( start_time_unix_sec=first_init_time_unix_sec, end_time_unix_sec=last_init_time_unix_sec, time_interval_sec=time_interval_sec) init_time_strings = [ time_conversion.unix_sec_to_string(t, DEFAULT_TIME_FORMAT) for t in init_times_unix_sec] num_init_times = len(init_times_unix_sec) local_file_names = [None] * num_init_times for i in range(num_init_times): local_file_names[i] = nwp_model_io.find_rap_file_any_grid( top_directory_name=top_local_directory_name, init_time_unix_sec=init_times_unix_sec[i], lead_time_hours=0, raise_error_if_missing=False) if local_file_names[i] is not None: continue local_file_names[i] = nwp_model_io.download_rap_file_any_grid( top_local_directory_name=top_local_directory_name, init_time_unix_sec=init_times_unix_sec[i], lead_time_hours=0, raise_error_if_fails=False) if local_file_names[i] is None: print '\nPROBLEM. Download failed for {0:s}.\n\n'.format( init_time_strings[i]) else: print '\nSUCCESS. File was downloaded to "{0:s}".\n\n'.format( local_file_names[i]) time.sleep(SECONDS_TO_PAUSE_BETWEEN_FILES) num_downloaded = numpy.sum(numpy.array( [f is not None for f in local_file_names])) print '{0:d} of {1:d} files were downloaded successfully!'.format( num_downloaded, num_init_times)
def _run(top_input_dir_name, top_output_dir_name, first_spc_date_string, last_spc_date_string, first_time_string, last_time_string, radar_source_name, for_storm_climatology): """Reanalyzes storm tracks across many SPC dates. This is effectively the main method. :param top_input_dir_name: See documentation at top of file. :param top_output_dir_name: Same. :param first_spc_date_string: Same. :param last_spc_date_string: Same. :param first_time_string: Same. :param last_time_string: Same. :param radar_source_name: Same. :param for_storm_climatology: Same. """ if (for_storm_climatology and radar_source_name == radar_utils.MYRORSS_SOURCE_ID): myrorss_start_time_unix_sec = time_conversion.string_to_unix_sec( MYRORSS_START_TIME_STRING, TIME_FORMAT) myrorss_end_time_unix_sec = time_conversion.string_to_unix_sec( MYRORSS_END_TIME_STRING, TIME_FORMAT) tracking_start_time_unix_sec = myrorss_start_time_unix_sec + 0 tracking_end_time_unix_sec = myrorss_end_time_unix_sec + 0 else: tracking_start_time_unix_sec = None tracking_end_time_unix_sec = None if first_time_string in ['', 'None'] or last_time_string in ['', 'None']: first_time_unix_sec = None last_time_unix_sec = None else: first_time_unix_sec = time_conversion.string_to_unix_sec( first_time_string, TIME_FORMAT) last_time_unix_sec = time_conversion.string_to_unix_sec( last_time_string, TIME_FORMAT) echo_top_tracking.reanalyze_tracks_across_spc_dates( top_input_dir_name=top_input_dir_name, top_output_dir_name=top_output_dir_name, first_spc_date_string=first_spc_date_string, last_spc_date_string=last_spc_date_string, first_time_unix_sec=first_time_unix_sec, last_time_unix_sec=last_time_unix_sec, tracking_start_time_unix_sec=tracking_start_time_unix_sec, tracking_end_time_unix_sec=tracking_end_time_unix_sec, min_track_duration_seconds=0)
def _convert_files(top_probsevere_dir_name, date_string, top_output_dir_name): """Converts probSevere tracking files for one day. :param top_probsevere_dir_name: See documentation at top of file. :param date_string: Same. :param top_output_dir_name: Same. """ date_unix_sec = time_conversion.string_to_unix_sec(date_string, DATE_FORMAT) raw_file_names = probsevere_io.find_raw_files_one_day( top_directory_name=top_probsevere_dir_name, unix_time_sec=date_unix_sec, file_extension=probsevere_io.ASCII_FILE_EXTENSION, raise_error_if_all_missing=True) for this_raw_file_name in raw_file_names: print 'Reading data from "{0:s}"...'.format(this_raw_file_name) this_storm_object_table = probsevere_io.read_raw_file( this_raw_file_name) this_time_unix_sec = probsevere_io.raw_file_name_to_time( this_raw_file_name) this_new_file_name = tracking_io.find_processed_file( unix_time_sec=this_time_unix_sec, data_source=tracking_utils.PROBSEVERE_SOURCE_ID, top_processed_dir_name=top_output_dir_name, tracking_scale_metres2=DUMMY_TRACKING_SCALE_METRES2, raise_error_if_missing=False) print 'Writing data to "{0:s}"...'.format(this_new_file_name) tracking_io.write_processed_file( storm_object_table=this_storm_object_table, pickle_file_name=this_new_file_name)
def _time_string_to_unix_sec(time_string): """Converts time from string to Unix format. :param time_string: Time string (format "dd-mmm-yy HH:MM:SS"). :return: unix_time_sec: Time in Unix format. """ time_string = _capitalize_months(time_string) return time_conversion.string_to_unix_sec(time_string, TIME_FORMAT)
def _run(top_input_dir_name, first_spc_date_string, last_spc_date_string, first_time_string, last_time_string, max_velocity_diff_m_s01, max_link_distance_m_s01, max_join_time_seconds, max_join_error_m_s01, min_duration_seconds, top_output_dir_name): """Reanalyzes storm tracks (preferably over many SPC dates). This is effectively the main method. :param top_input_dir_name: See documentation at top of file. :param first_spc_date_string: Same. :param last_spc_date_string: Same. :param first_time_string: Same. :param last_time_string: Same. :param max_velocity_diff_m_s01: Same. :param max_link_distance_m_s01: Same. :param max_join_time_seconds: Same. :param max_join_error_m_s01: Same. :param min_duration_seconds: Same. :param top_output_dir_name: Same. """ if first_time_string in ['', 'None'] or last_time_string in ['', 'None']: first_time_unix_sec = None last_time_unix_sec = None else: first_time_unix_sec = time_conversion.string_to_unix_sec( first_time_string, TIME_FORMAT) last_time_unix_sec = time_conversion.string_to_unix_sec( last_time_string, TIME_FORMAT) echo_top_tracking.reanalyze_across_spc_dates( top_input_dir_name=top_input_dir_name, top_output_dir_name=top_output_dir_name, first_spc_date_string=first_spc_date_string, last_spc_date_string=last_spc_date_string, first_time_unix_sec=first_time_unix_sec, last_time_unix_sec=last_time_unix_sec, max_velocity_diff_m_s01=max_velocity_diff_m_s01, max_link_distance_m_s01=max_link_distance_m_s01, max_join_time_seconds=max_join_time_seconds, max_join_error_m_s01=max_join_error_m_s01, min_track_duration_seconds=min_duration_seconds)
def _date_string_to_unix_sec(date_string): """Converts date from string to Unix format. :param date_string: String in format "xx yyyy mm dd HH MM SS", where the xx is meaningless. :return: unix_date_sec: Date in Unix format. """ words = date_string.split() date_string = words[1] + '-' + words[2] + '-' + words[3] return time_conversion.string_to_unix_sec(date_string, TIME_FORMAT_DATE)
def _local_time_string_to_unix_sec(local_time_string, utc_offset_hours): """Converts time from local string to Unix format. :param local_time_string: Local time (format "yyyymmddHHMM"). :param utc_offset_hours: Local time minus UTC. :return: unix_time_sec: Time in Unix format. """ return time_conversion.string_to_unix_sec( local_time_string, TIME_FORMAT_HOUR_MINUTE) - (utc_offset_hours * HOURS_TO_SECONDS)
def read_gridrad_stats_from_thea(csv_file_name): """Reads radar statistics created by GridRad software (file format by Thea). :param csv_file_name: Path to input file. :return: gridrad_statistic_table: pandas DataFrame with mandatory columns listed below. Other column names come from the list `GRIDRAD_STATISTIC_NAMES`. gridrad_statistic_table.storm_number: Numeric ID (integer) for storm cell. gridrad_statistic_table.unix_time_sec: Valid time of storm object. """ error_checking.assert_file_exists(csv_file_name) gridrad_statistic_table = pandas.read_csv(csv_file_name, header=0, sep=',') # Convert times from Thea's format to Unix format. unix_times_sec = numpy.array([ time_conversion.string_to_unix_sec(s, GRIDRAD_TIME_FORMAT) for s in gridrad_statistic_table[TIME_NAME_GRIDRAD_ORIG].values ]) gridrad_statistic_table = gridrad_statistic_table.assign( **{tracking_utils.TIME_COLUMN: unix_times_sec}) columns_to_keep = GRIDRAD_STATISTIC_NAMES_ORIG + [ STORM_NUMBER_NAME_GRIDRAD_ORIG, tracking_utils.TIME_COLUMN ] gridrad_statistic_table = gridrad_statistic_table[columns_to_keep] # Rename columns. column_dict_old_to_new = { STORM_NUMBER_NAME_GRIDRAD_ORIG: STORM_NUMBER_NAME_GRIDRAD, ECHO_TOP_40DBZ_NAME_GRIDRAD_ORIG: ECHO_TOP_40DBZ_NAME_GRIDRAD, SPECTRUM_WIDTH_NAME_GRIDRAD_ORIG: SPECTRUM_WIDTH_NAME_GRIDRAD, MAX_DIVERGENCE_NAME_GRIDRAD_ORIG: MAX_DIVERGENCE_NAME_GRIDRAD, UPPER_LEVEL_DIVERGENCE_NAME_GRIDRAD_ORIG: UPPER_LEVEL_DIVERGENCE_NAME_GRIDRAD, LOW_LEVEL_CONVERGENCE_NAME_GRIDRAD_ORIG: LOW_LEVEL_CONVERGENCE_NAME_GRIDRAD, DIVERGENCE_AREA_NAME_GRIDRAD_ORIG: DIVERGENCE_AREA_NAME_GRIDRAD, MAX_ROTATION_NAME_GRIDRAD_ORIG: MAX_ROTATION_NAME_GRIDRAD, UPPER_LEVEL_ROTATION_NAME_GRIDRAD_ORIG: UPPER_LEVEL_ROTATION_NAME_GRIDRAD, LOW_LEVEL_ROTATION_NAME_GRIDRAD_ORIG: LOW_LEVEL_ROTATION_NAME_GRIDRAD } gridrad_statistic_table.rename(columns=column_dict_old_to_new, inplace=True) # Convert units of divergence/convergence. gridrad_statistic_table[LOW_LEVEL_CONVERGENCE_NAME_GRIDRAD] *= -1 for this_name in GRIDRAD_DIVERGENCE_NAMES: gridrad_statistic_table[ this_name] *= CONVERSION_RATIO_FOR_GRIDRAD_DIVERGENCE return gridrad_statistic_table
def file_name_to_time(gridrad_file_name): """Parses valid time from name of GridRad file. :param gridrad_file_name: Path to GridRad file. :return: unix_time_sec: Valid time. """ _, pathless_file_name = os.path.split(gridrad_file_name) extensionless_file_name, _ = os.path.splitext(pathless_file_name) time_string = extensionless_file_name.split('_')[-1] return time_conversion.string_to_unix_sec(time_string, TIME_FORMAT_IN_FILE_NAMES)
def _file_name_to_valid_time(bulletin_file_name): """Parses valid time from file name. :param bulletin_file_name: Path to input file (text file in WPC format). :return: valid_time_unix_sec: Valid time. """ _, pathless_file_name = os.path.split(bulletin_file_name) valid_time_string = pathless_file_name.replace( PATHLESS_FILE_NAME_PREFIX + '_', '') return time_conversion.string_to_unix_sec(valid_time_string, TIME_FORMAT_IN_FILE_NAME)
def processed_file_name_to_time(processed_file_name): """Parses time from name of processed file. :param processed_file_name: Name of processed file. :return: unix_time_sec: Valid time. """ error_checking.assert_is_string(processed_file_name) _, pathless_file_name = os.path.split(processed_file_name) extensionless_file_name, _ = os.path.splitext(pathless_file_name) extensionless_file_name_parts = extensionless_file_name.split('_') return time_conversion.string_to_unix_sec( extensionless_file_name_parts[-1], TIME_FORMAT)
def raw_file_name_to_time(raw_file_name): """Parses time from file name. :param raw_file_name: Path to raw file. :return: unix_time_sec: Valid time. """ error_checking.assert_is_string(raw_file_name) _, time_string = os.path.split(raw_file_name) time_string = time_string.replace(ZIPPED_FILE_EXTENSION, '').replace(UNZIPPED_FILE_EXTENSION, '') return time_conversion.string_to_unix_sec(time_string, TIME_FORMAT_SECONDS)
def file_name_to_time(tracking_file_name): """Parses valid time from tracking file. :param tracking_file_name: Path to tracking file. :return: valid_time_unix_sec: Valid time. """ error_checking.assert_is_string(tracking_file_name) _, pathless_file_name = os.path.split(tracking_file_name) extensionless_file_name, _ = os.path.splitext(pathless_file_name) extensionless_file_name_parts = extensionless_file_name.split('_') return time_conversion.string_to_unix_sec( extensionless_file_name_parts[-1], FILE_NAME_TIME_FORMAT)
def find_file(top_directory_name, field_name, month_string, is_surface=False, raise_error_if_missing=True): """Finds NetCDF file on the local machine. :param top_directory_name: Name of top-level directory with NetCDF files containing NARR data. :param field_name: See doc for `_get_pathless_file_name`. :param month_string: Same. :param is_surface: Same. :param raise_error_if_missing: Boolean flag. If file is missing and `raise_error_if_missing = True`, this method will error out. :return: netcdf_file_name: Path to NetCDF file. If file is missing and `raise_error_if_missing = False`, will return the *expected* path. :raises: ValueError: if file is missing and `raise_error_if_missing = True`. """ error_checking.assert_is_string(top_directory_name) time_conversion.string_to_unix_sec(month_string, TIME_FORMAT_MONTH) error_checking.assert_is_boolean(is_surface) error_checking.assert_is_boolean(raise_error_if_missing) pathless_file_name = _get_pathless_file_name(field_name=field_name, month_string=month_string, is_surface=is_surface) netcdf_file_name = '{0:s}/{1:s}'.format(top_directory_name, pathless_file_name) if raise_error_if_missing and not os.path.isfile(netcdf_file_name): error_string = 'Cannot find file. Expected at: "{0:s}"'.format( netcdf_file_name) raise ValueError(error_string) return netcdf_file_name
def raw_file_name_to_time(raw_file_name): """Parses valid time from name of raw (either ASCII or JSON) file. :param raw_file_name: Path to raw file. :return: unix_time_sec: Valid time. """ error_checking.assert_is_string(raw_file_name) _, pathless_file_name = os.path.split(raw_file_name) extensionless_file_name, _ = os.path.splitext(pathless_file_name) time_string = extensionless_file_name.replace(RAW_FILE_NAME_PREFIX + '_', '') time_string = time_string.replace(ALT_RAW_FILE_NAME_PREFIX + '_', '') return time_conversion.string_to_unix_sec(time_string, RAW_FILE_TIME_FORMAT)
def _get_num_days_in_month(month, year): """Returns number of days in month. :param month: Month (integer from 1...12). :param year: Year (integer). :return: num_days_in_month: Number of days in month. """ time_string = '{0:04d}-{1:02d}'.format(year, month) unix_time_sec = time_conversion.string_to_unix_sec(time_string, '%Y-%m') (_, last_time_in_month_unix_sec ) = time_conversion.first_and_last_times_in_month(unix_time_sec) day_of_month_string = time_conversion.unix_sec_to_string( last_time_in_month_unix_sec, '%d') return int(day_of_month_string)
def processed_file_name_to_time(processed_file_name): """Parses time from name of processed tracking file. This file should contain storm objects (bounding polygons) and tracking statistics for one time step and one tracking scale. :param processed_file_name: Path to processed tracking file. :return: unix_time_sec: Valid time. """ error_checking.assert_is_string(processed_file_name) _, pathless_file_name = os.path.split(processed_file_name) extensionless_file_name, _ = os.path.splitext(pathless_file_name) extensionless_file_name_parts = extensionless_file_name.split('_') return time_conversion.string_to_unix_sec( extensionless_file_name_parts[-1], TIME_FORMAT_IN_FILE_NAMES)
def _write_metadata_one_model(argument_dict): """Writes metadata for one upconvnet to file. :param argument_dict: See doc for `_train_one_upconvnet`. :return: metadata_dict: See doc for `upconvnet.write_model_metadata`. """ from gewittergefahr.deep_learning import cnn from gewittergefahr.deep_learning import upconvnet from gewittergefahr.deep_learning import input_examples from gewittergefahr.scripts import train_upconvnet # Read input args. cnn_file_name = argument_dict[train_upconvnet.CNN_FILE_ARG_NAME] cnn_feature_layer_name = argument_dict[ train_upconvnet.FEATURE_LAYER_ARG_NAME] top_training_dir_name = argument_dict[train_upconvnet.TRAINING_DIR_ARG_NAME] first_training_time_string = argument_dict[ train_upconvnet.FIRST_TRAINING_TIME_ARG_NAME ] last_training_time_string = argument_dict[ train_upconvnet.LAST_TRAINING_TIME_ARG_NAME] top_validation_dir_name = argument_dict[ train_upconvnet.VALIDATION_DIR_ARG_NAME ] first_validation_time_string = argument_dict[ train_upconvnet.FIRST_VALIDATION_TIME_ARG_NAME ] last_validation_time_string = argument_dict[ train_upconvnet.LAST_VALIDATION_TIME_ARG_NAME] num_examples_per_batch = argument_dict[ train_upconvnet.NUM_EX_PER_BATCH_ARG_NAME ] num_epochs = argument_dict[train_upconvnet.NUM_EPOCHS_ARG_NAME] num_training_batches_per_epoch = argument_dict[ train_upconvnet.NUM_TRAINING_BATCHES_ARG_NAME ] num_validation_batches_per_epoch = argument_dict[ train_upconvnet.NUM_VALIDATION_BATCHES_ARG_NAME ] output_dir_name = argument_dict[train_upconvnet.OUTPUT_DIR_ARG_NAME] # Process input args. first_training_time_unix_sec = time_conversion.string_to_unix_sec( first_training_time_string, TIME_FORMAT) last_training_time_unix_sec = time_conversion.string_to_unix_sec( last_training_time_string, TIME_FORMAT) first_validation_time_unix_sec = time_conversion.string_to_unix_sec( first_validation_time_string, TIME_FORMAT) last_validation_time_unix_sec = time_conversion.string_to_unix_sec( last_validation_time_string, TIME_FORMAT) # Find training and validation files. training_file_names = input_examples.find_many_example_files( top_directory_name=top_training_dir_name, shuffled=True, first_batch_number=FIRST_BATCH_NUMBER, last_batch_number=LAST_BATCH_NUMBER, raise_error_if_any_missing=False) validation_file_names = input_examples.find_many_example_files( top_directory_name=top_validation_dir_name, shuffled=True, first_batch_number=FIRST_BATCH_NUMBER, last_batch_number=LAST_BATCH_NUMBER, raise_error_if_any_missing=False) # Write metadata. upconvnet_metafile_name = cnn.find_metafile( model_file_name='{0:s}/foo.h5'.format(output_dir_name), raise_error_if_missing=False ) print('Writing upconvnet metadata to: "{0:s}"...'.format( upconvnet_metafile_name )) return upconvnet.write_model_metadata( cnn_file_name=cnn_file_name, cnn_feature_layer_name=cnn_feature_layer_name, num_epochs=num_epochs, num_examples_per_batch=num_examples_per_batch, num_training_batches_per_epoch=num_training_batches_per_epoch, training_example_file_names=training_file_names, first_training_time_unix_sec=first_training_time_unix_sec, last_training_time_unix_sec=last_training_time_unix_sec, num_validation_batches_per_epoch=num_validation_batches_per_epoch, validation_example_file_names=validation_file_names, first_validation_time_unix_sec=first_validation_time_unix_sec, last_validation_time_unix_sec=last_validation_time_unix_sec, pickle_file_name=upconvnet_metafile_name)
from gewittergefahr.gg_utils import soundings from gewittergefahr.gg_utils import echo_top_tracking from gewittergefahr.gg_utils import time_conversion from gewittergefahr.gg_utils import number_rounding from gewittergefahr.gg_utils import storm_tracking_utils as tracking_utils HOURS_TO_SECONDS = 3600 STORM_TIME_FORMAT = '%Y-%m-%d-%H%M%S' MODEL_INIT_TIME_FORMAT = '%Y-%m-%d-%H' SEPARATOR_STRING = '\n\n' + '*' * 50 + '\n\n' WGRIB_EXE_NAME = '/condo/swatwork/ralager/wgrib/wgrib' WGRIB2_EXE_NAME = '/condo/swatwork/ralager/grib2/wgrib2/wgrib2' FIRST_RAP_TIME_STRING = '2012-05-01-00' FIRST_RAP_TIME_UNIX_SEC = time_conversion.string_to_unix_sec( FIRST_RAP_TIME_STRING, MODEL_INIT_TIME_FORMAT) SPC_DATE_ARG_NAME = 'spc_date_string' LEAD_TIMES_ARG_NAME = 'lead_times_seconds' LAG_TIME_ARG_NAME = 'lag_time_for_convective_contamination_sec' RUC_DIRECTORY_ARG_NAME = 'input_ruc_directory_name' RAP_DIRECTORY_ARG_NAME = 'input_rap_directory_name' TRACKING_DIR_ARG_NAME = 'input_tracking_dir_name' TRACKING_SCALE_ARG_NAME = 'tracking_scale_metres2' OUTPUT_DIR_ARG_NAME = 'output_sounding_dir_name' SPC_DATE_HELP_STRING = ( 'SPC (Storm Prediction Center) date in format "yyyymmdd". The RUC (Rapid ' 'Update Cycle) sounding will be interpolated to each storm object on this ' 'date, at each lead time in `{0:s}`.').format(LEAD_TIMES_ARG_NAME)
def _run(net_type_string, training_dir_name, validation_dir_name, input_model_file_name, output_model_dir_name, use_generator_for_training, use_generator_for_validn, predictor_names, target_names, heights_m_agl, omit_heating_rate, first_training_time_string, last_training_time_string, first_validn_time_string, last_validn_time_string, normalization_file_name, predictor_norm_type_string, predictor_min_norm_value, predictor_max_norm_value, vector_target_norm_type_string, vector_target_min_norm_value, vector_target_max_norm_value, scalar_target_norm_type_string, scalar_target_min_norm_value, scalar_target_max_norm_value, num_examples_per_batch, num_epochs, num_training_batches_per_epoch, num_validn_batches_per_epoch, plateau_lr_multiplier): """Trains neural net :param net_type_string: See documentation at top of training_args.py. :param training_dir_name: Same. :param validation_dir_name: Same. :param input_model_file_name: Same. :param output_model_dir_name: Same. :param use_generator_for_training: Same. :param use_generator_for_validn: Same. :param predictor_names: Same. :param target_names: Same. :param heights_m_agl: Same. :param omit_heating_rate: Same. :param first_training_time_string: Same. :param last_training_time_string: Same. :param first_validn_time_string: Same. :param last_validn_time_string: Same. :param normalization_file_name: Same. :param predictor_norm_type_string: Same. :param predictor_min_norm_value: Same. :param predictor_max_norm_value: Same. :param vector_target_norm_type_string: Same. :param vector_target_min_norm_value: Same. :param vector_target_max_norm_value: Same. :param scalar_target_norm_type_string: Same. :param scalar_target_min_norm_value: Same. :param scalar_target_max_norm_value: Same. :param num_examples_per_batch: Same. :param num_epochs: Same. :param num_training_batches_per_epoch: Same. :param num_validn_batches_per_epoch: Same. :param plateau_lr_multiplier: Same. """ if predictor_norm_type_string in NONE_STRINGS: predictor_norm_type_string = None if vector_target_norm_type_string in NONE_STRINGS: vector_target_norm_type_string = None if scalar_target_norm_type_string in NONE_STRINGS: scalar_target_norm_type_string = None neural_net.check_net_type(net_type_string) if len(heights_m_agl) and heights_m_agl[0] <= 0: heights_m_agl = ( training_args.NET_TYPE_TO_DEFAULT_HEIGHTS_M_AGL[net_type_string]) for n in predictor_names: example_utils.check_field_name(n) scalar_predictor_names = [ n for n in predictor_names if n in example_utils.ALL_SCALAR_PREDICTOR_NAMES ] vector_predictor_names = [ n for n in predictor_names if n in example_utils.ALL_VECTOR_PREDICTOR_NAMES ] for n in target_names: example_utils.check_field_name(n) scalar_target_names = [ n for n in target_names if n in example_utils.ALL_SCALAR_TARGET_NAMES ] vector_target_names = [ n for n in target_names if n in example_utils.ALL_VECTOR_TARGET_NAMES ] first_training_time_unix_sec = time_conversion.string_to_unix_sec( first_training_time_string, training_args.TIME_FORMAT) last_training_time_unix_sec = time_conversion.string_to_unix_sec( last_training_time_string, training_args.TIME_FORMAT) first_validn_time_unix_sec = time_conversion.string_to_unix_sec( first_validn_time_string, training_args.TIME_FORMAT) last_validn_time_unix_sec = time_conversion.string_to_unix_sec( last_validn_time_string, training_args.TIME_FORMAT) training_option_dict = { neural_net.EXAMPLE_DIRECTORY_KEY: training_dir_name, neural_net.BATCH_SIZE_KEY: num_examples_per_batch, neural_net.SCALAR_PREDICTOR_NAMES_KEY: scalar_predictor_names, neural_net.VECTOR_PREDICTOR_NAMES_KEY: vector_predictor_names, neural_net.SCALAR_TARGET_NAMES_KEY: scalar_target_names, neural_net.VECTOR_TARGET_NAMES_KEY: vector_target_names, neural_net.HEIGHTS_KEY: heights_m_agl, neural_net.OMIT_HEATING_RATE_KEY: omit_heating_rate, neural_net.NORMALIZATION_FILE_KEY: normalization_file_name, neural_net.PREDICTOR_NORM_TYPE_KEY: predictor_norm_type_string, neural_net.PREDICTOR_MIN_NORM_VALUE_KEY: predictor_min_norm_value, neural_net.PREDICTOR_MAX_NORM_VALUE_KEY: predictor_max_norm_value, neural_net.VECTOR_TARGET_NORM_TYPE_KEY: vector_target_norm_type_string, neural_net.VECTOR_TARGET_MIN_VALUE_KEY: vector_target_min_norm_value, neural_net.VECTOR_TARGET_MAX_VALUE_KEY: vector_target_max_norm_value, neural_net.SCALAR_TARGET_NORM_TYPE_KEY: scalar_target_norm_type_string, neural_net.SCALAR_TARGET_MIN_VALUE_KEY: scalar_target_min_norm_value, neural_net.SCALAR_TARGET_MAX_VALUE_KEY: scalar_target_max_norm_value, neural_net.FIRST_TIME_KEY: first_training_time_unix_sec, neural_net.LAST_TIME_KEY: last_training_time_unix_sec, # neural_net.MIN_COLUMN_LWP_KEY: 0.05, # neural_net.MAX_COLUMN_LWP_KEY: 1e12 } validation_option_dict = { neural_net.EXAMPLE_DIRECTORY_KEY: validation_dir_name, neural_net.BATCH_SIZE_KEY: num_examples_per_batch, neural_net.FIRST_TIME_KEY: first_validn_time_unix_sec, neural_net.LAST_TIME_KEY: last_validn_time_unix_sec } if input_model_file_name in NONE_STRINGS: model_object = u_net_architecture.create_model( option_dict=DEFAULT_ARCHITECTURE_OPTION_DICT, vector_loss_function=DEFAULT_VECTOR_LOSS_FUNCTION, scalar_loss_function=DEFAULT_SCALAR_LOSS_FUNCTION, num_output_channels=1) loss_function_or_dict = { 'conv_output': DEFAULT_VECTOR_LOSS_FUNCTION, 'dense_output': DEFAULT_SCALAR_LOSS_FUNCTION } else: print('Reading untrained model from: "{0:s}"...'.format( input_model_file_name)) model_object = neural_net.read_model(input_model_file_name) input_metafile_name = neural_net.find_metafile( model_dir_name=os.path.split(input_model_file_name)[0]) print('Reading loss function(s) from: "{0:s}"...'.format( input_metafile_name)) loss_function_or_dict = neural_net.read_metafile(input_metafile_name)[ neural_net.LOSS_FUNCTION_OR_DICT_KEY] print(SEPARATOR_STRING) if use_generator_for_training: neural_net.train_model_with_generator( model_object=model_object, output_dir_name=output_model_dir_name, num_epochs=num_epochs, num_training_batches_per_epoch=num_training_batches_per_epoch, training_option_dict=training_option_dict, use_generator_for_validn=use_generator_for_validn, num_validation_batches_per_epoch=num_validn_batches_per_epoch, validation_option_dict=validation_option_dict, net_type_string=net_type_string, loss_function_or_dict=loss_function_or_dict, do_early_stopping=True, plateau_lr_multiplier=plateau_lr_multiplier) else: neural_net.train_model_sans_generator( model_object=model_object, output_dir_name=output_model_dir_name, num_epochs=num_epochs, training_option_dict=training_option_dict, validation_option_dict=validation_option_dict, net_type_string=net_type_string, loss_function_or_dict=loss_function_or_dict, do_early_stopping=True, plateau_lr_multiplier=plateau_lr_multiplier)
Since the NARR is a reanalysis, valid time = initialization time always. In other words, all "forecasts" are zero-hour forecasts (analyses). """ import os.path import numpy from gewittergefahr.gg_io import netcdf_io from gewittergefahr.gg_io import downloads from gewittergefahr.gg_utils import time_conversion from gewittergefahr.gg_utils import error_checking from generalexam.ge_io import processed_narr_io SENTINEL_VALUE = -9e36 HOURS_TO_SECONDS = 3600 NARR_ZERO_TIME_UNIX_SEC = time_conversion.string_to_unix_sec( '1800-01-01-00', '%Y-%m-%d-%H') TIME_FORMAT_MONTH = '%Y%m' NETCDF_FILE_EXTENSION = '.nc' ONLINE_SURFACE_DIR_NAME = 'ftp://ftp.cdc.noaa.gov/Datasets/NARR/monolevel' ONLINE_PRESSURE_LEVEL_DIR_NAME = 'ftp://ftp.cdc.noaa.gov/Datasets/NARR/pressure' TEMPERATURE_NAME_NETCDF = 'air' HEIGHT_NAME_NETCDF = 'hgt' VERTICAL_VELOCITY_NAME_NETCDF = 'omega' SPECIFIC_HUMIDITY_NAME_NETCDF = 'shum' U_WIND_NAME_NETCDF = 'uwnd' V_WIND_NAME_NETCDF = 'vwnd' VALID_FIELD_NAMES_NETCDF = [
def _write_metadata_one_cnn(model_object, argument_dict): """Writes metadata for one CNN to file. :param model_object: Untrained CNN (instance of `keras.models.Model` or `keras.models.Sequential`). :param argument_dict: See doc for `_train_one_cnn`. :return: metadata_dict: See doc for `cnn.write_model_metadata`. :return: training_option_dict: Same. """ from gewittergefahr.deep_learning import cnn from gewittergefahr.deep_learning import input_examples from gewittergefahr.deep_learning import \ training_validation_io as trainval_io from gewittergefahr.scripts import deep_learning_helper as dl_helper # Read input args. sounding_field_names = argument_dict[dl_helper.SOUNDING_FIELDS_ARG_NAME] radar_field_name_by_channel = argument_dict[RADAR_FIELDS_KEY] layer_op_name_by_channel = argument_dict[LAYER_OPERATIONS_KEY] min_height_by_channel_m_agl = argument_dict[MIN_HEIGHTS_KEY] max_height_by_channel_m_agl = argument_dict[MAX_HEIGHTS_KEY] normalization_type_string = argument_dict[ dl_helper.NORMALIZATION_TYPE_ARG_NAME] normalization_file_name = argument_dict[ dl_helper.NORMALIZATION_FILE_ARG_NAME] min_normalized_value = argument_dict[dl_helper.MIN_NORM_VALUE_ARG_NAME] max_normalized_value = argument_dict[dl_helper.MAX_NORM_VALUE_ARG_NAME] target_name = argument_dict[dl_helper.TARGET_NAME_ARG_NAME] downsampling_classes = numpy.array( argument_dict[dl_helper.DOWNSAMPLING_CLASSES_ARG_NAME], dtype=int) downsampling_fractions = numpy.array( argument_dict[dl_helper.DOWNSAMPLING_FRACTIONS_ARG_NAME], dtype=float) monitor_string = argument_dict[dl_helper.MONITOR_ARG_NAME] weight_loss_function = bool(argument_dict[dl_helper.WEIGHT_LOSS_ARG_NAME]) x_translations_pixels = numpy.array( argument_dict[dl_helper.X_TRANSLATIONS_ARG_NAME], dtype=int) y_translations_pixels = numpy.array( argument_dict[dl_helper.Y_TRANSLATIONS_ARG_NAME], dtype=int) ccw_rotation_angles_deg = numpy.array( argument_dict[dl_helper.ROTATION_ANGLES_ARG_NAME], dtype=float) noise_standard_deviation = argument_dict[dl_helper.NOISE_STDEV_ARG_NAME] num_noisings = argument_dict[dl_helper.NUM_NOISINGS_ARG_NAME] flip_in_x = bool(argument_dict[dl_helper.FLIP_X_ARG_NAME]) flip_in_y = bool(argument_dict[dl_helper.FLIP_Y_ARG_NAME]) top_training_dir_name = argument_dict[dl_helper.TRAINING_DIR_ARG_NAME] first_training_time_string = argument_dict[ dl_helper.FIRST_TRAINING_TIME_ARG_NAME] last_training_time_string = argument_dict[ dl_helper.LAST_TRAINING_TIME_ARG_NAME] num_examples_per_train_batch = argument_dict[ dl_helper.NUM_EX_PER_TRAIN_ARG_NAME] top_validation_dir_name = argument_dict[dl_helper.VALIDATION_DIR_ARG_NAME] first_validation_time_string = argument_dict[ dl_helper.FIRST_VALIDATION_TIME_ARG_NAME] last_validation_time_string = argument_dict[ dl_helper.LAST_VALIDATION_TIME_ARG_NAME] num_examples_per_validn_batch = argument_dict[ dl_helper.NUM_EX_PER_VALIDN_ARG_NAME] num_epochs = argument_dict[dl_helper.NUM_EPOCHS_ARG_NAME] num_training_batches_per_epoch = argument_dict[ dl_helper.NUM_TRAINING_BATCHES_ARG_NAME] num_validation_batches_per_epoch = argument_dict[ dl_helper.NUM_VALIDATION_BATCHES_ARG_NAME] output_dir_name = argument_dict[dl_helper.OUTPUT_DIR_ARG_NAME] # Process input args. first_training_time_unix_sec = time_conversion.string_to_unix_sec( first_training_time_string, TIME_FORMAT) last_training_time_unix_sec = time_conversion.string_to_unix_sec( last_training_time_string, TIME_FORMAT) first_validation_time_unix_sec = time_conversion.string_to_unix_sec( first_validation_time_string, TIME_FORMAT) last_validation_time_unix_sec = time_conversion.string_to_unix_sec( last_validation_time_string, TIME_FORMAT) if sounding_field_names[0] in ['', 'None']: sounding_field_names = None num_channels = len(radar_field_name_by_channel) layer_operation_dicts = [{}] * num_channels for k in range(num_channels): layer_operation_dicts[k] = { input_examples.RADAR_FIELD_KEY: radar_field_name_by_channel[k], input_examples.OPERATION_NAME_KEY: layer_op_name_by_channel[k], input_examples.MIN_HEIGHT_KEY: min_height_by_channel_m_agl[k], input_examples.MAX_HEIGHT_KEY: max_height_by_channel_m_agl[k] } if len(downsampling_classes) > 1: downsampling_dict = dict( list(zip(downsampling_classes, downsampling_fractions))) else: downsampling_dict = None translate_flag = (len(x_translations_pixels) > 1 or x_translations_pixels[0] != 0 or y_translations_pixels[0] != 0) if not translate_flag: x_translations_pixels = None y_translations_pixels = None if len(ccw_rotation_angles_deg) == 1 and ccw_rotation_angles_deg[0] == 0: ccw_rotation_angles_deg = None if num_noisings <= 0: num_noisings = 0 noise_standard_deviation = None # Find training and validation files. training_file_names = input_examples.find_many_example_files( top_directory_name=top_training_dir_name, shuffled=True, first_batch_number=FIRST_BATCH_NUMBER, last_batch_number=LAST_BATCH_NUMBER, raise_error_if_any_missing=False) validation_file_names = input_examples.find_many_example_files( top_directory_name=top_validation_dir_name, shuffled=True, first_batch_number=FIRST_BATCH_NUMBER, last_batch_number=LAST_BATCH_NUMBER, raise_error_if_any_missing=False) # Write metadata. metadata_dict = { cnn.NUM_EPOCHS_KEY: num_epochs, cnn.NUM_TRAINING_BATCHES_KEY: num_training_batches_per_epoch, cnn.NUM_VALIDATION_BATCHES_KEY: num_validation_batches_per_epoch, cnn.MONITOR_STRING_KEY: monitor_string, cnn.WEIGHT_LOSS_FUNCTION_KEY: weight_loss_function, cnn.CONV_2D3D_KEY: False, cnn.VALIDATION_FILES_KEY: validation_file_names, cnn.FIRST_VALIDN_TIME_KEY: first_validation_time_unix_sec, cnn.LAST_VALIDN_TIME_KEY: last_validation_time_unix_sec, cnn.LAYER_OPERATIONS_KEY: layer_operation_dicts, cnn.NUM_EX_PER_VALIDN_BATCH_KEY: num_examples_per_validn_batch } input_tensor = model_object.input if isinstance(input_tensor, list): input_tensor = input_tensor[0] num_grid_rows = input_tensor.get_shape().as_list()[1] num_grid_columns = input_tensor.get_shape().as_list()[2] training_option_dict = { trainval_io.EXAMPLE_FILES_KEY: training_file_names, trainval_io.TARGET_NAME_KEY: target_name, trainval_io.FIRST_STORM_TIME_KEY: first_training_time_unix_sec, trainval_io.LAST_STORM_TIME_KEY: last_training_time_unix_sec, trainval_io.NUM_EXAMPLES_PER_BATCH_KEY: num_examples_per_train_batch, trainval_io.SOUNDING_FIELDS_KEY: sounding_field_names, trainval_io.SOUNDING_HEIGHTS_KEY: SOUNDING_HEIGHTS_M_AGL, trainval_io.NUM_ROWS_KEY: num_grid_rows, trainval_io.NUM_COLUMNS_KEY: num_grid_columns, trainval_io.NORMALIZATION_TYPE_KEY: normalization_type_string, trainval_io.NORMALIZATION_FILE_KEY: normalization_file_name, trainval_io.MIN_NORMALIZED_VALUE_KEY: min_normalized_value, trainval_io.MAX_NORMALIZED_VALUE_KEY: max_normalized_value, trainval_io.BINARIZE_TARGET_KEY: False, trainval_io.SAMPLING_FRACTIONS_KEY: downsampling_dict, trainval_io.LOOP_ONCE_KEY: False, trainval_io.X_TRANSLATIONS_KEY: x_translations_pixels, trainval_io.Y_TRANSLATIONS_KEY: y_translations_pixels, trainval_io.ROTATION_ANGLES_KEY: ccw_rotation_angles_deg, trainval_io.NOISE_STDEV_KEY: noise_standard_deviation, trainval_io.NUM_NOISINGS_KEY: num_noisings, trainval_io.FLIP_X_KEY: flip_in_x, trainval_io.FLIP_Y_KEY: flip_in_y } file_system_utils.mkdir_recursive_if_necessary( directory_name=output_dir_name) metafile_name = '{0:s}/model_metadata.p'.format(output_dir_name) print('Writing metadata to: "{0:s}"...'.format(metafile_name)) cnn.write_model_metadata(pickle_file_name=metafile_name, metadata_dict=metadata_dict, training_option_dict=training_option_dict) return metadata_dict, training_option_dict
def _run(input_cnn_file_name, input_upconvnet_file_name, cnn_feature_layer_name, top_training_dir_name, first_training_time_string, last_training_time_string, top_validation_dir_name, first_validation_time_string, last_validation_time_string, num_examples_per_batch, num_epochs, num_training_batches_per_epoch, num_validation_batches_per_epoch, output_dir_name): """Trains upconvnet. This is effectively the main method. :param input_cnn_file_name: See documentation at top of file. :param input_upconvnet_file_name: Same. :param cnn_feature_layer_name: Same. :param top_training_dir_name: Same. :param first_training_time_string: Same. :param last_training_time_string: Same. :param top_validation_dir_name: Same. :param first_validation_time_string: Same. :param last_validation_time_string: Same. :param num_examples_per_batch: Same. :param num_epochs: Same. :param num_training_batches_per_epoch: Same. :param num_validation_batches_per_epoch: Same. :param output_dir_name: Same. """ file_system_utils.mkdir_recursive_if_necessary( directory_name=output_dir_name) # argument_file_name = '{0:s}/input_args.p'.format(output_dir_name) # print('Writing input args to: "{0:s}"...'.format(argument_file_name)) # # argument_file_handle = open(argument_file_name, 'wb') # pickle.dump(INPUT_ARG_OBJECT.__dict__, argument_file_handle) # argument_file_handle.close() # # return # Process input args. first_training_time_unix_sec = time_conversion.string_to_unix_sec( first_training_time_string, TIME_FORMAT) last_training_time_unix_sec = time_conversion.string_to_unix_sec( last_training_time_string, TIME_FORMAT) first_validation_time_unix_sec = time_conversion.string_to_unix_sec( first_validation_time_string, TIME_FORMAT) last_validation_time_unix_sec = time_conversion.string_to_unix_sec( last_validation_time_string, TIME_FORMAT) # Find training and validation files. training_file_names = input_examples.find_many_example_files( top_directory_name=top_training_dir_name, shuffled=True, first_batch_number=FIRST_BATCH_NUMBER, last_batch_number=LAST_BATCH_NUMBER, raise_error_if_any_missing=False) validation_file_names = input_examples.find_many_example_files( top_directory_name=top_validation_dir_name, shuffled=True, first_batch_number=FIRST_BATCH_NUMBER, last_batch_number=LAST_BATCH_NUMBER, raise_error_if_any_missing=False) # Read trained CNN. print('Reading trained CNN from: "{0:s}"...'.format(input_cnn_file_name)) cnn_model_object = cnn.read_model(input_cnn_file_name) cnn_model_object.summary() print(SEPARATOR_STRING) cnn_metafile_name = cnn.find_metafile(model_file_name=input_cnn_file_name, raise_error_if_missing=True) print('Reading CNN metadata from: "{0:s}"...'.format(cnn_metafile_name)) cnn_metadata_dict = cnn.read_model_metadata(cnn_metafile_name) # Read architecture. print('Reading upconvnet architecture from: "{0:s}"...'.format( input_upconvnet_file_name)) upconvnet_model_object = cnn.read_model(input_upconvnet_file_name) # upconvnet_model_object = keras.models.clone_model(upconvnet_model_object) # TODO(thunderhoser): This is a HACK. upconvnet_model_object.compile(loss=keras.losses.mean_squared_error, optimizer=keras.optimizers.Adam()) upconvnet_model_object.summary() print(SEPARATOR_STRING) upconvnet_metafile_name = cnn.find_metafile( model_file_name='{0:s}/foo.h5'.format(output_dir_name), raise_error_if_missing=False) print('Writing upconvnet metadata to: "{0:s}"...'.format( upconvnet_metafile_name)) upconvnet.write_model_metadata( cnn_file_name=input_cnn_file_name, cnn_feature_layer_name=cnn_feature_layer_name, num_epochs=num_epochs, num_examples_per_batch=num_examples_per_batch, num_training_batches_per_epoch=num_training_batches_per_epoch, training_example_file_names=training_file_names, first_training_time_unix_sec=first_training_time_unix_sec, last_training_time_unix_sec=last_training_time_unix_sec, num_validation_batches_per_epoch=num_validation_batches_per_epoch, validation_example_file_names=validation_file_names, first_validation_time_unix_sec=first_validation_time_unix_sec, last_validation_time_unix_sec=last_validation_time_unix_sec, pickle_file_name=upconvnet_metafile_name) print(SEPARATOR_STRING) upconvnet.train_upconvnet( upconvnet_model_object=upconvnet_model_object, output_dir_name=output_dir_name, cnn_model_object=cnn_model_object, cnn_metadata_dict=cnn_metadata_dict, cnn_feature_layer_name=cnn_feature_layer_name, num_epochs=num_epochs, num_examples_per_batch=num_examples_per_batch, num_training_batches_per_epoch=num_training_batches_per_epoch, training_example_file_names=training_file_names, first_training_time_unix_sec=first_training_time_unix_sec, last_training_time_unix_sec=last_training_time_unix_sec, num_validation_batches_per_epoch=num_validation_batches_per_epoch, validation_example_file_names=validation_file_names, first_validation_time_unix_sec=first_validation_time_unix_sec, last_validation_time_unix_sec=last_validation_time_unix_sec)
def _run(input_model_file_name, narr_predictor_names, pressure_level_mb, dilation_distance_metres, num_lead_time_steps, predictor_time_step_offsets, num_examples_per_time, weight_loss_function, class_fractions, top_narr_directory_name, top_frontal_grid_dir_name, narr_mask_file_name, first_training_time_string, last_training_time_string, first_validation_time_string, last_validation_time_string, num_examples_per_batch, num_epochs, num_training_batches_per_epoch, num_validation_batches_per_epoch, output_model_file_name): """Trains CNN from scratch. This is effectively the main method. :param input_model_file_name: See documentation at top of file. :param narr_predictor_names: Same. :param pressure_level_mb: Same. :param dilation_distance_metres: Same. :param num_lead_time_steps: Same. :param predictor_time_step_offsets: Same. :param num_examples_per_time: Same. :param weight_loss_function: Same. :param class_fractions: Same. :param top_narr_directory_name: Same. :param top_frontal_grid_dir_name: Same. :param narr_mask_file_name: Same. :param first_training_time_string: Same. :param last_training_time_string: Same. :param first_validation_time_string: Same. :param last_validation_time_string: Same. :param num_examples_per_batch: Same. :param num_epochs: Same. :param num_training_batches_per_epoch: Same. :param num_validation_batches_per_epoch: Same. :param output_model_file_name: Same. :raises: ValueError: if `num_lead_time_steps > 1`. """ # Process input args. first_training_time_unix_sec = time_conversion.string_to_unix_sec( first_training_time_string, TIME_FORMAT) last_training_time_unix_sec = time_conversion.string_to_unix_sec( last_training_time_string, TIME_FORMAT) first_validation_time_unix_sec = time_conversion.string_to_unix_sec( first_validation_time_string, TIME_FORMAT) last_validation_time_unix_sec = time_conversion.string_to_unix_sec( last_validation_time_string, TIME_FORMAT) if narr_mask_file_name == '': narr_mask_file_name = None narr_mask_matrix = None if num_lead_time_steps <= 1: num_lead_time_steps = None predictor_time_step_offsets = None else: error_string = ( 'This script cannot yet handle num_lead_time_steps > 1 ' '(specifically {0:d}).' ).format(num_lead_time_steps) raise ValueError(error_string) # Read architecture. print 'Reading architecture from: "{0:s}"...'.format(input_model_file_name) model_object = traditional_cnn.read_keras_model(input_model_file_name) model_object = keras.models.clone_model(model_object) # TODO(thunderhoser): This is a HACK. model_object.compile( loss=keras.losses.categorical_crossentropy, optimizer=keras.optimizers.Adam(), metrics=traditional_cnn.LIST_OF_METRIC_FUNCTIONS) print SEPARATOR_STRING model_object.summary() print SEPARATOR_STRING # Write metadata. input_tensor = model_object.input num_grid_rows = input_tensor.get_shape().as_list()[1] num_grid_columns = input_tensor.get_shape().as_list()[2] num_half_rows = int(numpy.round((num_grid_rows - 1) / 2)) num_half_columns = int(numpy.round((num_grid_columns - 1) / 2)) if narr_mask_file_name is not None: print 'Reading NARR mask from: "{0:s}"...'.format(narr_mask_file_name) narr_mask_matrix = ml_utils.read_narr_mask(narr_mask_file_name) model_metafile_name = traditional_cnn.find_metafile( model_file_name=output_model_file_name, raise_error_if_missing=False) print 'Writing metadata to: "{0:s}"...'.format(model_metafile_name) traditional_cnn.write_model_metadata( pickle_file_name=model_metafile_name, num_epochs=num_epochs, num_examples_per_batch=num_examples_per_batch, num_examples_per_target_time=num_examples_per_time, num_training_batches_per_epoch=num_training_batches_per_epoch, num_validation_batches_per_epoch=num_validation_batches_per_epoch, num_rows_in_half_grid=num_half_rows, num_columns_in_half_grid=num_half_columns, dilation_distance_metres=dilation_distance_metres, class_fractions=class_fractions, weight_loss_function=weight_loss_function, narr_predictor_names=narr_predictor_names, pressure_level_mb=pressure_level_mb, training_start_time_unix_sec=first_training_time_unix_sec, training_end_time_unix_sec=last_training_time_unix_sec, validation_start_time_unix_sec=first_validation_time_unix_sec, validation_end_time_unix_sec=last_validation_time_unix_sec, num_lead_time_steps=num_lead_time_steps, predictor_time_step_offsets=predictor_time_step_offsets, narr_mask_matrix=narr_mask_matrix) print SEPARATOR_STRING traditional_cnn.train_with_3d_examples( model_object=model_object, output_file_name=output_model_file_name, num_examples_per_batch=num_examples_per_batch, num_epochs=num_epochs, num_training_batches_per_epoch=num_training_batches_per_epoch, num_examples_per_target_time=num_examples_per_time, training_start_time_unix_sec=first_training_time_unix_sec, training_end_time_unix_sec=last_training_time_unix_sec, top_narr_directory_name=top_narr_directory_name, top_frontal_grid_dir_name=top_frontal_grid_dir_name, narr_predictor_names=narr_predictor_names, pressure_level_mb=pressure_level_mb, dilation_distance_metres=dilation_distance_metres, class_fractions=class_fractions, num_rows_in_half_grid=num_half_rows, num_columns_in_half_grid=num_half_columns, weight_loss_function=weight_loss_function, num_validation_batches_per_epoch=num_validation_batches_per_epoch, validation_start_time_unix_sec=first_validation_time_unix_sec, validation_end_time_unix_sec=last_validation_time_unix_sec, narr_mask_matrix=narr_mask_matrix)
def _run(model_file_name, first_time_string, last_time_string, randomize_times, num_target_times, use_isotonic_regression, top_narr_directory_name, top_frontal_grid_dir_name, output_dir_name): """Applies traditional CNN to full grids. This is effectively the main method. :param model_file_name: See documentation at top of file. :param first_time_string: Same. :param last_time_string: Same. :param randomize_times: Same. :param num_target_times: Same. :param use_isotonic_regression: Same. :param top_narr_directory_name: Same. :param top_frontal_grid_dir_name: Same. :param output_dir_name: Same. """ first_time_unix_sec = time_conversion.string_to_unix_sec( first_time_string, INPUT_TIME_FORMAT) last_time_unix_sec = time_conversion.string_to_unix_sec( last_time_string, INPUT_TIME_FORMAT) target_times_unix_sec = time_periods.range_and_interval_to_list( start_time_unix_sec=first_time_unix_sec, end_time_unix_sec=last_time_unix_sec, time_interval_sec=NARR_TIME_INTERVAL_SEC, include_endpoint=True) if randomize_times: error_checking.assert_is_leq( num_target_times, len(target_times_unix_sec)) numpy.random.shuffle(target_times_unix_sec) target_times_unix_sec = target_times_unix_sec[:num_target_times] print 'Reading model from: "{0:s}"...'.format(model_file_name) model_object = traditional_cnn.read_keras_model(model_file_name) model_metafile_name = traditional_cnn.find_metafile( model_file_name=model_file_name, raise_error_if_missing=True) print 'Reading model metadata from: "{0:s}"...'.format( model_metafile_name) model_metadata_dict = traditional_cnn.read_model_metadata( model_metafile_name) if use_isotonic_regression: isotonic_file_name = isotonic_regression.find_model_file( base_model_file_name=model_file_name, raise_error_if_missing=True) print 'Reading isotonic-regression models from: "{0:s}"...'.format( isotonic_file_name) isotonic_model_object_by_class = ( isotonic_regression.read_model_for_each_class(isotonic_file_name) ) else: isotonic_model_object_by_class = None if model_metadata_dict[traditional_cnn.NUM_LEAD_TIME_STEPS_KEY] is None: num_dimensions = 3 else: num_dimensions = 4 num_classes = len(model_metadata_dict[traditional_cnn.CLASS_FRACTIONS_KEY]) num_target_times = len(target_times_unix_sec) print SEPARATOR_STRING for i in range(num_target_times): if num_dimensions == 3: (this_class_probability_matrix, this_target_matrix ) = traditional_cnn.apply_model_to_3d_example( model_object=model_object, target_time_unix_sec=target_times_unix_sec[i], top_narr_directory_name=top_narr_directory_name, top_frontal_grid_dir_name=top_frontal_grid_dir_name, narr_predictor_names=model_metadata_dict[ traditional_cnn.NARR_PREDICTOR_NAMES_KEY], pressure_level_mb=model_metadata_dict[ traditional_cnn.PRESSURE_LEVEL_KEY], dilation_distance_metres=model_metadata_dict[ traditional_cnn.DILATION_DISTANCE_FOR_TARGET_KEY], num_rows_in_half_grid=model_metadata_dict[ traditional_cnn.NUM_ROWS_IN_HALF_GRID_KEY], num_columns_in_half_grid=model_metadata_dict[ traditional_cnn.NUM_COLUMNS_IN_HALF_GRID_KEY], num_classes=num_classes, isotonic_model_object_by_class=isotonic_model_object_by_class, narr_mask_matrix=model_metadata_dict[ traditional_cnn.NARR_MASK_MATRIX_KEY]) else: (this_class_probability_matrix, this_target_matrix ) = traditional_cnn.apply_model_to_4d_example( model_object=model_object, target_time_unix_sec=target_times_unix_sec[i], predictor_time_step_offsets=model_metadata_dict[ traditional_cnn.PREDICTOR_TIME_STEP_OFFSETS_KEY], num_lead_time_steps=model_metadata_dict[ traditional_cnn.NUM_LEAD_TIME_STEPS_KEY], top_narr_directory_name=top_narr_directory_name, top_frontal_grid_dir_name=top_frontal_grid_dir_name, narr_predictor_names=model_metadata_dict[ traditional_cnn.NARR_PREDICTOR_NAMES_KEY], pressure_level_mb=model_metadata_dict[ traditional_cnn.PRESSURE_LEVEL_KEY], dilation_distance_metres=model_metadata_dict[ traditional_cnn.DILATION_DISTANCE_FOR_TARGET_KEY], num_rows_in_half_grid=model_metadata_dict[ traditional_cnn.NUM_ROWS_IN_HALF_GRID_KEY], num_columns_in_half_grid=model_metadata_dict[ traditional_cnn.NUM_COLUMNS_IN_HALF_GRID_KEY], num_classes=num_classes, isotonic_model_object_by_class=isotonic_model_object_by_class, narr_mask_matrix=model_metadata_dict[ traditional_cnn.NARR_MASK_MATRIX_KEY]) this_target_matrix[this_target_matrix == -1] = 0 print MINOR_SEPARATOR_STRING this_prediction_file_name = ml_utils.find_gridded_prediction_file( directory_name=output_dir_name, first_target_time_unix_sec=target_times_unix_sec[i], last_target_time_unix_sec=target_times_unix_sec[i], raise_error_if_missing=False) print 'Writing gridded predictions to file: "{0:s}"...'.format( this_prediction_file_name) ml_utils.write_gridded_predictions( pickle_file_name=this_prediction_file_name, class_probability_matrix=this_class_probability_matrix, target_times_unix_sec=target_times_unix_sec[[i]], model_file_name=model_file_name, used_isotonic_regression=use_isotonic_regression, target_matrix=this_target_matrix) if i != num_target_times - 1: print SEPARATOR_STRING
def _run(input_model_file_name, sounding_field_names, normalization_type_string, normalization_param_file_name, min_normalized_value, max_normalized_value, target_name, downsampling_classes, downsampling_fractions, monitor_string, weight_loss_function, x_translations_pixels, y_translations_pixels, ccw_rotation_angles_deg, noise_standard_deviation, num_noisings, flip_in_x, flip_in_y, top_training_dir_name, first_training_time_string, last_training_time_string, num_examples_per_train_batch, top_validation_dir_name, first_validation_time_string, last_validation_time_string, num_examples_per_validn_batch, num_epochs, num_training_batches_per_epoch, num_validation_batches_per_epoch, output_dir_name): """Trains CNN with 2-D and 3-D MYRORSS images. This is effectively the main method. :param input_model_file_name: See documentation at top of file. :param sounding_field_names: Same. :param normalization_type_string: Same. :param normalization_param_file_name: Same. :param min_normalized_value: Same. :param max_normalized_value: Same. :param target_name: Same. :param downsampling_classes: Same. :param downsampling_fractions: Same. :param monitor_string: Same. :param weight_loss_function: Same. :param x_translations_pixels: Same. :param y_translations_pixels: Same. :param ccw_rotation_angles_deg: Same. :param noise_standard_deviation: Same. :param num_noisings: Same. :param flip_in_x: Same. :param flip_in_y: Same. :param top_training_dir_name: Same. :param first_training_time_string: Same. :param last_training_time_string: Same. :param num_examples_per_train_batch: Same. :param top_validation_dir_name: Same. :param first_validation_time_string: Same. :param last_validation_time_string: Same. :param num_examples_per_validn_batch: Same. :param num_epochs: Same. :param num_training_batches_per_epoch: Same. :param num_validation_batches_per_epoch: Same. :param output_dir_name: Same. """ file_system_utils.mkdir_recursive_if_necessary( directory_name=output_dir_name) # argument_file_name = '{0:s}/input_args.p'.format(output_dir_name) # print('Writing input args to: "{0:s}"...'.format(argument_file_name)) # # argument_file_handle = open(argument_file_name, 'wb') # pickle.dump(INPUT_ARG_OBJECT.__dict__, argument_file_handle) # argument_file_handle.close() # # return # Process input args. first_training_time_unix_sec = time_conversion.string_to_unix_sec( first_training_time_string, TIME_FORMAT) last_training_time_unix_sec = time_conversion.string_to_unix_sec( last_training_time_string, TIME_FORMAT) first_validation_time_unix_sec = time_conversion.string_to_unix_sec( first_validation_time_string, TIME_FORMAT) last_validation_time_unix_sec = time_conversion.string_to_unix_sec( last_validation_time_string, TIME_FORMAT) if sounding_field_names[0] in ['', 'None']: sounding_field_names = None if len(downsampling_classes) > 1: downsampling_dict = dict( list(zip(downsampling_classes, downsampling_fractions))) else: downsampling_dict = None if (len(x_translations_pixels) == 1 and x_translations_pixels + y_translations_pixels == 0): x_translations_pixels = None y_translations_pixels = None if len(ccw_rotation_angles_deg) == 1 and ccw_rotation_angles_deg[0] == 0: ccw_rotation_angles_deg = None if num_noisings <= 0: num_noisings = 0 noise_standard_deviation = None # Set output locations. output_model_file_name = '{0:s}/model.h5'.format(output_dir_name) history_file_name = '{0:s}/model_history.csv'.format(output_dir_name) tensorboard_dir_name = '{0:s}/tensorboard'.format(output_dir_name) model_metafile_name = '{0:s}/model_metadata.p'.format(output_dir_name) # Find training and validation files. training_file_names = input_examples.find_many_example_files( top_directory_name=top_training_dir_name, shuffled=True, first_batch_number=FIRST_BATCH_NUMBER, last_batch_number=LAST_BATCH_NUMBER, raise_error_if_any_missing=False) validation_file_names = input_examples.find_many_example_files( top_directory_name=top_validation_dir_name, shuffled=True, first_batch_number=FIRST_BATCH_NUMBER, last_batch_number=LAST_BATCH_NUMBER, raise_error_if_any_missing=False) # Read architecture. print( 'Reading architecture from: "{0:s}"...'.format(input_model_file_name)) model_object = cnn.read_model(input_model_file_name) # model_object = keras.models.clone_model(model_object) # TODO(thunderhoser): This is a HACK. model_object.compile(loss=keras.losses.binary_crossentropy, optimizer=keras.optimizers.Adam(), metrics=cnn_setup.DEFAULT_METRIC_FUNCTION_LIST) print(SEPARATOR_STRING) model_object.summary() print(SEPARATOR_STRING) # Write metadata. metadata_dict = { cnn.NUM_EPOCHS_KEY: num_epochs, cnn.NUM_TRAINING_BATCHES_KEY: num_training_batches_per_epoch, cnn.NUM_VALIDATION_BATCHES_KEY: num_validation_batches_per_epoch, cnn.MONITOR_STRING_KEY: monitor_string, cnn.WEIGHT_LOSS_FUNCTION_KEY: weight_loss_function, cnn.CONV_2D3D_KEY: True, cnn.VALIDATION_FILES_KEY: validation_file_names, cnn.FIRST_VALIDN_TIME_KEY: first_validation_time_unix_sec, cnn.LAST_VALIDN_TIME_KEY: last_validation_time_unix_sec, cnn.NUM_EX_PER_VALIDN_BATCH_KEY: num_examples_per_validn_batch } if isinstance(model_object.input, list): list_of_input_tensors = model_object.input else: list_of_input_tensors = [model_object.input] upsample_refl = len(list_of_input_tensors) == 2 num_grid_rows = list_of_input_tensors[0].get_shape().as_list()[1] num_grid_columns = list_of_input_tensors[0].get_shape().as_list()[2] if upsample_refl: num_grid_rows = int(numpy.round(num_grid_rows / 2)) num_grid_columns = int(numpy.round(num_grid_columns / 2)) training_option_dict = { trainval_io.EXAMPLE_FILES_KEY: training_file_names, trainval_io.TARGET_NAME_KEY: target_name, trainval_io.FIRST_STORM_TIME_KEY: first_training_time_unix_sec, trainval_io.LAST_STORM_TIME_KEY: last_training_time_unix_sec, trainval_io.NUM_EXAMPLES_PER_BATCH_KEY: num_examples_per_train_batch, trainval_io.RADAR_FIELDS_KEY: input_examples.AZIMUTHAL_SHEAR_FIELD_NAMES, trainval_io.RADAR_HEIGHTS_KEY: REFLECTIVITY_HEIGHTS_M_AGL, trainval_io.SOUNDING_FIELDS_KEY: sounding_field_names, trainval_io.SOUNDING_HEIGHTS_KEY: SOUNDING_HEIGHTS_M_AGL, trainval_io.NUM_ROWS_KEY: num_grid_rows, trainval_io.NUM_COLUMNS_KEY: num_grid_columns, trainval_io.NORMALIZATION_TYPE_KEY: normalization_type_string, trainval_io.NORMALIZATION_FILE_KEY: normalization_param_file_name, trainval_io.MIN_NORMALIZED_VALUE_KEY: min_normalized_value, trainval_io.MAX_NORMALIZED_VALUE_KEY: max_normalized_value, trainval_io.BINARIZE_TARGET_KEY: False, trainval_io.SAMPLING_FRACTIONS_KEY: downsampling_dict, trainval_io.LOOP_ONCE_KEY: False, trainval_io.X_TRANSLATIONS_KEY: x_translations_pixels, trainval_io.Y_TRANSLATIONS_KEY: y_translations_pixels, trainval_io.ROTATION_ANGLES_KEY: ccw_rotation_angles_deg, trainval_io.NOISE_STDEV_KEY: noise_standard_deviation, trainval_io.NUM_NOISINGS_KEY: num_noisings, trainval_io.FLIP_X_KEY: flip_in_x, trainval_io.FLIP_Y_KEY: flip_in_y, trainval_io.UPSAMPLE_REFLECTIVITY_KEY: upsample_refl } print('Writing metadata to: "{0:s}"...'.format(model_metafile_name)) cnn.write_model_metadata(pickle_file_name=model_metafile_name, metadata_dict=metadata_dict, training_option_dict=training_option_dict) cnn.train_cnn_2d3d_myrorss( model_object=model_object, model_file_name=output_model_file_name, history_file_name=history_file_name, tensorboard_dir_name=tensorboard_dir_name, num_epochs=num_epochs, num_training_batches_per_epoch=num_training_batches_per_epoch, training_option_dict=training_option_dict, monitor_string=monitor_string, weight_loss_function=weight_loss_function, num_validation_batches_per_epoch=num_validation_batches_per_epoch, validation_file_names=validation_file_names, first_validn_time_unix_sec=first_validation_time_unix_sec, last_validn_time_unix_sec=last_validation_time_unix_sec, num_examples_per_validn_batch=num_examples_per_validn_batch)
def _find_baseline_and_test_examples(top_example_dir_name, first_time_string, last_time_string, num_baseline_examples, num_test_examples, cnn_model_object, cnn_metadata_dict): """Finds examples for baseline and test sets. :param top_example_dir_name: See documentation at top of file. :param first_time_string: Same. :param last_time_string: Same. :param num_baseline_examples: Same. :param num_test_examples: Same. :param cnn_model_object: :param cnn_metadata_dict: :return: baseline_image_matrix: B-by-M-by-N-by-C numpy array of baseline images (input examples for the CNN). :return: test_image_matrix: B-by-M-by-N-by-C numpy array of test images (input examples for the CNN). """ first_time_unix_sec = time_conversion.string_to_unix_sec( first_time_string, TIME_FORMAT) last_time_unix_sec = time_conversion.string_to_unix_sec( last_time_string, TIME_FORMAT) example_file_names = trainval_io.find_downsized_3d_example_files( top_directory_name=top_example_dir_name, shuffled=False, first_target_time_unix_sec=first_time_unix_sec, last_target_time_unix_sec=last_time_unix_sec) file_indices = numpy.array([], dtype=int) file_position_indices = numpy.array([], dtype=int) cold_front_probabilities = numpy.array([], dtype=float) for k in range(len(example_file_names)): print 'Reading data from: "{0:s}"...'.format(example_file_names[k]) this_example_dict = trainval_io.read_downsized_3d_examples( netcdf_file_name=example_file_names[k], metadata_only=False, predictor_names_to_keep=cnn_metadata_dict[ traditional_cnn.NARR_PREDICTOR_NAMES_KEY], num_half_rows_to_keep=cnn_metadata_dict[ traditional_cnn.NUM_ROWS_IN_HALF_GRID_KEY], num_half_columns_to_keep=cnn_metadata_dict[ traditional_cnn.NUM_COLUMNS_IN_HALF_GRID_KEY], first_time_to_keep_unix_sec=first_time_unix_sec, last_time_to_keep_unix_sec=last_time_unix_sec) this_num_examples = len( this_example_dict[trainval_io.TARGET_TIMES_KEY]) if this_num_examples == 0: continue these_file_indices = numpy.full(this_num_examples, k, dtype=int) these_position_indices = numpy.linspace(0, this_num_examples - 1, num=this_num_examples, dtype=int) these_cold_front_probs = _get_cnn_predictions( cnn_model_object=cnn_model_object, predictor_matrix=this_example_dict[ trainval_io.PREDICTOR_MATRIX_KEY], target_class=front_utils.COLD_FRONT_INTEGER_ID, verbose=True) print '\n' file_indices = numpy.concatenate((file_indices, these_file_indices)) file_position_indices = numpy.concatenate( (file_position_indices, these_position_indices)) cold_front_probabilities = numpy.concatenate( (cold_front_probabilities, these_cold_front_probs)) print SEPARATOR_STRING # Find test set. test_indices = numpy.argsort(-1 * cold_front_probabilities)[:num_test_examples] file_indices_for_test = file_indices[test_indices] file_position_indices_for_test = file_position_indices[test_indices] print 'Cold-front probabilities for the {0:d} test examples are:'.format( num_test_examples) for i in test_indices: print cold_front_probabilities[i] print SEPARATOR_STRING # Find baseline set. baseline_indices = numpy.linspace(0, num_baseline_examples - 1, num=num_baseline_examples, dtype=int) baseline_indices = (set(baseline_indices.tolist()) - set(test_indices.tolist())) baseline_indices = numpy.array(list(baseline_indices), dtype=int) baseline_indices = numpy.random.choice(baseline_indices, size=num_baseline_examples, replace=False) file_indices_for_baseline = file_indices[baseline_indices] file_position_indices_for_baseline = file_position_indices[ baseline_indices] print('Cold-front probabilities for the {0:d} baseline examples are:' ).format(num_baseline_examples) for i in baseline_indices: print cold_front_probabilities[i] print SEPARATOR_STRING # Read test and baseline sets. baseline_image_matrix = None test_image_matrix = None for k in range(len(example_file_names)): if not (k in file_indices_for_test or k in file_indices_for_baseline): continue print 'Reading data from: "{0:s}"...'.format(example_file_names[k]) this_example_dict = trainval_io.read_downsized_3d_examples( netcdf_file_name=example_file_names[k], metadata_only=False, predictor_names_to_keep=cnn_metadata_dict[ traditional_cnn.NARR_PREDICTOR_NAMES_KEY], num_half_rows_to_keep=cnn_metadata_dict[ traditional_cnn.NUM_ROWS_IN_HALF_GRID_KEY], num_half_columns_to_keep=cnn_metadata_dict[ traditional_cnn.NUM_COLUMNS_IN_HALF_GRID_KEY], first_time_to_keep_unix_sec=first_time_unix_sec, last_time_to_keep_unix_sec=last_time_unix_sec) this_predictor_matrix = this_example_dict[ trainval_io.PREDICTOR_MATRIX_KEY] if baseline_image_matrix is None: baseline_image_matrix = numpy.full( (num_baseline_examples, ) + this_predictor_matrix.shape[1:], numpy.nan) test_image_matrix = numpy.full( (num_test_examples, ) + this_predictor_matrix.shape[1:], numpy.nan) these_baseline_indices = numpy.where(file_indices_for_baseline == k)[0] if len(these_baseline_indices) > 0: baseline_image_matrix[these_baseline_indices, ...] = ( this_predictor_matrix[ file_position_indices_for_baseline[these_baseline_indices], ...]) these_test_indices = numpy.where(file_indices_for_test == k)[0] if len(these_test_indices) > 0: test_image_matrix[these_test_indices, ...] = ( this_predictor_matrix[ file_position_indices_for_test[these_test_indices], ...]) return baseline_image_matrix, test_image_matrix
import unittest import numpy from gewittergefahr.gg_utils import time_conversion from gewittergefahr.deep_learning import prediction_io # The following constants are used to test subset_ungridded_predictions. TARGET_NAME = 'foo' THESE_ID_STRINGS = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L'] THESE_TIME_STRINGS = [ '4001-01-01-01', '4002-02-02-02', '4003-03-03-03', '4004-04-04-04', '4005-05-05-05', '4006-06-06-06', '4007-07-07-07', '4008-08-08-08', '4009-09-09-09', '4010-10-10-10', '4011-11-11-11', '4012-12-12-12' ] THESE_TIMES_UNIX_SEC = [ time_conversion.string_to_unix_sec(t, '%Y-%m-%d-%H') for t in THESE_TIME_STRINGS ] THIS_PROBABILITY_MATRIX = numpy.array([[1, 0], [0.9, 0.1], [0.8, 0.2], [0.7, 0.3], [0.6, 0.4], [0.5, 0.5], [0.4, 0.6], [0.3, 0.7], [0.2, 0.8], [0.1, 0.9], [0, 1], [0.75, 0.25]]) THESE_OBSERVED_LABELS = numpy.array([0, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0], dtype=int) FULL_PREDICTION_DICT_SANS_OBS = { prediction_io.TARGET_NAME_KEY: TARGET_NAME, prediction_io.STORM_IDS_KEY: THESE_ID_STRINGS, prediction_io.STORM_TIMES_KEY: THESE_TIMES_UNIX_SEC, prediction_io.PROBABILITY_MATRIX_KEY: THIS_PROBABILITY_MATRIX, prediction_io.OBSERVED_LABELS_KEY: None,
def _run(model_file_name, first_eval_time_string, last_eval_time_string, num_times, num_examples_per_time, dilation_distance_metres, use_isotonic_regression, top_narr_directory_name, top_frontal_grid_dir_name, output_dir_name): """Evaluates CNN trained by patch classification. This is effectively the main method. :param model_file_name: See documentation at top of file. :param first_eval_time_string: Same. :param last_eval_time_string: Same. :param num_times: Same. :param num_examples_per_time: Same. :param dilation_distance_metres: Same. :param use_isotonic_regression: Same. :param top_narr_directory_name: Same. :param top_frontal_grid_dir_name: Same. :param output_dir_name: Same. """ first_eval_time_unix_sec = time_conversion.string_to_unix_sec( first_eval_time_string, INPUT_TIME_FORMAT) last_eval_time_unix_sec = time_conversion.string_to_unix_sec( last_eval_time_string, INPUT_TIME_FORMAT) print 'Reading model from: "{0:s}"...'.format(model_file_name) model_object = traditional_cnn.read_keras_model(model_file_name) model_metafile_name = traditional_cnn.find_metafile( model_file_name=model_file_name, raise_error_if_missing=True) print 'Reading model metadata from: "{0:s}"...'.format(model_metafile_name) model_metadata_dict = traditional_cnn.read_model_metadata( model_metafile_name) if dilation_distance_metres < 0: dilation_distance_metres = model_metadata_dict[ traditional_cnn.DILATION_DISTANCE_FOR_TARGET_KEY] + 0. if use_isotonic_regression: isotonic_file_name = isotonic_regression.find_model_file( base_model_file_name=model_file_name, raise_error_if_missing=True) print 'Reading isotonic-regression models from: "{0:s}"...'.format( isotonic_file_name) isotonic_model_object_by_class = ( isotonic_regression.read_model_for_each_class(isotonic_file_name)) else: isotonic_model_object_by_class = None num_classes = len(model_metadata_dict[traditional_cnn.CLASS_FRACTIONS_KEY]) print SEPARATOR_STRING class_probability_matrix, observed_labels = ( eval_utils.downsized_examples_to_eval_pairs( model_object=model_object, first_target_time_unix_sec=first_eval_time_unix_sec, last_target_time_unix_sec=last_eval_time_unix_sec, num_target_times_to_sample=num_times, num_examples_per_time=num_examples_per_time, top_narr_directory_name=top_narr_directory_name, top_frontal_grid_dir_name=top_frontal_grid_dir_name, narr_predictor_names=model_metadata_dict[ traditional_cnn.NARR_PREDICTOR_NAMES_KEY], pressure_level_mb=model_metadata_dict[ traditional_cnn.PRESSURE_LEVEL_KEY], dilation_distance_metres=dilation_distance_metres, num_rows_in_half_grid=model_metadata_dict[ traditional_cnn.NUM_ROWS_IN_HALF_GRID_KEY], num_columns_in_half_grid=model_metadata_dict[ traditional_cnn.NUM_COLUMNS_IN_HALF_GRID_KEY], num_classes=num_classes, predictor_time_step_offsets=model_metadata_dict[ traditional_cnn.PREDICTOR_TIME_STEP_OFFSETS_KEY], num_lead_time_steps=model_metadata_dict[ traditional_cnn.NUM_LEAD_TIME_STEPS_KEY], isotonic_model_object_by_class=isotonic_model_object_by_class, narr_mask_matrix=model_metadata_dict[ traditional_cnn.NARR_MASK_MATRIX_KEY])) print SEPARATOR_STRING model_eval_helper.run_evaluation( class_probability_matrix=class_probability_matrix, observed_labels=observed_labels, output_dir_name=output_dir_name)