def _create_best_tracks(start_time_string, end_time_string, top_input_dir_name,
                        data_source, top_output_dir_name,
                        tracking_scale_metres2):
    """Runs the best-track algorithm with default parameters.

    :param start_time_string: See documentation at top of file.
    :param end_time_string: Same.
    :param top_input_dir_name: Same.
    :param data_source: Same.
    :param top_output_dir_name: Same.
    :param tracking_scale_metres2: Same.
    """

    tracking_utils.check_data_source(data_source)

    start_time_unix_sec = time_conversion.string_to_unix_sec(
        start_time_string, INPUT_TIME_FORMAT)
    end_time_unix_sec = time_conversion.string_to_unix_sec(
        end_time_string, INPUT_TIME_FORMAT)
    first_date_string = time_conversion.time_to_spc_date_string(
        start_time_unix_sec)
    last_date_string = time_conversion.time_to_spc_date_string(
        end_time_unix_sec)

    file_dictionary = best_tracks_smart_io.find_files_for_smart_io(
        start_time_unix_sec=start_time_unix_sec,
        start_spc_date_string=first_date_string,
        end_time_unix_sec=end_time_unix_sec,
        end_spc_date_string=last_date_string,
        data_source=data_source,
        tracking_scale_metres2=tracking_scale_metres2,
        top_input_dir_name=top_input_dir_name,
        top_output_dir_name=top_output_dir_name)

    best_tracks_smart_io.run_best_track(smart_file_dict=file_dictionary)
示例#2
0
def _download_rap_analyses(first_init_time_string, last_init_time_string,
                           top_local_directory_name):
    """Downloads zero-hour analyses from the RAP (Rapid Refresh) model.

    :param first_init_time_string: See documentation at top of file.
    :param last_init_time_string: Same.
    :param top_local_directory_name: Same.
    """

    first_init_time_unix_sec = time_conversion.string_to_unix_sec(
        first_init_time_string, INPUT_TIME_FORMAT)
    last_init_time_unix_sec = time_conversion.string_to_unix_sec(
        last_init_time_string, INPUT_TIME_FORMAT)
    time_interval_sec = HOURS_TO_SECONDS * nwp_model_utils.get_time_steps(
        nwp_model_utils.RAP_MODEL_NAME)[1]

    init_times_unix_sec = time_periods.range_and_interval_to_list(
        start_time_unix_sec=first_init_time_unix_sec,
        end_time_unix_sec=last_init_time_unix_sec,
        time_interval_sec=time_interval_sec)
    init_time_strings = [
        time_conversion.unix_sec_to_string(t, DEFAULT_TIME_FORMAT)
        for t in init_times_unix_sec]

    num_init_times = len(init_times_unix_sec)
    local_file_names = [None] * num_init_times

    for i in range(num_init_times):
        local_file_names[i] = nwp_model_io.find_rap_file_any_grid(
            top_directory_name=top_local_directory_name,
            init_time_unix_sec=init_times_unix_sec[i], lead_time_hours=0,
            raise_error_if_missing=False)
        if local_file_names[i] is not None:
            continue

        local_file_names[i] = nwp_model_io.download_rap_file_any_grid(
            top_local_directory_name=top_local_directory_name,
            init_time_unix_sec=init_times_unix_sec[i], lead_time_hours=0,
            raise_error_if_fails=False)

        if local_file_names[i] is None:
            print '\nPROBLEM.  Download failed for {0:s}.\n\n'.format(
                init_time_strings[i])
        else:
            print '\nSUCCESS.  File was downloaded to "{0:s}".\n\n'.format(
                local_file_names[i])

        time.sleep(SECONDS_TO_PAUSE_BETWEEN_FILES)

    num_downloaded = numpy.sum(numpy.array(
        [f is not None for f in local_file_names]))
    print '{0:d} of {1:d} files were downloaded successfully!'.format(
        num_downloaded, num_init_times)
示例#3
0
def _run(top_input_dir_name, top_output_dir_name, first_spc_date_string,
         last_spc_date_string, first_time_string, last_time_string,
         radar_source_name, for_storm_climatology):
    """Reanalyzes storm tracks across many SPC dates.

    This is effectively the main method.

    :param top_input_dir_name: See documentation at top of file.
    :param top_output_dir_name: Same.
    :param first_spc_date_string: Same.
    :param last_spc_date_string: Same.
    :param first_time_string: Same.
    :param last_time_string: Same.
    :param radar_source_name: Same.
    :param for_storm_climatology: Same.
    """

    if (for_storm_climatology
            and radar_source_name == radar_utils.MYRORSS_SOURCE_ID):

        myrorss_start_time_unix_sec = time_conversion.string_to_unix_sec(
            MYRORSS_START_TIME_STRING, TIME_FORMAT)
        myrorss_end_time_unix_sec = time_conversion.string_to_unix_sec(
            MYRORSS_END_TIME_STRING, TIME_FORMAT)

        tracking_start_time_unix_sec = myrorss_start_time_unix_sec + 0
        tracking_end_time_unix_sec = myrorss_end_time_unix_sec + 0
    else:
        tracking_start_time_unix_sec = None
        tracking_end_time_unix_sec = None

    if first_time_string in ['', 'None'] or last_time_string in ['', 'None']:
        first_time_unix_sec = None
        last_time_unix_sec = None
    else:
        first_time_unix_sec = time_conversion.string_to_unix_sec(
            first_time_string, TIME_FORMAT)
        last_time_unix_sec = time_conversion.string_to_unix_sec(
            last_time_string, TIME_FORMAT)

    echo_top_tracking.reanalyze_tracks_across_spc_dates(
        top_input_dir_name=top_input_dir_name,
        top_output_dir_name=top_output_dir_name,
        first_spc_date_string=first_spc_date_string,
        last_spc_date_string=last_spc_date_string,
        first_time_unix_sec=first_time_unix_sec,
        last_time_unix_sec=last_time_unix_sec,
        tracking_start_time_unix_sec=tracking_start_time_unix_sec,
        tracking_end_time_unix_sec=tracking_end_time_unix_sec,
        min_track_duration_seconds=0)
def _convert_files(top_probsevere_dir_name, date_string, top_output_dir_name):
    """Converts probSevere tracking files for one day.

    :param top_probsevere_dir_name: See documentation at top of file.
    :param date_string: Same.
    :param top_output_dir_name: Same.
    """

    date_unix_sec = time_conversion.string_to_unix_sec(date_string,
                                                       DATE_FORMAT)
    raw_file_names = probsevere_io.find_raw_files_one_day(
        top_directory_name=top_probsevere_dir_name,
        unix_time_sec=date_unix_sec,
        file_extension=probsevere_io.ASCII_FILE_EXTENSION,
        raise_error_if_all_missing=True)

    for this_raw_file_name in raw_file_names:
        print 'Reading data from "{0:s}"...'.format(this_raw_file_name)
        this_storm_object_table = probsevere_io.read_raw_file(
            this_raw_file_name)

        this_time_unix_sec = probsevere_io.raw_file_name_to_time(
            this_raw_file_name)
        this_new_file_name = tracking_io.find_processed_file(
            unix_time_sec=this_time_unix_sec,
            data_source=tracking_utils.PROBSEVERE_SOURCE_ID,
            top_processed_dir_name=top_output_dir_name,
            tracking_scale_metres2=DUMMY_TRACKING_SCALE_METRES2,
            raise_error_if_missing=False)

        print 'Writing data to "{0:s}"...'.format(this_new_file_name)
        tracking_io.write_processed_file(
            storm_object_table=this_storm_object_table,
            pickle_file_name=this_new_file_name)
def _time_string_to_unix_sec(time_string):
    """Converts time from string to Unix format.

    :param time_string: Time string (format "dd-mmm-yy HH:MM:SS").
    :return: unix_time_sec: Time in Unix format.
    """

    time_string = _capitalize_months(time_string)
    return time_conversion.string_to_unix_sec(time_string, TIME_FORMAT)
示例#6
0
def _run(top_input_dir_name, first_spc_date_string, last_spc_date_string,
         first_time_string, last_time_string, max_velocity_diff_m_s01,
         max_link_distance_m_s01, max_join_time_seconds, max_join_error_m_s01,
         min_duration_seconds, top_output_dir_name):
    """Reanalyzes storm tracks (preferably over many SPC dates).

    This is effectively the main method.

    :param top_input_dir_name: See documentation at top of file.
    :param first_spc_date_string: Same.
    :param last_spc_date_string: Same.
    :param first_time_string: Same.
    :param last_time_string: Same.
    :param max_velocity_diff_m_s01: Same.
    :param max_link_distance_m_s01: Same.
    :param max_join_time_seconds: Same.
    :param max_join_error_m_s01: Same.
    :param min_duration_seconds: Same.
    :param top_output_dir_name: Same.
    """

    if first_time_string in ['', 'None'] or last_time_string in ['', 'None']:
        first_time_unix_sec = None
        last_time_unix_sec = None
    else:
        first_time_unix_sec = time_conversion.string_to_unix_sec(
            first_time_string, TIME_FORMAT)
        last_time_unix_sec = time_conversion.string_to_unix_sec(
            last_time_string, TIME_FORMAT)

    echo_top_tracking.reanalyze_across_spc_dates(
        top_input_dir_name=top_input_dir_name,
        top_output_dir_name=top_output_dir_name,
        first_spc_date_string=first_spc_date_string,
        last_spc_date_string=last_spc_date_string,
        first_time_unix_sec=first_time_unix_sec,
        last_time_unix_sec=last_time_unix_sec,
        max_velocity_diff_m_s01=max_velocity_diff_m_s01,
        max_link_distance_m_s01=max_link_distance_m_s01,
        max_join_time_seconds=max_join_time_seconds,
        max_join_error_m_s01=max_join_error_m_s01,
        min_track_duration_seconds=min_duration_seconds)
def _date_string_to_unix_sec(date_string):
    """Converts date from string to Unix format.

    :param date_string: String in format "xx yyyy mm dd HH MM SS", where the xx
        is meaningless.
    :return: unix_date_sec: Date in Unix format.
    """

    words = date_string.split()
    date_string = words[1] + '-' + words[2] + '-' + words[3]
    return time_conversion.string_to_unix_sec(date_string, TIME_FORMAT_DATE)
def _local_time_string_to_unix_sec(local_time_string, utc_offset_hours):
    """Converts time from local string to Unix format.

    :param local_time_string: Local time (format "yyyymmddHHMM").
    :param utc_offset_hours: Local time minus UTC.
    :return: unix_time_sec: Time in Unix format.
    """

    return time_conversion.string_to_unix_sec(
        local_time_string,
        TIME_FORMAT_HOUR_MINUTE) - (utc_offset_hours * HOURS_TO_SECONDS)
示例#9
0
def read_gridrad_stats_from_thea(csv_file_name):
    """Reads radar statistics created by GridRad software (file format by Thea).

    :param csv_file_name: Path to input file.
    :return: gridrad_statistic_table: pandas DataFrame with mandatory columns
        listed below.  Other column names come from the list
        `GRIDRAD_STATISTIC_NAMES`.
    gridrad_statistic_table.storm_number: Numeric ID (integer) for storm cell.
    gridrad_statistic_table.unix_time_sec: Valid time of storm object.
    """

    error_checking.assert_file_exists(csv_file_name)
    gridrad_statistic_table = pandas.read_csv(csv_file_name, header=0, sep=',')

    # Convert times from Thea's format to Unix format.
    unix_times_sec = numpy.array([
        time_conversion.string_to_unix_sec(s, GRIDRAD_TIME_FORMAT)
        for s in gridrad_statistic_table[TIME_NAME_GRIDRAD_ORIG].values
    ])
    gridrad_statistic_table = gridrad_statistic_table.assign(
        **{tracking_utils.TIME_COLUMN: unix_times_sec})

    columns_to_keep = GRIDRAD_STATISTIC_NAMES_ORIG + [
        STORM_NUMBER_NAME_GRIDRAD_ORIG, tracking_utils.TIME_COLUMN
    ]
    gridrad_statistic_table = gridrad_statistic_table[columns_to_keep]

    # Rename columns.
    column_dict_old_to_new = {
        STORM_NUMBER_NAME_GRIDRAD_ORIG: STORM_NUMBER_NAME_GRIDRAD,
        ECHO_TOP_40DBZ_NAME_GRIDRAD_ORIG: ECHO_TOP_40DBZ_NAME_GRIDRAD,
        SPECTRUM_WIDTH_NAME_GRIDRAD_ORIG: SPECTRUM_WIDTH_NAME_GRIDRAD,
        MAX_DIVERGENCE_NAME_GRIDRAD_ORIG: MAX_DIVERGENCE_NAME_GRIDRAD,
        UPPER_LEVEL_DIVERGENCE_NAME_GRIDRAD_ORIG:
        UPPER_LEVEL_DIVERGENCE_NAME_GRIDRAD,
        LOW_LEVEL_CONVERGENCE_NAME_GRIDRAD_ORIG:
        LOW_LEVEL_CONVERGENCE_NAME_GRIDRAD,
        DIVERGENCE_AREA_NAME_GRIDRAD_ORIG: DIVERGENCE_AREA_NAME_GRIDRAD,
        MAX_ROTATION_NAME_GRIDRAD_ORIG: MAX_ROTATION_NAME_GRIDRAD,
        UPPER_LEVEL_ROTATION_NAME_GRIDRAD_ORIG:
        UPPER_LEVEL_ROTATION_NAME_GRIDRAD,
        LOW_LEVEL_ROTATION_NAME_GRIDRAD_ORIG: LOW_LEVEL_ROTATION_NAME_GRIDRAD
    }

    gridrad_statistic_table.rename(columns=column_dict_old_to_new,
                                   inplace=True)

    # Convert units of divergence/convergence.
    gridrad_statistic_table[LOW_LEVEL_CONVERGENCE_NAME_GRIDRAD] *= -1
    for this_name in GRIDRAD_DIVERGENCE_NAMES:
        gridrad_statistic_table[
            this_name] *= CONVERSION_RATIO_FOR_GRIDRAD_DIVERGENCE

    return gridrad_statistic_table
示例#10
0
def file_name_to_time(gridrad_file_name):
    """Parses valid time from name of GridRad file.

    :param gridrad_file_name: Path to GridRad file.
    :return: unix_time_sec: Valid time.
    """

    _, pathless_file_name = os.path.split(gridrad_file_name)
    extensionless_file_name, _ = os.path.splitext(pathless_file_name)
    time_string = extensionless_file_name.split('_')[-1]
    return time_conversion.string_to_unix_sec(time_string,
                                              TIME_FORMAT_IN_FILE_NAMES)
示例#11
0
def _file_name_to_valid_time(bulletin_file_name):
    """Parses valid time from file name.

    :param bulletin_file_name: Path to input file (text file in WPC format).
    :return: valid_time_unix_sec: Valid time.
    """

    _, pathless_file_name = os.path.split(bulletin_file_name)
    valid_time_string = pathless_file_name.replace(
        PATHLESS_FILE_NAME_PREFIX + '_', '')

    return time_conversion.string_to_unix_sec(valid_time_string,
                                              TIME_FORMAT_IN_FILE_NAME)
示例#12
0
def processed_file_name_to_time(processed_file_name):
    """Parses time from name of processed file.

    :param processed_file_name: Name of processed file.
    :return: unix_time_sec: Valid time.
    """

    error_checking.assert_is_string(processed_file_name)
    _, pathless_file_name = os.path.split(processed_file_name)
    extensionless_file_name, _ = os.path.splitext(pathless_file_name)

    extensionless_file_name_parts = extensionless_file_name.split('_')
    return time_conversion.string_to_unix_sec(
        extensionless_file_name_parts[-1], TIME_FORMAT)
示例#13
0
def raw_file_name_to_time(raw_file_name):
    """Parses time from file name.

    :param raw_file_name: Path to raw file.
    :return: unix_time_sec: Valid time.
    """

    error_checking.assert_is_string(raw_file_name)

    _, time_string = os.path.split(raw_file_name)
    time_string = time_string.replace(ZIPPED_FILE_EXTENSION,
                                      '').replace(UNZIPPED_FILE_EXTENSION, '')

    return time_conversion.string_to_unix_sec(time_string, TIME_FORMAT_SECONDS)
示例#14
0
def file_name_to_time(tracking_file_name):
    """Parses valid time from tracking file.

    :param tracking_file_name: Path to tracking file.
    :return: valid_time_unix_sec: Valid time.
    """

    error_checking.assert_is_string(tracking_file_name)
    _, pathless_file_name = os.path.split(tracking_file_name)
    extensionless_file_name, _ = os.path.splitext(pathless_file_name)

    extensionless_file_name_parts = extensionless_file_name.split('_')

    return time_conversion.string_to_unix_sec(
        extensionless_file_name_parts[-1], FILE_NAME_TIME_FORMAT)
示例#15
0
def find_file(top_directory_name,
              field_name,
              month_string,
              is_surface=False,
              raise_error_if_missing=True):
    """Finds NetCDF file on the local machine.

    :param top_directory_name: Name of top-level directory with NetCDF files
        containing NARR data.
    :param field_name: See doc for `_get_pathless_file_name`.
    :param month_string: Same.
    :param is_surface: Same.
    :param raise_error_if_missing: Boolean flag.  If file is missing and
        `raise_error_if_missing = True`, this method will error out.
    :return: netcdf_file_name: Path to NetCDF file.  If file is missing and
        `raise_error_if_missing = False`, will return the *expected* path.
    :raises: ValueError: if file is missing and `raise_error_if_missing = True`.
    """

    error_checking.assert_is_string(top_directory_name)
    time_conversion.string_to_unix_sec(month_string, TIME_FORMAT_MONTH)
    error_checking.assert_is_boolean(is_surface)
    error_checking.assert_is_boolean(raise_error_if_missing)

    pathless_file_name = _get_pathless_file_name(field_name=field_name,
                                                 month_string=month_string,
                                                 is_surface=is_surface)
    netcdf_file_name = '{0:s}/{1:s}'.format(top_directory_name,
                                            pathless_file_name)

    if raise_error_if_missing and not os.path.isfile(netcdf_file_name):
        error_string = 'Cannot find file.  Expected at: "{0:s}"'.format(
            netcdf_file_name)
        raise ValueError(error_string)

    return netcdf_file_name
def raw_file_name_to_time(raw_file_name):
    """Parses valid time from name of raw (either ASCII or JSON) file.

    :param raw_file_name: Path to raw file.
    :return: unix_time_sec: Valid time.
    """

    error_checking.assert_is_string(raw_file_name)
    _, pathless_file_name = os.path.split(raw_file_name)
    extensionless_file_name, _ = os.path.splitext(pathless_file_name)

    time_string = extensionless_file_name.replace(RAW_FILE_NAME_PREFIX + '_',
                                                  '')
    time_string = time_string.replace(ALT_RAW_FILE_NAME_PREFIX + '_', '')
    return time_conversion.string_to_unix_sec(time_string,
                                              RAW_FILE_TIME_FORMAT)
def _get_num_days_in_month(month, year):
    """Returns number of days in month.

    :param month: Month (integer from 1...12).
    :param year: Year (integer).
    :return: num_days_in_month: Number of days in month.
    """

    time_string = '{0:04d}-{1:02d}'.format(year, month)
    unix_time_sec = time_conversion.string_to_unix_sec(time_string, '%Y-%m')
    (_, last_time_in_month_unix_sec
     ) = time_conversion.first_and_last_times_in_month(unix_time_sec)

    day_of_month_string = time_conversion.unix_sec_to_string(
        last_time_in_month_unix_sec, '%d')
    return int(day_of_month_string)
def processed_file_name_to_time(processed_file_name):
    """Parses time from name of processed tracking file.

    This file should contain storm objects (bounding polygons) and tracking
    statistics for one time step and one tracking scale.

    :param processed_file_name: Path to processed tracking file.
    :return: unix_time_sec: Valid time.
    """

    error_checking.assert_is_string(processed_file_name)
    _, pathless_file_name = os.path.split(processed_file_name)
    extensionless_file_name, _ = os.path.splitext(pathless_file_name)

    extensionless_file_name_parts = extensionless_file_name.split('_')
    return time_conversion.string_to_unix_sec(
        extensionless_file_name_parts[-1], TIME_FORMAT_IN_FILE_NAMES)
def _write_metadata_one_model(argument_dict):
    """Writes metadata for one upconvnet to file.

    :param argument_dict: See doc for `_train_one_upconvnet`.
    :return: metadata_dict: See doc for `upconvnet.write_model_metadata`.
    """

    from gewittergefahr.deep_learning import cnn
    from gewittergefahr.deep_learning import upconvnet
    from gewittergefahr.deep_learning import input_examples
    from gewittergefahr.scripts import train_upconvnet

    # Read input args.
    cnn_file_name = argument_dict[train_upconvnet.CNN_FILE_ARG_NAME]
    cnn_feature_layer_name = argument_dict[
        train_upconvnet.FEATURE_LAYER_ARG_NAME]

    top_training_dir_name = argument_dict[train_upconvnet.TRAINING_DIR_ARG_NAME]
    first_training_time_string = argument_dict[
        train_upconvnet.FIRST_TRAINING_TIME_ARG_NAME
    ]
    last_training_time_string = argument_dict[
        train_upconvnet.LAST_TRAINING_TIME_ARG_NAME]

    top_validation_dir_name = argument_dict[
        train_upconvnet.VALIDATION_DIR_ARG_NAME
    ]
    first_validation_time_string = argument_dict[
        train_upconvnet.FIRST_VALIDATION_TIME_ARG_NAME
    ]
    last_validation_time_string = argument_dict[
        train_upconvnet.LAST_VALIDATION_TIME_ARG_NAME]

    num_examples_per_batch = argument_dict[
        train_upconvnet.NUM_EX_PER_BATCH_ARG_NAME
    ]
    num_epochs = argument_dict[train_upconvnet.NUM_EPOCHS_ARG_NAME]
    num_training_batches_per_epoch = argument_dict[
        train_upconvnet.NUM_TRAINING_BATCHES_ARG_NAME
    ]
    num_validation_batches_per_epoch = argument_dict[
        train_upconvnet.NUM_VALIDATION_BATCHES_ARG_NAME
    ]
    output_dir_name = argument_dict[train_upconvnet.OUTPUT_DIR_ARG_NAME]

    # Process input args.
    first_training_time_unix_sec = time_conversion.string_to_unix_sec(
        first_training_time_string, TIME_FORMAT)
    last_training_time_unix_sec = time_conversion.string_to_unix_sec(
        last_training_time_string, TIME_FORMAT)

    first_validation_time_unix_sec = time_conversion.string_to_unix_sec(
        first_validation_time_string, TIME_FORMAT)
    last_validation_time_unix_sec = time_conversion.string_to_unix_sec(
        last_validation_time_string, TIME_FORMAT)

    # Find training and validation files.
    training_file_names = input_examples.find_many_example_files(
        top_directory_name=top_training_dir_name, shuffled=True,
        first_batch_number=FIRST_BATCH_NUMBER,
        last_batch_number=LAST_BATCH_NUMBER, raise_error_if_any_missing=False)

    validation_file_names = input_examples.find_many_example_files(
        top_directory_name=top_validation_dir_name, shuffled=True,
        first_batch_number=FIRST_BATCH_NUMBER,
        last_batch_number=LAST_BATCH_NUMBER, raise_error_if_any_missing=False)

    # Write metadata.
    upconvnet_metafile_name = cnn.find_metafile(
        model_file_name='{0:s}/foo.h5'.format(output_dir_name),
        raise_error_if_missing=False
    )
    print('Writing upconvnet metadata to: "{0:s}"...'.format(
        upconvnet_metafile_name
    ))

    return upconvnet.write_model_metadata(
        cnn_file_name=cnn_file_name,
        cnn_feature_layer_name=cnn_feature_layer_name, num_epochs=num_epochs,
        num_examples_per_batch=num_examples_per_batch,
        num_training_batches_per_epoch=num_training_batches_per_epoch,
        training_example_file_names=training_file_names,
        first_training_time_unix_sec=first_training_time_unix_sec,
        last_training_time_unix_sec=last_training_time_unix_sec,
        num_validation_batches_per_epoch=num_validation_batches_per_epoch,
        validation_example_file_names=validation_file_names,
        first_validation_time_unix_sec=first_validation_time_unix_sec,
        last_validation_time_unix_sec=last_validation_time_unix_sec,
        pickle_file_name=upconvnet_metafile_name)
from gewittergefahr.gg_utils import soundings
from gewittergefahr.gg_utils import echo_top_tracking
from gewittergefahr.gg_utils import time_conversion
from gewittergefahr.gg_utils import number_rounding
from gewittergefahr.gg_utils import storm_tracking_utils as tracking_utils

HOURS_TO_SECONDS = 3600
STORM_TIME_FORMAT = '%Y-%m-%d-%H%M%S'
MODEL_INIT_TIME_FORMAT = '%Y-%m-%d-%H'
SEPARATOR_STRING = '\n\n' + '*' * 50 + '\n\n'

WGRIB_EXE_NAME = '/condo/swatwork/ralager/wgrib/wgrib'
WGRIB2_EXE_NAME = '/condo/swatwork/ralager/grib2/wgrib2/wgrib2'

FIRST_RAP_TIME_STRING = '2012-05-01-00'
FIRST_RAP_TIME_UNIX_SEC = time_conversion.string_to_unix_sec(
    FIRST_RAP_TIME_STRING, MODEL_INIT_TIME_FORMAT)

SPC_DATE_ARG_NAME = 'spc_date_string'
LEAD_TIMES_ARG_NAME = 'lead_times_seconds'
LAG_TIME_ARG_NAME = 'lag_time_for_convective_contamination_sec'
RUC_DIRECTORY_ARG_NAME = 'input_ruc_directory_name'
RAP_DIRECTORY_ARG_NAME = 'input_rap_directory_name'
TRACKING_DIR_ARG_NAME = 'input_tracking_dir_name'
TRACKING_SCALE_ARG_NAME = 'tracking_scale_metres2'
OUTPUT_DIR_ARG_NAME = 'output_sounding_dir_name'

SPC_DATE_HELP_STRING = (
    'SPC (Storm Prediction Center) date in format "yyyymmdd".  The RUC (Rapid '
    'Update Cycle) sounding will be interpolated to each storm object on this '
    'date, at each lead time in `{0:s}`.').format(LEAD_TIMES_ARG_NAME)
示例#21
0
def _run(net_type_string, training_dir_name, validation_dir_name,
         input_model_file_name, output_model_dir_name,
         use_generator_for_training, use_generator_for_validn, predictor_names,
         target_names, heights_m_agl, omit_heating_rate,
         first_training_time_string, last_training_time_string,
         first_validn_time_string, last_validn_time_string,
         normalization_file_name, predictor_norm_type_string,
         predictor_min_norm_value, predictor_max_norm_value,
         vector_target_norm_type_string, vector_target_min_norm_value,
         vector_target_max_norm_value, scalar_target_norm_type_string,
         scalar_target_min_norm_value, scalar_target_max_norm_value,
         num_examples_per_batch, num_epochs, num_training_batches_per_epoch,
         num_validn_batches_per_epoch, plateau_lr_multiplier):
    """Trains neural net

    :param net_type_string: See documentation at top of training_args.py.
    :param training_dir_name: Same.
    :param validation_dir_name: Same.
    :param input_model_file_name: Same.
    :param output_model_dir_name: Same.
    :param use_generator_for_training: Same.
    :param use_generator_for_validn: Same.
    :param predictor_names: Same.
    :param target_names: Same.
    :param heights_m_agl: Same.
    :param omit_heating_rate: Same.
    :param first_training_time_string: Same.
    :param last_training_time_string: Same.
    :param first_validn_time_string: Same.
    :param last_validn_time_string: Same.
    :param normalization_file_name: Same.
    :param predictor_norm_type_string: Same.
    :param predictor_min_norm_value: Same.
    :param predictor_max_norm_value: Same.
    :param vector_target_norm_type_string: Same.
    :param vector_target_min_norm_value: Same.
    :param vector_target_max_norm_value: Same.
    :param scalar_target_norm_type_string: Same.
    :param scalar_target_min_norm_value: Same.
    :param scalar_target_max_norm_value: Same.
    :param num_examples_per_batch: Same.
    :param num_epochs: Same.
    :param num_training_batches_per_epoch: Same.
    :param num_validn_batches_per_epoch: Same.
    :param plateau_lr_multiplier: Same.
    """

    if predictor_norm_type_string in NONE_STRINGS:
        predictor_norm_type_string = None
    if vector_target_norm_type_string in NONE_STRINGS:
        vector_target_norm_type_string = None
    if scalar_target_norm_type_string in NONE_STRINGS:
        scalar_target_norm_type_string = None

    neural_net.check_net_type(net_type_string)

    if len(heights_m_agl) and heights_m_agl[0] <= 0:
        heights_m_agl = (
            training_args.NET_TYPE_TO_DEFAULT_HEIGHTS_M_AGL[net_type_string])

    for n in predictor_names:
        example_utils.check_field_name(n)

    scalar_predictor_names = [
        n for n in predictor_names
        if n in example_utils.ALL_SCALAR_PREDICTOR_NAMES
    ]
    vector_predictor_names = [
        n for n in predictor_names
        if n in example_utils.ALL_VECTOR_PREDICTOR_NAMES
    ]

    for n in target_names:
        example_utils.check_field_name(n)

    scalar_target_names = [
        n for n in target_names if n in example_utils.ALL_SCALAR_TARGET_NAMES
    ]
    vector_target_names = [
        n for n in target_names if n in example_utils.ALL_VECTOR_TARGET_NAMES
    ]

    first_training_time_unix_sec = time_conversion.string_to_unix_sec(
        first_training_time_string, training_args.TIME_FORMAT)
    last_training_time_unix_sec = time_conversion.string_to_unix_sec(
        last_training_time_string, training_args.TIME_FORMAT)
    first_validn_time_unix_sec = time_conversion.string_to_unix_sec(
        first_validn_time_string, training_args.TIME_FORMAT)
    last_validn_time_unix_sec = time_conversion.string_to_unix_sec(
        last_validn_time_string, training_args.TIME_FORMAT)

    training_option_dict = {
        neural_net.EXAMPLE_DIRECTORY_KEY: training_dir_name,
        neural_net.BATCH_SIZE_KEY: num_examples_per_batch,
        neural_net.SCALAR_PREDICTOR_NAMES_KEY: scalar_predictor_names,
        neural_net.VECTOR_PREDICTOR_NAMES_KEY: vector_predictor_names,
        neural_net.SCALAR_TARGET_NAMES_KEY: scalar_target_names,
        neural_net.VECTOR_TARGET_NAMES_KEY: vector_target_names,
        neural_net.HEIGHTS_KEY: heights_m_agl,
        neural_net.OMIT_HEATING_RATE_KEY: omit_heating_rate,
        neural_net.NORMALIZATION_FILE_KEY: normalization_file_name,
        neural_net.PREDICTOR_NORM_TYPE_KEY: predictor_norm_type_string,
        neural_net.PREDICTOR_MIN_NORM_VALUE_KEY: predictor_min_norm_value,
        neural_net.PREDICTOR_MAX_NORM_VALUE_KEY: predictor_max_norm_value,
        neural_net.VECTOR_TARGET_NORM_TYPE_KEY: vector_target_norm_type_string,
        neural_net.VECTOR_TARGET_MIN_VALUE_KEY: vector_target_min_norm_value,
        neural_net.VECTOR_TARGET_MAX_VALUE_KEY: vector_target_max_norm_value,
        neural_net.SCALAR_TARGET_NORM_TYPE_KEY: scalar_target_norm_type_string,
        neural_net.SCALAR_TARGET_MIN_VALUE_KEY: scalar_target_min_norm_value,
        neural_net.SCALAR_TARGET_MAX_VALUE_KEY: scalar_target_max_norm_value,
        neural_net.FIRST_TIME_KEY: first_training_time_unix_sec,
        neural_net.LAST_TIME_KEY: last_training_time_unix_sec,
        # neural_net.MIN_COLUMN_LWP_KEY: 0.05,
        # neural_net.MAX_COLUMN_LWP_KEY: 1e12
    }

    validation_option_dict = {
        neural_net.EXAMPLE_DIRECTORY_KEY: validation_dir_name,
        neural_net.BATCH_SIZE_KEY: num_examples_per_batch,
        neural_net.FIRST_TIME_KEY: first_validn_time_unix_sec,
        neural_net.LAST_TIME_KEY: last_validn_time_unix_sec
    }

    if input_model_file_name in NONE_STRINGS:
        model_object = u_net_architecture.create_model(
            option_dict=DEFAULT_ARCHITECTURE_OPTION_DICT,
            vector_loss_function=DEFAULT_VECTOR_LOSS_FUNCTION,
            scalar_loss_function=DEFAULT_SCALAR_LOSS_FUNCTION,
            num_output_channels=1)

        loss_function_or_dict = {
            'conv_output': DEFAULT_VECTOR_LOSS_FUNCTION,
            'dense_output': DEFAULT_SCALAR_LOSS_FUNCTION
        }
    else:
        print('Reading untrained model from: "{0:s}"...'.format(
            input_model_file_name))
        model_object = neural_net.read_model(input_model_file_name)

        input_metafile_name = neural_net.find_metafile(
            model_dir_name=os.path.split(input_model_file_name)[0])

        print('Reading loss function(s) from: "{0:s}"...'.format(
            input_metafile_name))
        loss_function_or_dict = neural_net.read_metafile(input_metafile_name)[
            neural_net.LOSS_FUNCTION_OR_DICT_KEY]

    print(SEPARATOR_STRING)

    if use_generator_for_training:
        neural_net.train_model_with_generator(
            model_object=model_object,
            output_dir_name=output_model_dir_name,
            num_epochs=num_epochs,
            num_training_batches_per_epoch=num_training_batches_per_epoch,
            training_option_dict=training_option_dict,
            use_generator_for_validn=use_generator_for_validn,
            num_validation_batches_per_epoch=num_validn_batches_per_epoch,
            validation_option_dict=validation_option_dict,
            net_type_string=net_type_string,
            loss_function_or_dict=loss_function_or_dict,
            do_early_stopping=True,
            plateau_lr_multiplier=plateau_lr_multiplier)
    else:
        neural_net.train_model_sans_generator(
            model_object=model_object,
            output_dir_name=output_model_dir_name,
            num_epochs=num_epochs,
            training_option_dict=training_option_dict,
            validation_option_dict=validation_option_dict,
            net_type_string=net_type_string,
            loss_function_or_dict=loss_function_or_dict,
            do_early_stopping=True,
            plateau_lr_multiplier=plateau_lr_multiplier)
示例#22
0
Since the NARR is a reanalysis, valid time = initialization time always.  In
other words, all "forecasts" are zero-hour forecasts (analyses).
"""

import os.path
import numpy
from gewittergefahr.gg_io import netcdf_io
from gewittergefahr.gg_io import downloads
from gewittergefahr.gg_utils import time_conversion
from gewittergefahr.gg_utils import error_checking
from generalexam.ge_io import processed_narr_io

SENTINEL_VALUE = -9e36
HOURS_TO_SECONDS = 3600
NARR_ZERO_TIME_UNIX_SEC = time_conversion.string_to_unix_sec(
    '1800-01-01-00', '%Y-%m-%d-%H')

TIME_FORMAT_MONTH = '%Y%m'
NETCDF_FILE_EXTENSION = '.nc'

ONLINE_SURFACE_DIR_NAME = 'ftp://ftp.cdc.noaa.gov/Datasets/NARR/monolevel'
ONLINE_PRESSURE_LEVEL_DIR_NAME = 'ftp://ftp.cdc.noaa.gov/Datasets/NARR/pressure'

TEMPERATURE_NAME_NETCDF = 'air'
HEIGHT_NAME_NETCDF = 'hgt'
VERTICAL_VELOCITY_NAME_NETCDF = 'omega'
SPECIFIC_HUMIDITY_NAME_NETCDF = 'shum'
U_WIND_NAME_NETCDF = 'uwnd'
V_WIND_NAME_NETCDF = 'vwnd'

VALID_FIELD_NAMES_NETCDF = [
示例#23
0
def _write_metadata_one_cnn(model_object, argument_dict):
    """Writes metadata for one CNN to file.

    :param model_object: Untrained CNN (instance of `keras.models.Model` or
        `keras.models.Sequential`).
    :param argument_dict: See doc for `_train_one_cnn`.
    :return: metadata_dict: See doc for `cnn.write_model_metadata`.
    :return: training_option_dict: Same.
    """

    from gewittergefahr.deep_learning import cnn
    from gewittergefahr.deep_learning import input_examples
    from gewittergefahr.deep_learning import \
        training_validation_io as trainval_io
    from gewittergefahr.scripts import deep_learning_helper as dl_helper

    # Read input args.
    sounding_field_names = argument_dict[dl_helper.SOUNDING_FIELDS_ARG_NAME]
    radar_field_name_by_channel = argument_dict[RADAR_FIELDS_KEY]
    layer_op_name_by_channel = argument_dict[LAYER_OPERATIONS_KEY]
    min_height_by_channel_m_agl = argument_dict[MIN_HEIGHTS_KEY]
    max_height_by_channel_m_agl = argument_dict[MAX_HEIGHTS_KEY]

    normalization_type_string = argument_dict[
        dl_helper.NORMALIZATION_TYPE_ARG_NAME]
    normalization_file_name = argument_dict[
        dl_helper.NORMALIZATION_FILE_ARG_NAME]
    min_normalized_value = argument_dict[dl_helper.MIN_NORM_VALUE_ARG_NAME]
    max_normalized_value = argument_dict[dl_helper.MAX_NORM_VALUE_ARG_NAME]

    target_name = argument_dict[dl_helper.TARGET_NAME_ARG_NAME]
    downsampling_classes = numpy.array(
        argument_dict[dl_helper.DOWNSAMPLING_CLASSES_ARG_NAME], dtype=int)
    downsampling_fractions = numpy.array(
        argument_dict[dl_helper.DOWNSAMPLING_FRACTIONS_ARG_NAME], dtype=float)

    monitor_string = argument_dict[dl_helper.MONITOR_ARG_NAME]
    weight_loss_function = bool(argument_dict[dl_helper.WEIGHT_LOSS_ARG_NAME])

    x_translations_pixels = numpy.array(
        argument_dict[dl_helper.X_TRANSLATIONS_ARG_NAME], dtype=int)
    y_translations_pixels = numpy.array(
        argument_dict[dl_helper.Y_TRANSLATIONS_ARG_NAME], dtype=int)
    ccw_rotation_angles_deg = numpy.array(
        argument_dict[dl_helper.ROTATION_ANGLES_ARG_NAME], dtype=float)
    noise_standard_deviation = argument_dict[dl_helper.NOISE_STDEV_ARG_NAME]
    num_noisings = argument_dict[dl_helper.NUM_NOISINGS_ARG_NAME]
    flip_in_x = bool(argument_dict[dl_helper.FLIP_X_ARG_NAME])
    flip_in_y = bool(argument_dict[dl_helper.FLIP_Y_ARG_NAME])

    top_training_dir_name = argument_dict[dl_helper.TRAINING_DIR_ARG_NAME]
    first_training_time_string = argument_dict[
        dl_helper.FIRST_TRAINING_TIME_ARG_NAME]
    last_training_time_string = argument_dict[
        dl_helper.LAST_TRAINING_TIME_ARG_NAME]
    num_examples_per_train_batch = argument_dict[
        dl_helper.NUM_EX_PER_TRAIN_ARG_NAME]

    top_validation_dir_name = argument_dict[dl_helper.VALIDATION_DIR_ARG_NAME]
    first_validation_time_string = argument_dict[
        dl_helper.FIRST_VALIDATION_TIME_ARG_NAME]
    last_validation_time_string = argument_dict[
        dl_helper.LAST_VALIDATION_TIME_ARG_NAME]
    num_examples_per_validn_batch = argument_dict[
        dl_helper.NUM_EX_PER_VALIDN_ARG_NAME]

    num_epochs = argument_dict[dl_helper.NUM_EPOCHS_ARG_NAME]
    num_training_batches_per_epoch = argument_dict[
        dl_helper.NUM_TRAINING_BATCHES_ARG_NAME]
    num_validation_batches_per_epoch = argument_dict[
        dl_helper.NUM_VALIDATION_BATCHES_ARG_NAME]
    output_dir_name = argument_dict[dl_helper.OUTPUT_DIR_ARG_NAME]

    # Process input args.
    first_training_time_unix_sec = time_conversion.string_to_unix_sec(
        first_training_time_string, TIME_FORMAT)
    last_training_time_unix_sec = time_conversion.string_to_unix_sec(
        last_training_time_string, TIME_FORMAT)

    first_validation_time_unix_sec = time_conversion.string_to_unix_sec(
        first_validation_time_string, TIME_FORMAT)
    last_validation_time_unix_sec = time_conversion.string_to_unix_sec(
        last_validation_time_string, TIME_FORMAT)

    if sounding_field_names[0] in ['', 'None']:
        sounding_field_names = None

    num_channels = len(radar_field_name_by_channel)
    layer_operation_dicts = [{}] * num_channels

    for k in range(num_channels):
        layer_operation_dicts[k] = {
            input_examples.RADAR_FIELD_KEY: radar_field_name_by_channel[k],
            input_examples.OPERATION_NAME_KEY: layer_op_name_by_channel[k],
            input_examples.MIN_HEIGHT_KEY: min_height_by_channel_m_agl[k],
            input_examples.MAX_HEIGHT_KEY: max_height_by_channel_m_agl[k]
        }

    if len(downsampling_classes) > 1:
        downsampling_dict = dict(
            list(zip(downsampling_classes, downsampling_fractions)))
    else:
        downsampling_dict = None

    translate_flag = (len(x_translations_pixels) > 1
                      or x_translations_pixels[0] != 0
                      or y_translations_pixels[0] != 0)

    if not translate_flag:
        x_translations_pixels = None
        y_translations_pixels = None

    if len(ccw_rotation_angles_deg) == 1 and ccw_rotation_angles_deg[0] == 0:
        ccw_rotation_angles_deg = None

    if num_noisings <= 0:
        num_noisings = 0
        noise_standard_deviation = None

    # Find training and validation files.
    training_file_names = input_examples.find_many_example_files(
        top_directory_name=top_training_dir_name,
        shuffled=True,
        first_batch_number=FIRST_BATCH_NUMBER,
        last_batch_number=LAST_BATCH_NUMBER,
        raise_error_if_any_missing=False)

    validation_file_names = input_examples.find_many_example_files(
        top_directory_name=top_validation_dir_name,
        shuffled=True,
        first_batch_number=FIRST_BATCH_NUMBER,
        last_batch_number=LAST_BATCH_NUMBER,
        raise_error_if_any_missing=False)

    # Write metadata.
    metadata_dict = {
        cnn.NUM_EPOCHS_KEY: num_epochs,
        cnn.NUM_TRAINING_BATCHES_KEY: num_training_batches_per_epoch,
        cnn.NUM_VALIDATION_BATCHES_KEY: num_validation_batches_per_epoch,
        cnn.MONITOR_STRING_KEY: monitor_string,
        cnn.WEIGHT_LOSS_FUNCTION_KEY: weight_loss_function,
        cnn.CONV_2D3D_KEY: False,
        cnn.VALIDATION_FILES_KEY: validation_file_names,
        cnn.FIRST_VALIDN_TIME_KEY: first_validation_time_unix_sec,
        cnn.LAST_VALIDN_TIME_KEY: last_validation_time_unix_sec,
        cnn.LAYER_OPERATIONS_KEY: layer_operation_dicts,
        cnn.NUM_EX_PER_VALIDN_BATCH_KEY: num_examples_per_validn_batch
    }

    input_tensor = model_object.input
    if isinstance(input_tensor, list):
        input_tensor = input_tensor[0]

    num_grid_rows = input_tensor.get_shape().as_list()[1]
    num_grid_columns = input_tensor.get_shape().as_list()[2]

    training_option_dict = {
        trainval_io.EXAMPLE_FILES_KEY: training_file_names,
        trainval_io.TARGET_NAME_KEY: target_name,
        trainval_io.FIRST_STORM_TIME_KEY: first_training_time_unix_sec,
        trainval_io.LAST_STORM_TIME_KEY: last_training_time_unix_sec,
        trainval_io.NUM_EXAMPLES_PER_BATCH_KEY: num_examples_per_train_batch,
        trainval_io.SOUNDING_FIELDS_KEY: sounding_field_names,
        trainval_io.SOUNDING_HEIGHTS_KEY: SOUNDING_HEIGHTS_M_AGL,
        trainval_io.NUM_ROWS_KEY: num_grid_rows,
        trainval_io.NUM_COLUMNS_KEY: num_grid_columns,
        trainval_io.NORMALIZATION_TYPE_KEY: normalization_type_string,
        trainval_io.NORMALIZATION_FILE_KEY: normalization_file_name,
        trainval_io.MIN_NORMALIZED_VALUE_KEY: min_normalized_value,
        trainval_io.MAX_NORMALIZED_VALUE_KEY: max_normalized_value,
        trainval_io.BINARIZE_TARGET_KEY: False,
        trainval_io.SAMPLING_FRACTIONS_KEY: downsampling_dict,
        trainval_io.LOOP_ONCE_KEY: False,
        trainval_io.X_TRANSLATIONS_KEY: x_translations_pixels,
        trainval_io.Y_TRANSLATIONS_KEY: y_translations_pixels,
        trainval_io.ROTATION_ANGLES_KEY: ccw_rotation_angles_deg,
        trainval_io.NOISE_STDEV_KEY: noise_standard_deviation,
        trainval_io.NUM_NOISINGS_KEY: num_noisings,
        trainval_io.FLIP_X_KEY: flip_in_x,
        trainval_io.FLIP_Y_KEY: flip_in_y
    }

    file_system_utils.mkdir_recursive_if_necessary(
        directory_name=output_dir_name)
    metafile_name = '{0:s}/model_metadata.p'.format(output_dir_name)

    print('Writing metadata to: "{0:s}"...'.format(metafile_name))
    cnn.write_model_metadata(pickle_file_name=metafile_name,
                             metadata_dict=metadata_dict,
                             training_option_dict=training_option_dict)

    return metadata_dict, training_option_dict
示例#24
0
def _run(input_cnn_file_name, input_upconvnet_file_name,
         cnn_feature_layer_name, top_training_dir_name,
         first_training_time_string, last_training_time_string,
         top_validation_dir_name, first_validation_time_string,
         last_validation_time_string, num_examples_per_batch, num_epochs,
         num_training_batches_per_epoch, num_validation_batches_per_epoch,
         output_dir_name):
    """Trains upconvnet.

    This is effectively the main method.

    :param input_cnn_file_name: See documentation at top of file.
    :param input_upconvnet_file_name: Same.
    :param cnn_feature_layer_name: Same.
    :param top_training_dir_name: Same.
    :param first_training_time_string: Same.
    :param last_training_time_string: Same.
    :param top_validation_dir_name: Same.
    :param first_validation_time_string: Same.
    :param last_validation_time_string: Same.
    :param num_examples_per_batch: Same.
    :param num_epochs: Same.
    :param num_training_batches_per_epoch: Same.
    :param num_validation_batches_per_epoch: Same.
    :param output_dir_name: Same.
    """

    file_system_utils.mkdir_recursive_if_necessary(
        directory_name=output_dir_name)
    # argument_file_name = '{0:s}/input_args.p'.format(output_dir_name)
    # print('Writing input args to: "{0:s}"...'.format(argument_file_name))
    #
    # argument_file_handle = open(argument_file_name, 'wb')
    # pickle.dump(INPUT_ARG_OBJECT.__dict__, argument_file_handle)
    # argument_file_handle.close()
    #
    # return

    # Process input args.
    first_training_time_unix_sec = time_conversion.string_to_unix_sec(
        first_training_time_string, TIME_FORMAT)
    last_training_time_unix_sec = time_conversion.string_to_unix_sec(
        last_training_time_string, TIME_FORMAT)

    first_validation_time_unix_sec = time_conversion.string_to_unix_sec(
        first_validation_time_string, TIME_FORMAT)
    last_validation_time_unix_sec = time_conversion.string_to_unix_sec(
        last_validation_time_string, TIME_FORMAT)

    # Find training and validation files.
    training_file_names = input_examples.find_many_example_files(
        top_directory_name=top_training_dir_name,
        shuffled=True,
        first_batch_number=FIRST_BATCH_NUMBER,
        last_batch_number=LAST_BATCH_NUMBER,
        raise_error_if_any_missing=False)

    validation_file_names = input_examples.find_many_example_files(
        top_directory_name=top_validation_dir_name,
        shuffled=True,
        first_batch_number=FIRST_BATCH_NUMBER,
        last_batch_number=LAST_BATCH_NUMBER,
        raise_error_if_any_missing=False)

    # Read trained CNN.
    print('Reading trained CNN from: "{0:s}"...'.format(input_cnn_file_name))
    cnn_model_object = cnn.read_model(input_cnn_file_name)
    cnn_model_object.summary()
    print(SEPARATOR_STRING)

    cnn_metafile_name = cnn.find_metafile(model_file_name=input_cnn_file_name,
                                          raise_error_if_missing=True)

    print('Reading CNN metadata from: "{0:s}"...'.format(cnn_metafile_name))
    cnn_metadata_dict = cnn.read_model_metadata(cnn_metafile_name)

    # Read architecture.
    print('Reading upconvnet architecture from: "{0:s}"...'.format(
        input_upconvnet_file_name))
    upconvnet_model_object = cnn.read_model(input_upconvnet_file_name)
    # upconvnet_model_object = keras.models.clone_model(upconvnet_model_object)

    # TODO(thunderhoser): This is a HACK.
    upconvnet_model_object.compile(loss=keras.losses.mean_squared_error,
                                   optimizer=keras.optimizers.Adam())

    upconvnet_model_object.summary()
    print(SEPARATOR_STRING)

    upconvnet_metafile_name = cnn.find_metafile(
        model_file_name='{0:s}/foo.h5'.format(output_dir_name),
        raise_error_if_missing=False)
    print('Writing upconvnet metadata to: "{0:s}"...'.format(
        upconvnet_metafile_name))

    upconvnet.write_model_metadata(
        cnn_file_name=input_cnn_file_name,
        cnn_feature_layer_name=cnn_feature_layer_name,
        num_epochs=num_epochs,
        num_examples_per_batch=num_examples_per_batch,
        num_training_batches_per_epoch=num_training_batches_per_epoch,
        training_example_file_names=training_file_names,
        first_training_time_unix_sec=first_training_time_unix_sec,
        last_training_time_unix_sec=last_training_time_unix_sec,
        num_validation_batches_per_epoch=num_validation_batches_per_epoch,
        validation_example_file_names=validation_file_names,
        first_validation_time_unix_sec=first_validation_time_unix_sec,
        last_validation_time_unix_sec=last_validation_time_unix_sec,
        pickle_file_name=upconvnet_metafile_name)

    print(SEPARATOR_STRING)

    upconvnet.train_upconvnet(
        upconvnet_model_object=upconvnet_model_object,
        output_dir_name=output_dir_name,
        cnn_model_object=cnn_model_object,
        cnn_metadata_dict=cnn_metadata_dict,
        cnn_feature_layer_name=cnn_feature_layer_name,
        num_epochs=num_epochs,
        num_examples_per_batch=num_examples_per_batch,
        num_training_batches_per_epoch=num_training_batches_per_epoch,
        training_example_file_names=training_file_names,
        first_training_time_unix_sec=first_training_time_unix_sec,
        last_training_time_unix_sec=last_training_time_unix_sec,
        num_validation_batches_per_epoch=num_validation_batches_per_epoch,
        validation_example_file_names=validation_file_names,
        first_validation_time_unix_sec=first_validation_time_unix_sec,
        last_validation_time_unix_sec=last_validation_time_unix_sec)
示例#25
0
def _run(input_model_file_name, narr_predictor_names, pressure_level_mb,
         dilation_distance_metres, num_lead_time_steps,
         predictor_time_step_offsets, num_examples_per_time,
         weight_loss_function, class_fractions, top_narr_directory_name,
         top_frontal_grid_dir_name, narr_mask_file_name,
         first_training_time_string, last_training_time_string,
         first_validation_time_string, last_validation_time_string,
         num_examples_per_batch, num_epochs, num_training_batches_per_epoch,
         num_validation_batches_per_epoch, output_model_file_name):
    """Trains CNN from scratch.

    This is effectively the main method.

    :param input_model_file_name: See documentation at top of file.
    :param narr_predictor_names: Same.
    :param pressure_level_mb: Same.
    :param dilation_distance_metres: Same.
    :param num_lead_time_steps: Same.
    :param predictor_time_step_offsets: Same.
    :param num_examples_per_time: Same.
    :param weight_loss_function: Same.
    :param class_fractions: Same.
    :param top_narr_directory_name: Same.
    :param top_frontal_grid_dir_name: Same.
    :param narr_mask_file_name: Same.
    :param first_training_time_string: Same.
    :param last_training_time_string: Same.
    :param first_validation_time_string: Same.
    :param last_validation_time_string: Same.
    :param num_examples_per_batch: Same.
    :param num_epochs: Same.
    :param num_training_batches_per_epoch: Same.
    :param num_validation_batches_per_epoch: Same.
    :param output_model_file_name: Same.
    :raises: ValueError: if `num_lead_time_steps > 1`.
    """

    # Process input args.
    first_training_time_unix_sec = time_conversion.string_to_unix_sec(
        first_training_time_string, TIME_FORMAT)
    last_training_time_unix_sec = time_conversion.string_to_unix_sec(
        last_training_time_string, TIME_FORMAT)

    first_validation_time_unix_sec = time_conversion.string_to_unix_sec(
        first_validation_time_string, TIME_FORMAT)
    last_validation_time_unix_sec = time_conversion.string_to_unix_sec(
        last_validation_time_string, TIME_FORMAT)

    if narr_mask_file_name == '':
        narr_mask_file_name = None
        narr_mask_matrix = None

    if num_lead_time_steps <= 1:
        num_lead_time_steps = None
        predictor_time_step_offsets = None
    else:
        error_string = (
            'This script cannot yet handle num_lead_time_steps > 1 '
            '(specifically {0:d}).'
        ).format(num_lead_time_steps)

        raise ValueError(error_string)

    # Read architecture.
    print 'Reading architecture from: "{0:s}"...'.format(input_model_file_name)
    model_object = traditional_cnn.read_keras_model(input_model_file_name)
    model_object = keras.models.clone_model(model_object)

    # TODO(thunderhoser): This is a HACK.
    model_object.compile(
        loss=keras.losses.categorical_crossentropy,
        optimizer=keras.optimizers.Adam(),
        metrics=traditional_cnn.LIST_OF_METRIC_FUNCTIONS)

    print SEPARATOR_STRING
    model_object.summary()
    print SEPARATOR_STRING

    # Write metadata.
    input_tensor = model_object.input
    num_grid_rows = input_tensor.get_shape().as_list()[1]
    num_grid_columns = input_tensor.get_shape().as_list()[2]

    num_half_rows = int(numpy.round((num_grid_rows - 1) / 2))
    num_half_columns = int(numpy.round((num_grid_columns - 1) / 2))

    if narr_mask_file_name is not None:
        print 'Reading NARR mask from: "{0:s}"...'.format(narr_mask_file_name)
        narr_mask_matrix = ml_utils.read_narr_mask(narr_mask_file_name)

    model_metafile_name = traditional_cnn.find_metafile(
        model_file_name=output_model_file_name, raise_error_if_missing=False)
    print 'Writing metadata to: "{0:s}"...'.format(model_metafile_name)

    traditional_cnn.write_model_metadata(
        pickle_file_name=model_metafile_name, num_epochs=num_epochs,
        num_examples_per_batch=num_examples_per_batch,
        num_examples_per_target_time=num_examples_per_time,
        num_training_batches_per_epoch=num_training_batches_per_epoch,
        num_validation_batches_per_epoch=num_validation_batches_per_epoch,
        num_rows_in_half_grid=num_half_rows,
        num_columns_in_half_grid=num_half_columns,
        dilation_distance_metres=dilation_distance_metres,
        class_fractions=class_fractions,
        weight_loss_function=weight_loss_function,
        narr_predictor_names=narr_predictor_names,
        pressure_level_mb=pressure_level_mb,
        training_start_time_unix_sec=first_training_time_unix_sec,
        training_end_time_unix_sec=last_training_time_unix_sec,
        validation_start_time_unix_sec=first_validation_time_unix_sec,
        validation_end_time_unix_sec=last_validation_time_unix_sec,
        num_lead_time_steps=num_lead_time_steps,
        predictor_time_step_offsets=predictor_time_step_offsets,
        narr_mask_matrix=narr_mask_matrix)

    print SEPARATOR_STRING

    traditional_cnn.train_with_3d_examples(
        model_object=model_object, output_file_name=output_model_file_name,
        num_examples_per_batch=num_examples_per_batch, num_epochs=num_epochs,
        num_training_batches_per_epoch=num_training_batches_per_epoch,
        num_examples_per_target_time=num_examples_per_time,
        training_start_time_unix_sec=first_training_time_unix_sec,
        training_end_time_unix_sec=last_training_time_unix_sec,
        top_narr_directory_name=top_narr_directory_name,
        top_frontal_grid_dir_name=top_frontal_grid_dir_name,
        narr_predictor_names=narr_predictor_names,
        pressure_level_mb=pressure_level_mb,
        dilation_distance_metres=dilation_distance_metres,
        class_fractions=class_fractions,
        num_rows_in_half_grid=num_half_rows,
        num_columns_in_half_grid=num_half_columns,
        weight_loss_function=weight_loss_function,
        num_validation_batches_per_epoch=num_validation_batches_per_epoch,
        validation_start_time_unix_sec=first_validation_time_unix_sec,
        validation_end_time_unix_sec=last_validation_time_unix_sec,
        narr_mask_matrix=narr_mask_matrix)
def _run(model_file_name, first_time_string, last_time_string, randomize_times,
         num_target_times, use_isotonic_regression, top_narr_directory_name,
         top_frontal_grid_dir_name, output_dir_name):
    """Applies traditional CNN to full grids.

    This is effectively the main method.

    :param model_file_name: See documentation at top of file.
    :param first_time_string: Same.
    :param last_time_string: Same.
    :param randomize_times: Same.
    :param num_target_times: Same.
    :param use_isotonic_regression: Same.
    :param top_narr_directory_name: Same.
    :param top_frontal_grid_dir_name: Same.
    :param output_dir_name: Same.
    """

    first_time_unix_sec = time_conversion.string_to_unix_sec(
        first_time_string, INPUT_TIME_FORMAT)
    last_time_unix_sec = time_conversion.string_to_unix_sec(
        last_time_string, INPUT_TIME_FORMAT)
    target_times_unix_sec = time_periods.range_and_interval_to_list(
        start_time_unix_sec=first_time_unix_sec,
        end_time_unix_sec=last_time_unix_sec,
        time_interval_sec=NARR_TIME_INTERVAL_SEC, include_endpoint=True)

    if randomize_times:
        error_checking.assert_is_leq(
            num_target_times, len(target_times_unix_sec))
        numpy.random.shuffle(target_times_unix_sec)
        target_times_unix_sec = target_times_unix_sec[:num_target_times]

    print 'Reading model from: "{0:s}"...'.format(model_file_name)
    model_object = traditional_cnn.read_keras_model(model_file_name)

    model_metafile_name = traditional_cnn.find_metafile(
        model_file_name=model_file_name, raise_error_if_missing=True)

    print 'Reading model metadata from: "{0:s}"...'.format(
        model_metafile_name)
    model_metadata_dict = traditional_cnn.read_model_metadata(
        model_metafile_name)

    if use_isotonic_regression:
        isotonic_file_name = isotonic_regression.find_model_file(
            base_model_file_name=model_file_name, raise_error_if_missing=True)

        print 'Reading isotonic-regression models from: "{0:s}"...'.format(
            isotonic_file_name)
        isotonic_model_object_by_class = (
            isotonic_regression.read_model_for_each_class(isotonic_file_name)
        )
    else:
        isotonic_model_object_by_class = None

    if model_metadata_dict[traditional_cnn.NUM_LEAD_TIME_STEPS_KEY] is None:
        num_dimensions = 3
    else:
        num_dimensions = 4

    num_classes = len(model_metadata_dict[traditional_cnn.CLASS_FRACTIONS_KEY])
    num_target_times = len(target_times_unix_sec)
    print SEPARATOR_STRING

    for i in range(num_target_times):
        if num_dimensions == 3:
            (this_class_probability_matrix, this_target_matrix
            ) = traditional_cnn.apply_model_to_3d_example(
                model_object=model_object,
                target_time_unix_sec=target_times_unix_sec[i],
                top_narr_directory_name=top_narr_directory_name,
                top_frontal_grid_dir_name=top_frontal_grid_dir_name,
                narr_predictor_names=model_metadata_dict[
                    traditional_cnn.NARR_PREDICTOR_NAMES_KEY],
                pressure_level_mb=model_metadata_dict[
                    traditional_cnn.PRESSURE_LEVEL_KEY],
                dilation_distance_metres=model_metadata_dict[
                    traditional_cnn.DILATION_DISTANCE_FOR_TARGET_KEY],
                num_rows_in_half_grid=model_metadata_dict[
                    traditional_cnn.NUM_ROWS_IN_HALF_GRID_KEY],
                num_columns_in_half_grid=model_metadata_dict[
                    traditional_cnn.NUM_COLUMNS_IN_HALF_GRID_KEY],
                num_classes=num_classes,
                isotonic_model_object_by_class=isotonic_model_object_by_class,
                narr_mask_matrix=model_metadata_dict[
                    traditional_cnn.NARR_MASK_MATRIX_KEY])
        else:
            (this_class_probability_matrix, this_target_matrix
            ) = traditional_cnn.apply_model_to_4d_example(
                model_object=model_object,
                target_time_unix_sec=target_times_unix_sec[i],
                predictor_time_step_offsets=model_metadata_dict[
                    traditional_cnn.PREDICTOR_TIME_STEP_OFFSETS_KEY],
                num_lead_time_steps=model_metadata_dict[
                    traditional_cnn.NUM_LEAD_TIME_STEPS_KEY],
                top_narr_directory_name=top_narr_directory_name,
                top_frontal_grid_dir_name=top_frontal_grid_dir_name,
                narr_predictor_names=model_metadata_dict[
                    traditional_cnn.NARR_PREDICTOR_NAMES_KEY],
                pressure_level_mb=model_metadata_dict[
                    traditional_cnn.PRESSURE_LEVEL_KEY],
                dilation_distance_metres=model_metadata_dict[
                    traditional_cnn.DILATION_DISTANCE_FOR_TARGET_KEY],
                num_rows_in_half_grid=model_metadata_dict[
                    traditional_cnn.NUM_ROWS_IN_HALF_GRID_KEY],
                num_columns_in_half_grid=model_metadata_dict[
                    traditional_cnn.NUM_COLUMNS_IN_HALF_GRID_KEY],
                num_classes=num_classes,
                isotonic_model_object_by_class=isotonic_model_object_by_class,
                narr_mask_matrix=model_metadata_dict[
                    traditional_cnn.NARR_MASK_MATRIX_KEY])

        this_target_matrix[this_target_matrix == -1] = 0
        print MINOR_SEPARATOR_STRING

        this_prediction_file_name = ml_utils.find_gridded_prediction_file(
            directory_name=output_dir_name,
            first_target_time_unix_sec=target_times_unix_sec[i],
            last_target_time_unix_sec=target_times_unix_sec[i],
            raise_error_if_missing=False)

        print 'Writing gridded predictions to file: "{0:s}"...'.format(
            this_prediction_file_name)

        ml_utils.write_gridded_predictions(
            pickle_file_name=this_prediction_file_name,
            class_probability_matrix=this_class_probability_matrix,
            target_times_unix_sec=target_times_unix_sec[[i]],
            model_file_name=model_file_name,
            used_isotonic_regression=use_isotonic_regression,
            target_matrix=this_target_matrix)

        if i != num_target_times - 1:
            print SEPARATOR_STRING
def _run(input_model_file_name, sounding_field_names,
         normalization_type_string, normalization_param_file_name,
         min_normalized_value, max_normalized_value, target_name,
         downsampling_classes, downsampling_fractions, monitor_string,
         weight_loss_function, x_translations_pixels, y_translations_pixels,
         ccw_rotation_angles_deg, noise_standard_deviation, num_noisings,
         flip_in_x, flip_in_y, top_training_dir_name,
         first_training_time_string, last_training_time_string,
         num_examples_per_train_batch, top_validation_dir_name,
         first_validation_time_string, last_validation_time_string,
         num_examples_per_validn_batch, num_epochs,
         num_training_batches_per_epoch, num_validation_batches_per_epoch,
         output_dir_name):
    """Trains CNN with 2-D and 3-D MYRORSS images.

    This is effectively the main method.

    :param input_model_file_name: See documentation at top of file.
    :param sounding_field_names: Same.
    :param normalization_type_string: Same.
    :param normalization_param_file_name: Same.
    :param min_normalized_value: Same.
    :param max_normalized_value: Same.
    :param target_name: Same.
    :param downsampling_classes: Same.
    :param downsampling_fractions: Same.
    :param monitor_string: Same.
    :param weight_loss_function: Same.
    :param x_translations_pixels: Same.
    :param y_translations_pixels: Same.
    :param ccw_rotation_angles_deg: Same.
    :param noise_standard_deviation: Same.
    :param num_noisings: Same.
    :param flip_in_x: Same.
    :param flip_in_y: Same.
    :param top_training_dir_name: Same.
    :param first_training_time_string: Same.
    :param last_training_time_string: Same.
    :param num_examples_per_train_batch: Same.
    :param top_validation_dir_name: Same.
    :param first_validation_time_string: Same.
    :param last_validation_time_string: Same.
    :param num_examples_per_validn_batch: Same.
    :param num_epochs: Same.
    :param num_training_batches_per_epoch: Same.
    :param num_validation_batches_per_epoch: Same.
    :param output_dir_name: Same.
    """

    file_system_utils.mkdir_recursive_if_necessary(
        directory_name=output_dir_name)

    # argument_file_name = '{0:s}/input_args.p'.format(output_dir_name)
    # print('Writing input args to: "{0:s}"...'.format(argument_file_name))
    #
    # argument_file_handle = open(argument_file_name, 'wb')
    # pickle.dump(INPUT_ARG_OBJECT.__dict__, argument_file_handle)
    # argument_file_handle.close()
    #
    # return

    # Process input args.
    first_training_time_unix_sec = time_conversion.string_to_unix_sec(
        first_training_time_string, TIME_FORMAT)
    last_training_time_unix_sec = time_conversion.string_to_unix_sec(
        last_training_time_string, TIME_FORMAT)

    first_validation_time_unix_sec = time_conversion.string_to_unix_sec(
        first_validation_time_string, TIME_FORMAT)
    last_validation_time_unix_sec = time_conversion.string_to_unix_sec(
        last_validation_time_string, TIME_FORMAT)

    if sounding_field_names[0] in ['', 'None']:
        sounding_field_names = None

    if len(downsampling_classes) > 1:
        downsampling_dict = dict(
            list(zip(downsampling_classes, downsampling_fractions)))
    else:
        downsampling_dict = None

    if (len(x_translations_pixels) == 1
            and x_translations_pixels + y_translations_pixels == 0):
        x_translations_pixels = None
        y_translations_pixels = None

    if len(ccw_rotation_angles_deg) == 1 and ccw_rotation_angles_deg[0] == 0:
        ccw_rotation_angles_deg = None

    if num_noisings <= 0:
        num_noisings = 0
        noise_standard_deviation = None

    # Set output locations.
    output_model_file_name = '{0:s}/model.h5'.format(output_dir_name)
    history_file_name = '{0:s}/model_history.csv'.format(output_dir_name)
    tensorboard_dir_name = '{0:s}/tensorboard'.format(output_dir_name)
    model_metafile_name = '{0:s}/model_metadata.p'.format(output_dir_name)

    # Find training and validation files.
    training_file_names = input_examples.find_many_example_files(
        top_directory_name=top_training_dir_name,
        shuffled=True,
        first_batch_number=FIRST_BATCH_NUMBER,
        last_batch_number=LAST_BATCH_NUMBER,
        raise_error_if_any_missing=False)

    validation_file_names = input_examples.find_many_example_files(
        top_directory_name=top_validation_dir_name,
        shuffled=True,
        first_batch_number=FIRST_BATCH_NUMBER,
        last_batch_number=LAST_BATCH_NUMBER,
        raise_error_if_any_missing=False)

    # Read architecture.
    print(
        'Reading architecture from: "{0:s}"...'.format(input_model_file_name))
    model_object = cnn.read_model(input_model_file_name)
    # model_object = keras.models.clone_model(model_object)

    # TODO(thunderhoser): This is a HACK.
    model_object.compile(loss=keras.losses.binary_crossentropy,
                         optimizer=keras.optimizers.Adam(),
                         metrics=cnn_setup.DEFAULT_METRIC_FUNCTION_LIST)

    print(SEPARATOR_STRING)
    model_object.summary()
    print(SEPARATOR_STRING)

    # Write metadata.
    metadata_dict = {
        cnn.NUM_EPOCHS_KEY: num_epochs,
        cnn.NUM_TRAINING_BATCHES_KEY: num_training_batches_per_epoch,
        cnn.NUM_VALIDATION_BATCHES_KEY: num_validation_batches_per_epoch,
        cnn.MONITOR_STRING_KEY: monitor_string,
        cnn.WEIGHT_LOSS_FUNCTION_KEY: weight_loss_function,
        cnn.CONV_2D3D_KEY: True,
        cnn.VALIDATION_FILES_KEY: validation_file_names,
        cnn.FIRST_VALIDN_TIME_KEY: first_validation_time_unix_sec,
        cnn.LAST_VALIDN_TIME_KEY: last_validation_time_unix_sec,
        cnn.NUM_EX_PER_VALIDN_BATCH_KEY: num_examples_per_validn_batch
    }

    if isinstance(model_object.input, list):
        list_of_input_tensors = model_object.input
    else:
        list_of_input_tensors = [model_object.input]

    upsample_refl = len(list_of_input_tensors) == 2
    num_grid_rows = list_of_input_tensors[0].get_shape().as_list()[1]
    num_grid_columns = list_of_input_tensors[0].get_shape().as_list()[2]

    if upsample_refl:
        num_grid_rows = int(numpy.round(num_grid_rows / 2))
        num_grid_columns = int(numpy.round(num_grid_columns / 2))

    training_option_dict = {
        trainval_io.EXAMPLE_FILES_KEY: training_file_names,
        trainval_io.TARGET_NAME_KEY: target_name,
        trainval_io.FIRST_STORM_TIME_KEY: first_training_time_unix_sec,
        trainval_io.LAST_STORM_TIME_KEY: last_training_time_unix_sec,
        trainval_io.NUM_EXAMPLES_PER_BATCH_KEY: num_examples_per_train_batch,
        trainval_io.RADAR_FIELDS_KEY:
        input_examples.AZIMUTHAL_SHEAR_FIELD_NAMES,
        trainval_io.RADAR_HEIGHTS_KEY: REFLECTIVITY_HEIGHTS_M_AGL,
        trainval_io.SOUNDING_FIELDS_KEY: sounding_field_names,
        trainval_io.SOUNDING_HEIGHTS_KEY: SOUNDING_HEIGHTS_M_AGL,
        trainval_io.NUM_ROWS_KEY: num_grid_rows,
        trainval_io.NUM_COLUMNS_KEY: num_grid_columns,
        trainval_io.NORMALIZATION_TYPE_KEY: normalization_type_string,
        trainval_io.NORMALIZATION_FILE_KEY: normalization_param_file_name,
        trainval_io.MIN_NORMALIZED_VALUE_KEY: min_normalized_value,
        trainval_io.MAX_NORMALIZED_VALUE_KEY: max_normalized_value,
        trainval_io.BINARIZE_TARGET_KEY: False,
        trainval_io.SAMPLING_FRACTIONS_KEY: downsampling_dict,
        trainval_io.LOOP_ONCE_KEY: False,
        trainval_io.X_TRANSLATIONS_KEY: x_translations_pixels,
        trainval_io.Y_TRANSLATIONS_KEY: y_translations_pixels,
        trainval_io.ROTATION_ANGLES_KEY: ccw_rotation_angles_deg,
        trainval_io.NOISE_STDEV_KEY: noise_standard_deviation,
        trainval_io.NUM_NOISINGS_KEY: num_noisings,
        trainval_io.FLIP_X_KEY: flip_in_x,
        trainval_io.FLIP_Y_KEY: flip_in_y,
        trainval_io.UPSAMPLE_REFLECTIVITY_KEY: upsample_refl
    }

    print('Writing metadata to: "{0:s}"...'.format(model_metafile_name))
    cnn.write_model_metadata(pickle_file_name=model_metafile_name,
                             metadata_dict=metadata_dict,
                             training_option_dict=training_option_dict)

    cnn.train_cnn_2d3d_myrorss(
        model_object=model_object,
        model_file_name=output_model_file_name,
        history_file_name=history_file_name,
        tensorboard_dir_name=tensorboard_dir_name,
        num_epochs=num_epochs,
        num_training_batches_per_epoch=num_training_batches_per_epoch,
        training_option_dict=training_option_dict,
        monitor_string=monitor_string,
        weight_loss_function=weight_loss_function,
        num_validation_batches_per_epoch=num_validation_batches_per_epoch,
        validation_file_names=validation_file_names,
        first_validn_time_unix_sec=first_validation_time_unix_sec,
        last_validn_time_unix_sec=last_validation_time_unix_sec,
        num_examples_per_validn_batch=num_examples_per_validn_batch)
示例#28
0
def _find_baseline_and_test_examples(top_example_dir_name, first_time_string,
                                     last_time_string, num_baseline_examples,
                                     num_test_examples, cnn_model_object,
                                     cnn_metadata_dict):
    """Finds examples for baseline and test sets.

    :param top_example_dir_name: See documentation at top of file.
    :param first_time_string: Same.
    :param last_time_string: Same.
    :param num_baseline_examples: Same.
    :param num_test_examples: Same.
    :param cnn_model_object:
    :param cnn_metadata_dict:
    :return: baseline_image_matrix: B-by-M-by-N-by-C numpy array of baseline
        images (input examples for the CNN).
    :return: test_image_matrix: B-by-M-by-N-by-C numpy array of test images
        (input examples for the CNN).
    """

    first_time_unix_sec = time_conversion.string_to_unix_sec(
        first_time_string, TIME_FORMAT)
    last_time_unix_sec = time_conversion.string_to_unix_sec(
        last_time_string, TIME_FORMAT)

    example_file_names = trainval_io.find_downsized_3d_example_files(
        top_directory_name=top_example_dir_name,
        shuffled=False,
        first_target_time_unix_sec=first_time_unix_sec,
        last_target_time_unix_sec=last_time_unix_sec)

    file_indices = numpy.array([], dtype=int)
    file_position_indices = numpy.array([], dtype=int)
    cold_front_probabilities = numpy.array([], dtype=float)

    for k in range(len(example_file_names)):
        print 'Reading data from: "{0:s}"...'.format(example_file_names[k])
        this_example_dict = trainval_io.read_downsized_3d_examples(
            netcdf_file_name=example_file_names[k],
            metadata_only=False,
            predictor_names_to_keep=cnn_metadata_dict[
                traditional_cnn.NARR_PREDICTOR_NAMES_KEY],
            num_half_rows_to_keep=cnn_metadata_dict[
                traditional_cnn.NUM_ROWS_IN_HALF_GRID_KEY],
            num_half_columns_to_keep=cnn_metadata_dict[
                traditional_cnn.NUM_COLUMNS_IN_HALF_GRID_KEY],
            first_time_to_keep_unix_sec=first_time_unix_sec,
            last_time_to_keep_unix_sec=last_time_unix_sec)

        this_num_examples = len(
            this_example_dict[trainval_io.TARGET_TIMES_KEY])
        if this_num_examples == 0:
            continue

        these_file_indices = numpy.full(this_num_examples, k, dtype=int)
        these_position_indices = numpy.linspace(0,
                                                this_num_examples - 1,
                                                num=this_num_examples,
                                                dtype=int)

        these_cold_front_probs = _get_cnn_predictions(
            cnn_model_object=cnn_model_object,
            predictor_matrix=this_example_dict[
                trainval_io.PREDICTOR_MATRIX_KEY],
            target_class=front_utils.COLD_FRONT_INTEGER_ID,
            verbose=True)
        print '\n'

        file_indices = numpy.concatenate((file_indices, these_file_indices))
        file_position_indices = numpy.concatenate(
            (file_position_indices, these_position_indices))
        cold_front_probabilities = numpy.concatenate(
            (cold_front_probabilities, these_cold_front_probs))

    print SEPARATOR_STRING

    # Find test set.
    test_indices = numpy.argsort(-1 *
                                 cold_front_probabilities)[:num_test_examples]
    file_indices_for_test = file_indices[test_indices]
    file_position_indices_for_test = file_position_indices[test_indices]

    print 'Cold-front probabilities for the {0:d} test examples are:'.format(
        num_test_examples)
    for i in test_indices:
        print cold_front_probabilities[i]
    print SEPARATOR_STRING

    # Find baseline set.
    baseline_indices = numpy.linspace(0,
                                      num_baseline_examples - 1,
                                      num=num_baseline_examples,
                                      dtype=int)

    baseline_indices = (set(baseline_indices.tolist()) -
                        set(test_indices.tolist()))
    baseline_indices = numpy.array(list(baseline_indices), dtype=int)
    baseline_indices = numpy.random.choice(baseline_indices,
                                           size=num_baseline_examples,
                                           replace=False)

    file_indices_for_baseline = file_indices[baseline_indices]
    file_position_indices_for_baseline = file_position_indices[
        baseline_indices]

    print('Cold-front probabilities for the {0:d} baseline examples are:'
          ).format(num_baseline_examples)
    for i in baseline_indices:
        print cold_front_probabilities[i]
    print SEPARATOR_STRING

    # Read test and baseline sets.
    baseline_image_matrix = None
    test_image_matrix = None

    for k in range(len(example_file_names)):
        if not (k in file_indices_for_test or k in file_indices_for_baseline):
            continue

        print 'Reading data from: "{0:s}"...'.format(example_file_names[k])
        this_example_dict = trainval_io.read_downsized_3d_examples(
            netcdf_file_name=example_file_names[k],
            metadata_only=False,
            predictor_names_to_keep=cnn_metadata_dict[
                traditional_cnn.NARR_PREDICTOR_NAMES_KEY],
            num_half_rows_to_keep=cnn_metadata_dict[
                traditional_cnn.NUM_ROWS_IN_HALF_GRID_KEY],
            num_half_columns_to_keep=cnn_metadata_dict[
                traditional_cnn.NUM_COLUMNS_IN_HALF_GRID_KEY],
            first_time_to_keep_unix_sec=first_time_unix_sec,
            last_time_to_keep_unix_sec=last_time_unix_sec)

        this_predictor_matrix = this_example_dict[
            trainval_io.PREDICTOR_MATRIX_KEY]

        if baseline_image_matrix is None:
            baseline_image_matrix = numpy.full(
                (num_baseline_examples, ) + this_predictor_matrix.shape[1:],
                numpy.nan)
            test_image_matrix = numpy.full(
                (num_test_examples, ) + this_predictor_matrix.shape[1:],
                numpy.nan)

        these_baseline_indices = numpy.where(file_indices_for_baseline == k)[0]
        if len(these_baseline_indices) > 0:
            baseline_image_matrix[these_baseline_indices, ...] = (
                this_predictor_matrix[
                    file_position_indices_for_baseline[these_baseline_indices],
                    ...])

        these_test_indices = numpy.where(file_indices_for_test == k)[0]
        if len(these_test_indices) > 0:
            test_image_matrix[these_test_indices, ...] = (
                this_predictor_matrix[
                    file_position_indices_for_test[these_test_indices], ...])

    return baseline_image_matrix, test_image_matrix
import unittest
import numpy
from gewittergefahr.gg_utils import time_conversion
from gewittergefahr.deep_learning import prediction_io

# The following constants are used to test subset_ungridded_predictions.
TARGET_NAME = 'foo'

THESE_ID_STRINGS = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L']
THESE_TIME_STRINGS = [
    '4001-01-01-01', '4002-02-02-02', '4003-03-03-03', '4004-04-04-04',
    '4005-05-05-05', '4006-06-06-06', '4007-07-07-07', '4008-08-08-08',
    '4009-09-09-09', '4010-10-10-10', '4011-11-11-11', '4012-12-12-12'
]
THESE_TIMES_UNIX_SEC = [
    time_conversion.string_to_unix_sec(t, '%Y-%m-%d-%H')
    for t in THESE_TIME_STRINGS
]
THIS_PROBABILITY_MATRIX = numpy.array([[1, 0], [0.9, 0.1], [0.8, 0.2],
                                       [0.7, 0.3], [0.6, 0.4], [0.5, 0.5],
                                       [0.4, 0.6], [0.3, 0.7], [0.2, 0.8],
                                       [0.1, 0.9], [0, 1], [0.75, 0.25]])
THESE_OBSERVED_LABELS = numpy.array([0, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0],
                                    dtype=int)

FULL_PREDICTION_DICT_SANS_OBS = {
    prediction_io.TARGET_NAME_KEY: TARGET_NAME,
    prediction_io.STORM_IDS_KEY: THESE_ID_STRINGS,
    prediction_io.STORM_TIMES_KEY: THESE_TIMES_UNIX_SEC,
    prediction_io.PROBABILITY_MATRIX_KEY: THIS_PROBABILITY_MATRIX,
    prediction_io.OBSERVED_LABELS_KEY: None,
def _run(model_file_name, first_eval_time_string, last_eval_time_string,
         num_times, num_examples_per_time, dilation_distance_metres,
         use_isotonic_regression, top_narr_directory_name,
         top_frontal_grid_dir_name, output_dir_name):
    """Evaluates CNN trained by patch classification.

    This is effectively the main method.

    :param model_file_name: See documentation at top of file.
    :param first_eval_time_string: Same.
    :param last_eval_time_string: Same.
    :param num_times: Same.
    :param num_examples_per_time: Same.
    :param dilation_distance_metres: Same.
    :param use_isotonic_regression: Same.
    :param top_narr_directory_name: Same.
    :param top_frontal_grid_dir_name: Same.
    :param output_dir_name: Same.
    """

    first_eval_time_unix_sec = time_conversion.string_to_unix_sec(
        first_eval_time_string, INPUT_TIME_FORMAT)
    last_eval_time_unix_sec = time_conversion.string_to_unix_sec(
        last_eval_time_string, INPUT_TIME_FORMAT)

    print 'Reading model from: "{0:s}"...'.format(model_file_name)
    model_object = traditional_cnn.read_keras_model(model_file_name)

    model_metafile_name = traditional_cnn.find_metafile(
        model_file_name=model_file_name, raise_error_if_missing=True)

    print 'Reading model metadata from: "{0:s}"...'.format(model_metafile_name)
    model_metadata_dict = traditional_cnn.read_model_metadata(
        model_metafile_name)

    if dilation_distance_metres < 0:
        dilation_distance_metres = model_metadata_dict[
            traditional_cnn.DILATION_DISTANCE_FOR_TARGET_KEY] + 0.

    if use_isotonic_regression:
        isotonic_file_name = isotonic_regression.find_model_file(
            base_model_file_name=model_file_name, raise_error_if_missing=True)

        print 'Reading isotonic-regression models from: "{0:s}"...'.format(
            isotonic_file_name)
        isotonic_model_object_by_class = (
            isotonic_regression.read_model_for_each_class(isotonic_file_name))
    else:
        isotonic_model_object_by_class = None

    num_classes = len(model_metadata_dict[traditional_cnn.CLASS_FRACTIONS_KEY])
    print SEPARATOR_STRING

    class_probability_matrix, observed_labels = (
        eval_utils.downsized_examples_to_eval_pairs(
            model_object=model_object,
            first_target_time_unix_sec=first_eval_time_unix_sec,
            last_target_time_unix_sec=last_eval_time_unix_sec,
            num_target_times_to_sample=num_times,
            num_examples_per_time=num_examples_per_time,
            top_narr_directory_name=top_narr_directory_name,
            top_frontal_grid_dir_name=top_frontal_grid_dir_name,
            narr_predictor_names=model_metadata_dict[
                traditional_cnn.NARR_PREDICTOR_NAMES_KEY],
            pressure_level_mb=model_metadata_dict[
                traditional_cnn.PRESSURE_LEVEL_KEY],
            dilation_distance_metres=dilation_distance_metres,
            num_rows_in_half_grid=model_metadata_dict[
                traditional_cnn.NUM_ROWS_IN_HALF_GRID_KEY],
            num_columns_in_half_grid=model_metadata_dict[
                traditional_cnn.NUM_COLUMNS_IN_HALF_GRID_KEY],
            num_classes=num_classes,
            predictor_time_step_offsets=model_metadata_dict[
                traditional_cnn.PREDICTOR_TIME_STEP_OFFSETS_KEY],
            num_lead_time_steps=model_metadata_dict[
                traditional_cnn.NUM_LEAD_TIME_STEPS_KEY],
            isotonic_model_object_by_class=isotonic_model_object_by_class,
            narr_mask_matrix=model_metadata_dict[
                traditional_cnn.NARR_MASK_MATRIX_KEY]))

    print SEPARATOR_STRING

    model_eval_helper.run_evaluation(
        class_probability_matrix=class_probability_matrix,
        observed_labels=observed_labels,
        output_dir_name=output_dir_name)