예제 #1
0
    def test_target_name_to_params_tornado(self):
        """Ensures correct output from target_name_to_params.

        In this case, target variable is based on tornado occurrence.
        """

        this_dict = target_val_utils.target_name_to_params(TORNADO_TARGET_NAME)
        self.assertTrue(
            _compare_target_param_dicts(this_dict, TORNADO_PARAM_DICT))
예제 #2
0
    def test_target_name_to_params_wind_regression(self):
        """Ensures correct output from target_name_to_params.

        In this case, target variable is based on wind-speed regression.
        """

        this_dict = target_val_utils.target_name_to_params(
            WIND_REGRESSION_NAME)
        self.assertTrue(
            _compare_target_param_dicts(this_dict, WIND_REGRESSION_PARAM_DICT))
예제 #3
0
    def test_target_name_to_params_wind_classifn_0lead(self):
        """Ensures correct output from target_name_to_params.

        In this case, target variable is based on wind-speed classification and
        minimum lead time is zero.
        """

        this_dict = target_val_utils.target_name_to_params(
            WIND_CLASSIFICATION_NAME_0LEAD)
        self.assertTrue(
            _compare_target_param_dicts(this_dict,
                                        WIND_CLASSIFICATION_PARAM_DICT_0LEAD))
예제 #4
0
def _read_target_values(top_target_dir_name, storm_activations,
                        activation_metadata_dict):
    """Reads target value for each storm object.

    E = number of examples (storm objects)

    :param top_target_dir_name: See documentation at top of file.
    :param storm_activations: length-E numpy array of activations.
    :param activation_metadata_dict: Dictionary returned by
        `model_activation.read_file`.
    :return: target_dict: Dictionary with the following keys.
    target_dict['full_id_strings']: length-E list of full storm IDs.
    target_dict['storm_times_unix_sec']: length-E numpy array of storm times.
    target_dict['storm_activations']: length-E numpy array of model activations.
    target_dict['storm_target_values']: length-E numpy array of target values.

    :raises: ValueError: if the target variable is multiclass and not binarized.
    """

    # Convert input args.
    full_id_strings = activation_metadata_dict[model_activation.FULL_IDS_KEY]
    storm_times_unix_sec = activation_metadata_dict[
        model_activation.STORM_TIMES_KEY]

    storm_spc_date_strings_numpy = numpy.array([
        time_conversion.time_to_spc_date_string(t)
        for t in storm_times_unix_sec
    ],
                                               dtype=object)

    unique_spc_date_strings_numpy = numpy.unique(storm_spc_date_strings_numpy)

    # Read metadata for machine-learning model.
    model_file_name = activation_metadata_dict[
        model_activation.MODEL_FILE_NAME_KEY]
    model_metadata_file_name = '{0:s}/model_metadata.p'.format(
        os.path.split(model_file_name)[0])

    print('Reading metadata from: "{0:s}"...'.format(model_metadata_file_name))
    model_metadata_dict = cnn.read_model_metadata(model_metadata_file_name)
    training_option_dict = model_metadata_dict[cnn.TRAINING_OPTION_DICT_KEY]

    target_name = training_option_dict[trainval_io.TARGET_NAME_KEY]
    num_classes = target_val_utils.target_name_to_num_classes(
        target_name=target_name, include_dead_storms=False)

    binarize_target = (training_option_dict[trainval_io.BINARIZE_TARGET_KEY]
                       and num_classes > 2)

    if num_classes > 2 and not binarize_target:
        error_string = (
            'The target variable ("{0:s}") is multiclass, which this script '
            'cannot handle.').format(target_name)

        raise ValueError(error_string)

    event_type_string = target_val_utils.target_name_to_params(target_name)[
        target_val_utils.EVENT_TYPE_KEY]

    # Read target values.
    storm_target_values = numpy.array([], dtype=int)
    id_sort_indices = numpy.array([], dtype=int)
    num_spc_dates = len(unique_spc_date_strings_numpy)

    for i in range(num_spc_dates):
        this_target_file_name = target_val_utils.find_target_file(
            top_directory_name=top_target_dir_name,
            event_type_string=event_type_string,
            spc_date_string=unique_spc_date_strings_numpy[i])

        print('Reading data from: "{0:s}"...'.format(this_target_file_name))
        this_target_value_dict = target_val_utils.read_target_values(
            netcdf_file_name=this_target_file_name, target_names=[target_name])

        these_indices = numpy.where(storm_spc_date_strings_numpy ==
                                    unique_spc_date_strings_numpy[i])[0]
        id_sort_indices = numpy.concatenate((id_sort_indices, these_indices))

        these_indices = tracking_utils.find_storm_objects(
            all_id_strings=this_target_value_dict[
                target_val_utils.FULL_IDS_KEY],
            all_times_unix_sec=this_target_value_dict[
                target_val_utils.VALID_TIMES_KEY],
            id_strings_to_keep=[full_id_strings[k] for k in these_indices],
            times_to_keep_unix_sec=storm_times_unix_sec[these_indices])

        if len(these_indices) == 0:
            continue

        these_target_values = this_target_value_dict[
            target_val_utils.TARGET_MATRIX_KEY][these_indices, :]

        these_target_values = numpy.reshape(these_target_values,
                                            these_target_values.size)

        storm_target_values = numpy.concatenate(
            (storm_target_values, these_target_values))

    good_indices = numpy.where(
        storm_target_values != target_val_utils.INVALID_STORM_INTEGER)[0]

    storm_target_values = storm_target_values[good_indices]
    id_sort_indices = id_sort_indices[good_indices]

    if binarize_target:
        storm_target_values = (storm_target_values == num_classes -
                               1).astype(int)

    return {
        FULL_IDS_KEY: [full_id_strings[k] for k in id_sort_indices],
        STORM_TIMES_KEY: storm_times_unix_sec[id_sort_indices],
        STORM_ACTIVATIONS_KEY: storm_activations[id_sort_indices],
        TARGET_VALUES_KEY: storm_target_values
    }
예제 #5
0
def _extract_storm_images(
        num_image_rows, num_image_columns, rotate_grids,
        rotated_grid_spacing_metres, radar_field_names, radar_heights_m_agl,
        spc_date_string, top_radar_dir_name, top_tracking_dir_name,
        elevation_dir_name, tracking_scale_metres2, target_name,
        top_target_dir_name, top_output_dir_name):
    """Extracts storm-centered radar images from GridRad data.

    :param num_image_rows: See documentation at top of file.
    :param num_image_columns: Same.
    :param rotate_grids: Same.
    :param rotated_grid_spacing_metres: Same.
    :param radar_field_names: Same.
    :param radar_heights_m_agl: Same.
    :param spc_date_string: Same.
    :param top_radar_dir_name: Same.
    :param top_tracking_dir_name: Same.
    :param elevation_dir_name: Same.
    :param tracking_scale_metres2: Same.
    :param target_name: Same.
    :param top_target_dir_name: Same.
    :param top_output_dir_name: Same.
    """

    if target_name in ['', 'None']:
        target_name = None

    if target_name is not None:
        target_param_dict = target_val_utils.target_name_to_params(target_name)

        target_file_name = target_val_utils.find_target_file(
            top_directory_name=top_target_dir_name,
            event_type_string=target_param_dict[
                target_val_utils.EVENT_TYPE_KEY],
            spc_date_string=spc_date_string)

        print('Reading data from: "{0:s}"...'.format(target_file_name))
        target_dict = target_val_utils.read_target_values(
            netcdf_file_name=target_file_name, target_names=[target_name]
        )

        print('\n')

    # Find storm objects on the given SPC date.
    tracking_file_names = tracking_io.find_files_one_spc_date(
        spc_date_string=spc_date_string,
        source_name=tracking_utils.SEGMOTION_NAME,
        top_tracking_dir_name=top_tracking_dir_name,
        tracking_scale_metres2=tracking_scale_metres2
    )[0]

    # Read storm objects on the given SPC date.
    storm_object_table = tracking_io.read_many_files(
        tracking_file_names
    )[storm_images.STORM_COLUMNS_NEEDED]

    print(SEPARATOR_STRING)

    if target_name is not None:
        print((
            'Removing storm objects without target values (variable = '
            '"{0:s}")...'
        ).format(target_name))

        these_indices = tracking_utils.find_storm_objects(
            all_id_strings=storm_object_table[
                tracking_utils.FULL_ID_COLUMN].values.tolist(),
            all_times_unix_sec=storm_object_table[
                tracking_utils.VALID_TIME_COLUMN].values.astype(int),
            id_strings_to_keep=target_dict[target_val_utils.FULL_IDS_KEY],
            times_to_keep_unix_sec=target_dict[
                target_val_utils.VALID_TIMES_KEY],
            allow_missing=False)

        num_storm_objects_orig = len(storm_object_table.index)
        storm_object_table = storm_object_table.iloc[these_indices]
        num_storm_objects = len(storm_object_table.index)

        print('Removed {0:d} of {1:d} storm objects!\n'.format(
            num_storm_objects_orig - num_storm_objects, num_storm_objects_orig
        ))

    # Extract storm-centered radar images.
    storm_images.extract_storm_images_gridrad(
        storm_object_table=storm_object_table,
        top_radar_dir_name=top_radar_dir_name,
        top_output_dir_name=top_output_dir_name,
        elevation_dir_name=elevation_dir_name,
        num_storm_image_rows=num_image_rows,
        num_storm_image_columns=num_image_columns, rotate_grids=rotate_grids,
        rotated_grid_spacing_metres=rotated_grid_spacing_metres,
        radar_field_names=radar_field_names,
        radar_heights_m_agl=radar_heights_m_agl)
def _read_new_target_values(top_target_dir_name, new_target_name,
                            full_storm_id_strings, storm_times_unix_sec,
                            orig_target_values):
    """Reads new target values (for upgraded minimum EF rating).

    E = number of examples (storm objects)

    :param top_target_dir_name: See documentation at top of file.
    :param new_target_name: Name of new target variable (with upgraded minimum EF
        rating).
    :param full_storm_id_strings: length-E list of storm IDs.
    :param storm_times_unix_sec: length-E numpy array of valid times.
    :param orig_target_values: length-E numpy array of original target values
        (for original minimum EF rating), all integers in 0...1.
    :return: new_target_values: length-E numpy array of new target values
        (integers in -1...1).  -1 means that increasing minimum EF rating
        flipped the value from 1 to 0.
    """

    storm_spc_date_strings = numpy.array([
        time_conversion.time_to_spc_date_string(t)
        for t in storm_times_unix_sec
    ])
    unique_spc_date_strings = numpy.unique(storm_spc_date_strings)

    event_type_string = target_val_utils.target_name_to_params(
        new_target_name)[target_val_utils.EVENT_TYPE_KEY]

    num_spc_dates = len(unique_spc_date_strings)
    num_storm_objects = len(full_storm_id_strings)
    new_target_values = numpy.full(num_storm_objects, numpy.nan)

    for i in range(num_spc_dates):
        this_target_file_name = target_val_utils.find_target_file(
            top_directory_name=top_target_dir_name,
            event_type_string=event_type_string,
            spc_date_string=unique_spc_date_strings[i])

        print('Reading data from: "{0:s}"...'.format(this_target_file_name))
        this_target_value_dict = target_val_utils.read_target_values(
            netcdf_file_name=this_target_file_name,
            target_names=[new_target_name])

        these_storm_indices = numpy.where(
            storm_spc_date_strings == unique_spc_date_strings[i])[0]

        these_target_indices = tracking_utils.find_storm_objects(
            all_id_strings=this_target_value_dict[
                target_val_utils.FULL_IDS_KEY],
            all_times_unix_sec=this_target_value_dict[
                target_val_utils.VALID_TIMES_KEY],
            id_strings_to_keep=[
                full_storm_id_strings[k] for k in these_storm_indices
            ],
            times_to_keep_unix_sec=storm_times_unix_sec[these_storm_indices],
            allow_missing=False)

        new_target_values[these_storm_indices] = this_target_value_dict[
            target_val_utils.TARGET_MATRIX_KEY][these_target_indices, 0]

    assert not numpy.any(numpy.isnan(new_target_values))
    new_target_values = numpy.round(new_target_values).astype(int)

    bad_indices = numpy.where(new_target_values != orig_target_values)[0]
    print(('\n{0:d} of {1:d} new target values do not match original value.'
           ).format(len(bad_indices), num_storm_objects))

    new_target_values[bad_indices] = -1
    return new_target_values
def _run(main_activation_file_name, aux_activation_file_name, tornado_dir_name,
         top_tracking_dir_name, top_myrorss_dir_name, radar_field_name,
         radar_height_m_asl, latitude_buffer_deg, longitude_buffer_deg,
         top_output_dir_name):
    """Plots examples (storm objects) with surrounding context.

    This is effectively the main method.

    :param main_activation_file_name: See documentation at top of file.
    :param aux_activation_file_name: Same.
    :param tornado_dir_name: Same.
    :param top_tracking_dir_name: Same.
    :param top_myrorss_dir_name: Same.
    :param radar_field_name: Same.
    :param radar_height_m_asl: Same.
    :param latitude_buffer_deg: Same.
    :param longitude_buffer_deg: Same.
    :param top_output_dir_name: Same.
    :raises: ValueError: if activation file contains activations of some
        intermediate model component, rather than final predictions.
    :raises: ValueError: if target variable is not related to tornadogenesis.
    """

    if aux_activation_file_name in ['', 'None']:
        aux_activation_file_name = None

    print('Reading data from: "{0:s}"...'.format(main_activation_file_name))
    activation_matrix, activation_dict = model_activation.read_file(
        main_activation_file_name)

    component_type_string = activation_dict[
        model_activation.COMPONENT_TYPE_KEY]

    if (component_type_string !=
            model_interpretation.CLASS_COMPONENT_TYPE_STRING):
        error_string = (
            'Activation file should contain final predictions (component type '
            '"{0:s}").  Instead, component type is "{1:s}".').format(
                model_interpretation.CLASS_COMPONENT_TYPE_STRING,
                component_type_string)

        raise ValueError(error_string)

    forecast_probabilities = numpy.squeeze(activation_matrix)
    num_storm_objects = len(forecast_probabilities)

    model_file_name = activation_dict[model_activation.MODEL_FILE_NAME_KEY]
    model_metafile_name = '{0:s}/model_metadata.p'.format(
        os.path.split(model_file_name)[0])

    print(
        'Reading model metadata from: "{0:s}"...'.format(model_metafile_name))
    model_metadata_dict = cnn.read_model_metadata(model_metafile_name)

    training_option_dict = model_metadata_dict[cnn.TRAINING_OPTION_DICT_KEY]
    target_name = training_option_dict[trainval_io.TARGET_NAME_KEY]
    target_param_dict = target_val_utils.target_name_to_params(target_name)
    event_type_string = target_param_dict[target_val_utils.EVENT_TYPE_KEY]

    if event_type_string != linkage.TORNADO_EVENT_STRING:
        error_string = (
            'Target variable should be related to tornadogenesis.  Instead, got'
            ' "{0:s}".').format(target_name)

        raise ValueError(error_string)

    if aux_activation_file_name is None:
        aux_forecast_probabilities = None
        aux_activation_dict = None
    else:
        print('Reading data from: "{0:s}"...'.format(aux_activation_file_name))
        this_matrix, aux_activation_dict = model_activation.read_file(
            aux_activation_file_name)

        aux_forecast_probabilities = numpy.squeeze(this_matrix)

    print(SEPARATOR_STRING)

    for i in range(num_storm_objects):
        _plot_one_example(
            full_id_string=activation_dict[model_activation.FULL_IDS_KEY][i],
            storm_time_unix_sec=activation_dict[
                model_activation.STORM_TIMES_KEY][i],
            target_name=target_name,
            forecast_probability=forecast_probabilities[i],
            tornado_dir_name=tornado_dir_name,
            top_tracking_dir_name=top_tracking_dir_name,
            top_myrorss_dir_name=top_myrorss_dir_name,
            radar_field_name=radar_field_name,
            radar_height_m_asl=radar_height_m_asl,
            latitude_buffer_deg=latitude_buffer_deg,
            longitude_buffer_deg=longitude_buffer_deg,
            top_output_dir_name=top_output_dir_name,
            aux_forecast_probabilities=aux_forecast_probabilities,
            aux_activation_dict=aux_activation_dict)

        if i != num_storm_objects - 1:
            print(SEPARATOR_STRING)
예제 #8
0
def _run(top_input_dir_name, target_name, first_spc_date_string,
         last_spc_date_string, class_fraction_keys, class_fraction_values,
         for_training, top_output_dir_name):
    """Downsamples storm objects, based on target values.

    This is effectively the main method.

    :param top_input_dir_name: See documentation at top of file.
    :param target_name: Same.
    :param first_spc_date_string: Same.
    :param last_spc_date_string: Same.
    :param class_fraction_keys: Same.
    :param class_fraction_values: Same.
    :param for_training: Same.
    :param top_output_dir_name: Same.
    """

    spc_date_strings = time_conversion.get_spc_dates_in_range(
        first_spc_date_string=first_spc_date_string,
        last_spc_date_string=last_spc_date_string)

    class_fraction_dict = dict(zip(class_fraction_keys, class_fraction_values))
    target_param_dict = target_val_utils.target_name_to_params(target_name)

    input_target_file_names = []
    spc_date_string_by_file = []

    for this_spc_date_string in spc_date_strings:
        this_file_name = target_val_utils.find_target_file(
            top_directory_name=top_input_dir_name,
            event_type_string=target_param_dict[
                target_val_utils.EVENT_TYPE_KEY],
            spc_date_string=this_spc_date_string,
            raise_error_if_missing=False)
        if not os.path.isfile(this_file_name):
            continue

        input_target_file_names.append(this_file_name)
        spc_date_string_by_file.append(this_spc_date_string)

    num_files = len(input_target_file_names)
    spc_date_strings = spc_date_string_by_file
    target_dict_by_file = [None] * num_files

    storm_ids = []
    storm_times_unix_sec = numpy.array([], dtype=int)
    target_values = numpy.array([], dtype=int)

    for i in range(num_files):
        print 'Reading "{0:s}" from: "{1:s}"...'.format(
            target_name, input_target_file_names[i])
        target_dict_by_file[i] = target_val_utils.read_target_values(
            netcdf_file_name=input_target_file_names[i],
            target_name=target_name)

        storm_ids += target_dict_by_file[i][target_val_utils.STORM_IDS_KEY]
        storm_times_unix_sec = numpy.concatenate(
            (storm_times_unix_sec,
             target_dict_by_file[i][target_val_utils.VALID_TIMES_KEY]))
        target_values = numpy.concatenate(
            (target_values,
             target_dict_by_file[i][target_val_utils.TARGET_VALUES_KEY]))

    print SEPARATOR_STRING

    good_indices = numpy.where(
        target_values != target_val_utils.INVALID_STORM_INTEGER)[0]

    storm_ids = [storm_ids[k] for k in good_indices]
    storm_times_unix_sec = storm_times_unix_sec[good_indices]
    target_values = target_values[good_indices]

    if for_training:
        storm_ids, storm_times_unix_sec, target_values = (
            fancy_downsampling.downsample_for_training(
                storm_ids=storm_ids,
                storm_times_unix_sec=storm_times_unix_sec,
                target_values=target_values,
                target_name=target_name,
                class_fraction_dict=class_fraction_dict))
    else:
        storm_ids, storm_times_unix_sec, target_values = (
            fancy_downsampling.downsample_for_non_training(
                storm_ids=storm_ids,
                storm_times_unix_sec=storm_times_unix_sec,
                target_values=target_values,
                target_name=target_name,
                class_fraction_dict=class_fraction_dict))

    print SEPARATOR_STRING

    for i in range(num_files):
        these_indices = tracking_utils.find_storm_objects(
            all_storm_ids=target_dict_by_file[i][
                target_val_utils.STORM_IDS_KEY],
            all_times_unix_sec=target_dict_by_file[i][
                target_val_utils.VALID_TIMES_KEY],
            storm_ids_to_keep=storm_ids,
            times_to_keep_unix_sec=storm_times_unix_sec,
            allow_missing=True)

        these_indices = these_indices[these_indices >= 0]
        if len(these_indices) == 0:
            continue

        this_output_dict = {
            tracking_utils.STORM_ID_COLUMN: [
                target_dict_by_file[i][target_val_utils.STORM_IDS_KEY][k]
                for k in these_indices
            ],
            tracking_utils.TIME_COLUMN:
            target_dict_by_file[i][
                target_val_utils.VALID_TIMES_KEY][these_indices],
            target_name:
            target_dict_by_file[i][target_val_utils.TARGET_VALUES_KEY]
            [these_indices]
        }
        this_output_table = pandas.DataFrame.from_dict(this_output_dict)

        this_new_file_name = target_val_utils.find_target_file(
            top_directory_name=top_output_dir_name,
            event_type_string=target_param_dict[
                target_val_utils.EVENT_TYPE_KEY],
            spc_date_string=spc_date_strings[i],
            raise_error_if_missing=False)

        print(
            'Writing {0:d} downsampled storm objects (out of {1:d} total) to: '
            '"{2:s}"...').format(
                len(this_output_table.index),
                len(target_dict_by_file[i][target_val_utils.STORM_IDS_KEY]),
                this_new_file_name)

        target_val_utils.write_target_values(
            storm_to_events_table=this_output_table,
            target_names=[target_name],
            netcdf_file_name=this_new_file_name)
def _find_tracking_files_one_example(top_tracking_dir_name,
                                     valid_time_unix_sec, target_name):
    """Finds tracking files needed to make plots for one example.

    :param top_tracking_dir_name: See documentation at top of file.
    :param valid_time_unix_sec: Valid time for example.
    :param target_name: Name of target variable.
    :return: tracking_file_names: 1-D list of paths to tracking files.
    :raises: ValueError: if no tracking files are found.
    """

    target_param_dict = target_val_utils.target_name_to_params(target_name)
    min_lead_time_seconds = target_param_dict[
        target_val_utils.MIN_LEAD_TIME_KEY]
    max_lead_time_seconds = target_param_dict[
        target_val_utils.MAX_LEAD_TIME_KEY]

    first_time_unix_sec = valid_time_unix_sec + min_lead_time_seconds
    last_time_unix_sec = valid_time_unix_sec + max_lead_time_seconds

    first_spc_date_string = time_conversion.time_to_spc_date_string(
        first_time_unix_sec - TIME_INTERVAL_SECONDS)
    last_spc_date_string = time_conversion.time_to_spc_date_string(
        last_time_unix_sec + TIME_INTERVAL_SECONDS)
    spc_date_strings = time_conversion.get_spc_dates_in_range(
        first_spc_date_string=first_spc_date_string,
        last_spc_date_string=last_spc_date_string)

    tracking_file_names = []

    for this_spc_date_string in spc_date_strings:
        these_file_names = tracking_io.find_files_one_spc_date(
            top_tracking_dir_name=top_tracking_dir_name,
            tracking_scale_metres2=echo_top_tracking.
            DUMMY_TRACKING_SCALE_METRES2,
            source_name=tracking_utils.SEGMOTION_NAME,
            spc_date_string=this_spc_date_string,
            raise_error_if_missing=False)[0]

        tracking_file_names += these_file_names

    if len(tracking_file_names) == 0:
        error_string = (
            'Cannot find any tracking files for SPC dates "{0:s}" to "{1:s}".'
        ).format(first_spc_date_string, last_spc_date_string)

        raise ValueError(error_string)

    tracking_times_unix_sec = numpy.array(
        [tracking_io.file_name_to_time(f) for f in tracking_file_names],
        dtype=int)

    sort_indices = numpy.argsort(tracking_times_unix_sec)
    tracking_times_unix_sec = tracking_times_unix_sec[sort_indices]
    tracking_file_names = [tracking_file_names[k] for k in sort_indices]

    these_indices = numpy.where(
        tracking_times_unix_sec <= first_time_unix_sec)[0]

    if len(these_indices) == 0:
        first_index = 0
    else:
        first_index = these_indices[-1]

    these_indices = numpy.where(
        tracking_times_unix_sec >= last_time_unix_sec)[0]

    if len(these_indices) == 0:
        last_index = len(tracking_file_names) - 1
    else:
        last_index = these_indices[0]

    return tracking_file_names[first_index:(last_index + 1)]
def _plot_one_example(full_id_string,
                      storm_time_unix_sec,
                      target_name,
                      forecast_probability,
                      tornado_dir_name,
                      top_tracking_dir_name,
                      top_myrorss_dir_name,
                      radar_field_name,
                      radar_height_m_asl,
                      latitude_buffer_deg,
                      longitude_buffer_deg,
                      top_output_dir_name,
                      aux_forecast_probabilities=None,
                      aux_activation_dict=None):
    """Plots one example with surrounding context at several times.

    N = number of storm objects read from auxiliary activation file

    :param full_id_string: Full storm ID.
    :param storm_time_unix_sec: Storm time.
    :param target_name: Name of target variable.
    :param forecast_probability: Forecast tornado probability for this example.
    :param tornado_dir_name: See documentation at top of file.
    :param top_tracking_dir_name: Same.
    :param top_myrorss_dir_name: Same.
    :param radar_field_name: Same.
    :param radar_height_m_asl: Same.
    :param latitude_buffer_deg: Same.
    :param longitude_buffer_deg: Same.
    :param top_output_dir_name: Same.
    :param aux_forecast_probabilities: length-N numpy array of forecast
        probabilities.  If this is None, will not plot forecast probs in maps.
    :param aux_activation_dict: Dictionary returned by
        `model_activation.read_file` from auxiliary file.  If this is None, will
        not plot forecast probs in maps.
    """

    storm_time_string = time_conversion.unix_sec_to_string(
        storm_time_unix_sec, TIME_FORMAT)

    # Create output directory for this example.
    output_dir_name = '{0:s}/{1:s}_{2:s}'.format(top_output_dir_name,
                                                 full_id_string,
                                                 storm_time_string)
    file_system_utils.mkdir_recursive_if_necessary(
        directory_name=output_dir_name)

    # Find tracking files.
    tracking_file_names = _find_tracking_files_one_example(
        valid_time_unix_sec=storm_time_unix_sec,
        top_tracking_dir_name=top_tracking_dir_name,
        target_name=target_name)

    tracking_times_unix_sec = numpy.array(
        [tracking_io.file_name_to_time(f) for f in tracking_file_names],
        dtype=int)

    tracking_time_strings = [
        time_conversion.unix_sec_to_string(t, TIME_FORMAT)
        for t in tracking_times_unix_sec
    ]

    # Read tracking files.
    storm_object_table = tracking_io.read_many_files(tracking_file_names)
    print('\n')

    if aux_activation_dict is not None:
        these_indices = tracking_utils.find_storm_objects(
            all_id_strings=aux_activation_dict[model_activation.FULL_IDS_KEY],
            all_times_unix_sec=aux_activation_dict[
                model_activation.STORM_TIMES_KEY],
            id_strings_to_keep=storm_object_table[
                tracking_utils.FULL_ID_COLUMN].values.tolist(),
            times_to_keep_unix_sec=storm_object_table[
                tracking_utils.VALID_TIME_COLUMN].values,
            allow_missing=True)

        storm_object_probs = numpy.array([
            aux_forecast_probabilities[k] if k >= 0 else numpy.nan
            for k in these_indices
        ])

        storm_object_table = storm_object_table.assign(
            **{FORECAST_PROBABILITY_COLUMN: storm_object_probs})

    primary_id_string = temporal_tracking.full_to_partial_ids([full_id_string
                                                               ])[0][0]

    this_storm_object_table = storm_object_table.loc[storm_object_table[
        tracking_utils.PRIMARY_ID_COLUMN] == primary_id_string]

    latitude_limits_deg, longitude_limits_deg = _get_plotting_limits(
        storm_object_table=this_storm_object_table,
        latitude_buffer_deg=latitude_buffer_deg,
        longitude_buffer_deg=longitude_buffer_deg)

    storm_min_latitudes_deg = numpy.array([
        numpy.min(numpy.array(p.exterior.xy[1])) for p in storm_object_table[
            tracking_utils.LATLNG_POLYGON_COLUMN].values
    ])

    storm_max_latitudes_deg = numpy.array([
        numpy.max(numpy.array(p.exterior.xy[1])) for p in storm_object_table[
            tracking_utils.LATLNG_POLYGON_COLUMN].values
    ])

    storm_min_longitudes_deg = numpy.array([
        numpy.min(numpy.array(p.exterior.xy[0])) for p in storm_object_table[
            tracking_utils.LATLNG_POLYGON_COLUMN].values
    ])

    storm_max_longitudes_deg = numpy.array([
        numpy.max(numpy.array(p.exterior.xy[0])) for p in storm_object_table[
            tracking_utils.LATLNG_POLYGON_COLUMN].values
    ])

    min_latitude_flags = numpy.logical_and(
        storm_min_latitudes_deg >= latitude_limits_deg[0],
        storm_min_latitudes_deg <= latitude_limits_deg[1])

    max_latitude_flags = numpy.logical_and(
        storm_max_latitudes_deg >= latitude_limits_deg[0],
        storm_max_latitudes_deg <= latitude_limits_deg[1])

    latitude_flags = numpy.logical_or(min_latitude_flags, max_latitude_flags)

    min_longitude_flags = numpy.logical_and(
        storm_min_longitudes_deg >= longitude_limits_deg[0],
        storm_min_longitudes_deg <= longitude_limits_deg[1])

    max_longitude_flags = numpy.logical_and(
        storm_max_longitudes_deg >= longitude_limits_deg[0],
        storm_max_longitudes_deg <= longitude_limits_deg[1])

    longitude_flags = numpy.logical_or(min_longitude_flags,
                                       max_longitude_flags)
    good_indices = numpy.where(
        numpy.logical_and(latitude_flags, longitude_flags))[0]

    storm_object_table = storm_object_table.iloc[good_indices]

    # Read tornado reports.
    target_param_dict = target_val_utils.target_name_to_params(target_name)
    min_lead_time_seconds = target_param_dict[
        target_val_utils.MIN_LEAD_TIME_KEY]
    max_lead_time_seconds = target_param_dict[
        target_val_utils.MAX_LEAD_TIME_KEY]

    tornado_table = linkage._read_input_tornado_reports(
        input_directory_name=tornado_dir_name,
        storm_times_unix_sec=numpy.array([storm_time_unix_sec], dtype=int),
        max_time_before_storm_start_sec=-1 * min_lead_time_seconds,
        max_time_after_storm_end_sec=max_lead_time_seconds,
        genesis_only=True)

    tornado_table = tornado_table.loc[
        (tornado_table[linkage.EVENT_LATITUDE_COLUMN] >= latitude_limits_deg[0]
         )
        & (tornado_table[linkage.EVENT_LATITUDE_COLUMN] <=
           latitude_limits_deg[1])]

    tornado_table = tornado_table.loc[
        (tornado_table[linkage.EVENT_LONGITUDE_COLUMN] >=
         longitude_limits_deg[0])
        & (tornado_table[linkage.EVENT_LONGITUDE_COLUMN] <=
           longitude_limits_deg[1])]

    for i in range(len(tracking_file_names)):
        this_storm_object_table = storm_object_table.loc[storm_object_table[
            tracking_utils.VALID_TIME_COLUMN] == tracking_times_unix_sec[i]]

        _plot_one_example_one_time(
            storm_object_table=this_storm_object_table,
            full_id_string=full_id_string,
            valid_time_unix_sec=tracking_times_unix_sec[i],
            tornado_table=copy.deepcopy(tornado_table),
            top_myrorss_dir_name=top_myrorss_dir_name,
            radar_field_name=radar_field_name,
            radar_height_m_asl=radar_height_m_asl,
            latitude_limits_deg=latitude_limits_deg,
            longitude_limits_deg=longitude_limits_deg)

        if aux_activation_dict is None:
            this_title_string = (
                'Valid time = {0:s} ... forecast prob at {1:s} = {2:.3f}'
            ).format(tracking_time_strings[i], storm_time_string,
                     forecast_probability)

            pyplot.title(this_title_string, fontsize=TITLE_FONT_SIZE)

        this_file_name = '{0:s}/{1:s}.jpg'.format(output_dir_name,
                                                  tracking_time_strings[i])

        print('Saving figure to file: "{0:s}"...\n'.format(this_file_name))
        pyplot.savefig(this_file_name, dpi=FIGURE_RESOLUTION_DPI)
        pyplot.close()

        imagemagick_utils.trim_whitespace(input_file_name=this_file_name,
                                          output_file_name=this_file_name)
def _extract_storm_images(num_image_rows, num_image_columns, rotate_grids,
                          rotated_grid_spacing_metres, radar_field_names,
                          refl_heights_m_agl, spc_date_string,
                          tarred_myrorss_dir_name, untarred_myrorss_dir_name,
                          top_tracking_dir_name, elevation_dir_name,
                          tracking_scale_metres2, target_name,
                          top_target_dir_name, top_output_dir_name):
    """Extracts storm-centered img for each field/height pair and storm object.

    :param num_image_rows: See documentation at top of file.
    :param num_image_columns: Same.
    :param rotate_grids: Same.
    :param rotated_grid_spacing_metres: Same.
    :param radar_field_names: Same.
    :param refl_heights_m_agl: Same.
    :param spc_date_string: Same.
    :param tarred_myrorss_dir_name: Same.
    :param untarred_myrorss_dir_name: Same.
    :param top_tracking_dir_name: Same.
    :param elevation_dir_name: Same.
    :param tracking_scale_metres2: Same.
    :param target_name: Same.
    :param top_target_dir_name: Same.
    :param top_output_dir_name: Same.
    """

    if target_name in ['', 'None']:
        target_name = None

    if target_name is not None:
        target_param_dict = target_val_utils.target_name_to_params(target_name)

        target_file_name = target_val_utils.find_target_file(
            top_directory_name=top_target_dir_name,
            event_type_string=target_param_dict[
                target_val_utils.EVENT_TYPE_KEY],
            spc_date_string=spc_date_string)

        print('Reading data from: "{0:s}"...'.format(target_file_name))
        target_dict = target_val_utils.read_target_values(
            netcdf_file_name=target_file_name, target_names=[target_name])
        print('\n')

    refl_heights_m_asl = radar_utils.get_valid_heights(
        data_source=radar_utils.MYRORSS_SOURCE_ID,
        field_name=radar_utils.REFL_NAME)

    # Untar files with azimuthal shear.
    az_shear_field_names = list(
        set(radar_field_names) & set(ALL_AZ_SHEAR_FIELD_NAMES))

    if len(az_shear_field_names):
        az_shear_tar_file_name = (
            '{0:s}/{1:s}/azimuthal_shear_only/{2:s}.tar').format(
                tarred_myrorss_dir_name, spc_date_string[:4], spc_date_string)

        myrorss_io.unzip_1day_tar_file(
            tar_file_name=az_shear_tar_file_name,
            field_names=az_shear_field_names,
            spc_date_string=spc_date_string,
            top_target_directory_name=untarred_myrorss_dir_name)
        print(SEPARATOR_STRING)

    # Untar files with other radar fields.
    non_shear_field_names = list(
        set(radar_field_names) - set(ALL_AZ_SHEAR_FIELD_NAMES))

    if len(non_shear_field_names):
        non_shear_tar_file_name = '{0:s}/{1:s}/{2:s}.tar'.format(
            tarred_myrorss_dir_name, spc_date_string[:4], spc_date_string)

        myrorss_io.unzip_1day_tar_file(
            tar_file_name=non_shear_tar_file_name,
            field_names=non_shear_field_names,
            spc_date_string=spc_date_string,
            top_target_directory_name=untarred_myrorss_dir_name,
            refl_heights_m_asl=refl_heights_m_asl)
        print(SEPARATOR_STRING)

    # Read storm tracks for the given SPC date.
    tracking_file_names = tracking_io.find_files_one_spc_date(
        spc_date_string=spc_date_string,
        source_name=tracking_utils.SEGMOTION_NAME,
        top_tracking_dir_name=top_tracking_dir_name,
        tracking_scale_metres2=tracking_scale_metres2)[0]

    storm_object_table = tracking_io.read_many_files(tracking_file_names)[
        storm_images.STORM_COLUMNS_NEEDED]
    print(SEPARATOR_STRING)

    if target_name is not None:
        print(('Removing storm objects without target values (variable = '
               '"{0:s}")...').format(target_name))

        these_indices = tracking_utils.find_storm_objects(
            all_id_strings=storm_object_table[
                tracking_utils.FULL_ID_COLUMN].values.tolist(),
            all_times_unix_sec=storm_object_table[
                tracking_utils.VALID_TIME_COLUMN].values.astype(int),
            id_strings_to_keep=target_dict[target_val_utils.FULL_IDS_KEY],
            times_to_keep_unix_sec=target_dict[
                target_val_utils.VALID_TIMES_KEY],
            allow_missing=False)

        num_storm_objects_orig = len(storm_object_table.index)
        storm_object_table = storm_object_table.iloc[these_indices]
        num_storm_objects = len(storm_object_table.index)

        print('Removed {0:d} of {1:d} storm objects!\n'.format(
            num_storm_objects_orig - num_storm_objects,
            num_storm_objects_orig))

    # Extract storm-centered radar images.
    storm_images.extract_storm_images_myrorss_or_mrms(
        storm_object_table=storm_object_table,
        radar_source=radar_utils.MYRORSS_SOURCE_ID,
        top_radar_dir_name=untarred_myrorss_dir_name,
        top_output_dir_name=top_output_dir_name,
        elevation_dir_name=elevation_dir_name,
        num_storm_image_rows=num_image_rows,
        num_storm_image_columns=num_image_columns,
        rotate_grids=rotate_grids,
        rotated_grid_spacing_metres=rotated_grid_spacing_metres,
        radar_field_names=radar_field_names,
        reflectivity_heights_m_agl=refl_heights_m_agl)
    print(SEPARATOR_STRING)

    # Remove untarred MYRORSS files.
    myrorss_io.remove_unzipped_data_1day(
        spc_date_string=spc_date_string,
        top_directory_name=untarred_myrorss_dir_name,
        field_names=radar_field_names,
        refl_heights_m_asl=refl_heights_m_asl)
    print(SEPARATOR_STRING)
예제 #12
0
def _extract_storm_images(
        num_image_rows, num_image_columns, rotate_grids,
        rotated_grid_spacing_metres, radar_field_names, refl_heights_m_agl,
        spc_date_string, first_time_string, last_time_string,
        tarred_myrorss_dir_name, untarred_myrorss_dir_name,
        top_tracking_dir_name, elevation_dir_name, tracking_scale_metres2,
        target_name, top_target_dir_name, top_output_dir_name):
    """Extracts storm-centered img for each field/height pair and storm object.

    :param num_image_rows: See documentation at top of file.
    :param num_image_columns: Same.
    :param rotate_grids: Same.
    :param rotated_grid_spacing_metres: Same.
    :param radar_field_names: Same.
    :param refl_heights_m_agl: Same.
    :param spc_date_string: Same.
    :param first_time_string: Same.
    :param last_time_string: Same.
    :param tarred_myrorss_dir_name: Same.
    :param untarred_myrorss_dir_name: Same.
    :param top_tracking_dir_name: Same.
    :param elevation_dir_name: Same.
    :param tracking_scale_metres2: Same.
    :param target_name: Same.
    :param top_target_dir_name: Same.
    :param top_output_dir_name: Same.
    :raises: ValueError: if `first_time_string` and `last_time_string` have
        different SPC dates.
    """

    if elevation_dir_name in ['', 'None']:
        elevation_dir_name = None

    if elevation_dir_name is None:
        host_name = socket.gethostname()

        if 'casper' in host_name:
            elevation_dir_name = '/glade/work/ryanlage/elevation'
        else:
            elevation_dir_name = '/condo/swatwork/ralager/elevation'

    if spc_date_string in ['', 'None']:
        first_time_unix_sec = time_conversion.string_to_unix_sec(
            first_time_string, TIME_FORMAT)
        last_time_unix_sec = time_conversion.string_to_unix_sec(
            last_time_string, TIME_FORMAT)

        first_spc_date_string = time_conversion.time_to_spc_date_string(
            first_time_unix_sec)
        last_spc_date_string = time_conversion.time_to_spc_date_string(
            last_time_unix_sec)

        if first_spc_date_string != last_spc_date_string:
            error_string = (
                'First ({0:s}) and last ({1:s}) times have different SPC dates.'
                '  This script can handle only one SPC date.'
            ).format(first_time_string, last_time_string)

            raise ValueError(error_string)

        spc_date_string = first_spc_date_string
    else:
        first_time_unix_sec = 0
        last_time_unix_sec = int(1e12)

    if tarred_myrorss_dir_name in ['', 'None']:
        tarred_myrorss_dir_name = None
    if target_name in ['', 'None']:
        target_name = None

    if target_name is not None:
        target_param_dict = target_val_utils.target_name_to_params(target_name)

        target_file_name = target_val_utils.find_target_file(
            top_directory_name=top_target_dir_name,
            event_type_string=target_param_dict[
                target_val_utils.EVENT_TYPE_KEY],
            spc_date_string=spc_date_string)

        print('Reading data from: "{0:s}"...'.format(target_file_name))
        target_dict = target_val_utils.read_target_values(
            netcdf_file_name=target_file_name, target_names=[target_name]
        )
        print('\n')

    refl_heights_m_asl = radar_utils.get_valid_heights(
        data_source=radar_utils.MYRORSS_SOURCE_ID,
        field_name=radar_utils.REFL_NAME)

    # Untar files.
    if tarred_myrorss_dir_name is not None:
        az_shear_field_names = list(
            set(radar_field_names) & set(ALL_AZ_SHEAR_FIELD_NAMES)
        )

        if len(az_shear_field_names) > 0:
            az_shear_tar_file_name = (
                '{0:s}/{1:s}/azimuthal_shear_only/{2:s}.tar'
            ).format(
                tarred_myrorss_dir_name, spc_date_string[:4], spc_date_string
            )

            myrorss_io.unzip_1day_tar_file(
                tar_file_name=az_shear_tar_file_name,
                field_names=az_shear_field_names,
                spc_date_string=spc_date_string,
                top_target_directory_name=untarred_myrorss_dir_name)
            print(SEPARATOR_STRING)

        non_shear_field_names = list(
            set(radar_field_names) - set(ALL_AZ_SHEAR_FIELD_NAMES)
        )

        if len(non_shear_field_names) > 0:
            non_shear_tar_file_name = '{0:s}/{1:s}/{2:s}.tar'.format(
                tarred_myrorss_dir_name, spc_date_string[:4], spc_date_string
            )

            myrorss_io.unzip_1day_tar_file(
                tar_file_name=non_shear_tar_file_name,
                field_names=non_shear_field_names,
                spc_date_string=spc_date_string,
                top_target_directory_name=untarred_myrorss_dir_name,
                refl_heights_m_asl=refl_heights_m_asl)
            print(SEPARATOR_STRING)

    # Read storm tracks for the given SPC date.
    tracking_file_names = tracking_io.find_files_one_spc_date(
        spc_date_string=spc_date_string,
        source_name=tracking_utils.SEGMOTION_NAME,
        top_tracking_dir_name=top_tracking_dir_name,
        tracking_scale_metres2=tracking_scale_metres2
    )[0]

    file_times_unix_sec = numpy.array(
        [tracking_io.file_name_to_time(f) for f in tracking_file_names],
        dtype=int
    )

    good_indices = numpy.where(numpy.logical_and(
        file_times_unix_sec >= first_time_unix_sec,
        file_times_unix_sec <= last_time_unix_sec
    ))[0]

    tracking_file_names = [tracking_file_names[k] for k in good_indices]

    storm_object_table = tracking_io.read_many_files(
        tracking_file_names
    )[storm_images.STORM_COLUMNS_NEEDED]
    print(SEPARATOR_STRING)

    if target_name is not None:
        print((
            'Removing storm objects without target values (variable = '
            '"{0:s}")...'
        ).format(target_name))

        these_indices = tracking_utils.find_storm_objects(
            all_id_strings=storm_object_table[
                tracking_utils.FULL_ID_COLUMN].values.tolist(),
            all_times_unix_sec=storm_object_table[
                tracking_utils.VALID_TIME_COLUMN].values.astype(int),
            id_strings_to_keep=target_dict[target_val_utils.FULL_IDS_KEY],
            times_to_keep_unix_sec=target_dict[
                target_val_utils.VALID_TIMES_KEY],
            allow_missing=False)

        num_storm_objects_orig = len(storm_object_table.index)
        storm_object_table = storm_object_table.iloc[these_indices]
        num_storm_objects = len(storm_object_table.index)

        print('Removed {0:d} of {1:d} storm objects!\n'.format(
            num_storm_objects_orig - num_storm_objects, num_storm_objects_orig
        ))

    # Extract storm-centered radar images.
    storm_images.extract_storm_images_myrorss_or_mrms(
        storm_object_table=storm_object_table,
        radar_source=radar_utils.MYRORSS_SOURCE_ID,
        top_radar_dir_name=untarred_myrorss_dir_name,
        top_output_dir_name=top_output_dir_name,
        elevation_dir_name=elevation_dir_name,
        num_storm_image_rows=num_image_rows,
        num_storm_image_columns=num_image_columns, rotate_grids=rotate_grids,
        rotated_grid_spacing_metres=rotated_grid_spacing_metres,
        radar_field_names=radar_field_names,
        reflectivity_heights_m_agl=refl_heights_m_agl)
    print(SEPARATOR_STRING)

    # Remove untarred MYRORSS files.
    if tarred_myrorss_dir_name is not None:
        myrorss_io.remove_unzipped_data_1day(
            spc_date_string=spc_date_string,
            top_directory_name=untarred_myrorss_dir_name,
            field_names=radar_field_names,
            refl_heights_m_asl=refl_heights_m_asl)
예제 #13
0
def _run(input_prediction_file_name, top_tracking_dir_name,
         tracking_scale_metres2, x_spacing_metres, y_spacing_metres,
         effective_radius_metres, smoothing_method_name,
         smoothing_cutoff_radius_metres, smoothing_efold_radius_metres,
         top_output_dir_name):
    """Projects CNN forecasts onto the RAP grid.

    This is effectively the same method.

    :param input_prediction_file_name: See documentation at top of file.
    :param top_tracking_dir_name: Same.
    :param tracking_scale_metres2: Same.
    :param x_spacing_metres: Same.
    :param y_spacing_metres: Same.
    :param effective_radius_metres: Same.
    :param smoothing_method_name: Same.
    :param smoothing_cutoff_radius_metres: Same.
    :param smoothing_efold_radius_metres: Same.
    :param top_output_dir_name: Same.
    """

    print('Reading data from: "{0:s}"...'.format(input_prediction_file_name))
    ungridded_forecast_dict = prediction_io.read_ungridded_predictions(
        input_prediction_file_name)

    target_param_dict = target_val_utils.target_name_to_params(
        ungridded_forecast_dict[prediction_io.TARGET_NAME_KEY])

    min_buffer_dist_metres = target_param_dict[
        target_val_utils.MIN_LINKAGE_DISTANCE_KEY]

    # TODO(thunderhoser): This is HACKY.
    if min_buffer_dist_metres == 0:
        min_buffer_dist_metres = numpy.nan

    max_buffer_dist_metres = target_param_dict[
        target_val_utils.MAX_LINKAGE_DISTANCE_KEY]

    min_lead_time_seconds = target_param_dict[
        target_val_utils.MIN_LEAD_TIME_KEY]

    max_lead_time_seconds = target_param_dict[
        target_val_utils.MAX_LEAD_TIME_KEY]

    forecast_column_name = gridded_forecasts._buffer_to_column_name(
        min_buffer_dist_metres=min_buffer_dist_metres,
        max_buffer_dist_metres=max_buffer_dist_metres,
        column_type=gridded_forecasts.FORECAST_COLUMN_TYPE)

    init_times_unix_sec = numpy.unique(
        ungridded_forecast_dict[prediction_io.STORM_TIMES_KEY])

    tracking_file_names = []

    for this_time_unix_sec in init_times_unix_sec:
        this_tracking_file_name = tracking_io.find_file(
            top_tracking_dir_name=top_tracking_dir_name,
            tracking_scale_metres2=tracking_scale_metres2,
            source_name=tracking_utils.SEGMOTION_NAME,
            valid_time_unix_sec=this_time_unix_sec,
            spc_date_string=time_conversion.time_to_spc_date_string(
                this_time_unix_sec),
            raise_error_if_missing=True)

        tracking_file_names.append(this_tracking_file_name)

    storm_object_table = tracking_io.read_many_files(tracking_file_names)
    print(SEPARATOR_STRING)

    tracking_utils.find_storm_objects(
        all_id_strings=ungridded_forecast_dict[prediction_io.STORM_IDS_KEY],
        all_times_unix_sec=ungridded_forecast_dict[
            prediction_io.STORM_TIMES_KEY],
        id_strings_to_keep=storm_object_table[
            tracking_utils.FULL_ID_COLUMN].values.tolist(),
        times_to_keep_unix_sec=storm_object_table[
            tracking_utils.VALID_TIME_COLUMN].values,
        allow_missing=False)

    sort_indices = tracking_utils.find_storm_objects(
        all_id_strings=storm_object_table[
            tracking_utils.FULL_ID_COLUMN].values.tolist(),
        all_times_unix_sec=storm_object_table[
            tracking_utils.VALID_TIME_COLUMN].values,
        id_strings_to_keep=ungridded_forecast_dict[
            prediction_io.STORM_IDS_KEY],
        times_to_keep_unix_sec=ungridded_forecast_dict[
            prediction_io.STORM_TIMES_KEY],
        allow_missing=False)

    forecast_probabilities = ungridded_forecast_dict[
        prediction_io.PROBABILITY_MATRIX_KEY][sort_indices, 1]

    storm_object_table = storm_object_table.assign(
        **{forecast_column_name: forecast_probabilities})

    gridded_forecast_dict = gridded_forecasts.create_forecast_grids(
        storm_object_table=storm_object_table,
        min_lead_time_sec=min_lead_time_seconds,
        max_lead_time_sec=max_lead_time_seconds,
        lead_time_resolution_sec=gridded_forecasts.
        DEFAULT_LEAD_TIME_RES_SECONDS,
        grid_spacing_x_metres=x_spacing_metres,
        grid_spacing_y_metres=y_spacing_metres,
        interp_to_latlng_grid=False,
        prob_radius_for_grid_metres=effective_radius_metres,
        smoothing_method=smoothing_method_name,
        smoothing_e_folding_radius_metres=smoothing_efold_radius_metres,
        smoothing_cutoff_radius_metres=smoothing_cutoff_radius_metres)

    print(SEPARATOR_STRING)

    output_file_name = prediction_io.find_file(
        top_prediction_dir_name=top_output_dir_name,
        first_init_time_unix_sec=numpy.min(
            storm_object_table[tracking_utils.VALID_TIME_COLUMN].values),
        last_init_time_unix_sec=numpy.max(
            storm_object_table[tracking_utils.VALID_TIME_COLUMN].values),
        gridded=True,
        raise_error_if_missing=False)

    print(('Writing results (forecast grids for {0:d} initial times) to: '
           '"{1:s}"...').format(
               len(gridded_forecast_dict[prediction_io.INIT_TIMES_KEY]),
               output_file_name))

    prediction_io.write_gridded_predictions(
        gridded_forecast_dict=gridded_forecast_dict,
        pickle_file_name=output_file_name)
예제 #14
0
def _run(model_file_name, example_file_name, first_time_string,
         last_time_string, top_output_dir_name):
    """Applies CNN to one example file.

    This is effectively the main method.

    :param model_file_name: See documentation at top of file.
    :param example_file_name: Same.
    :param first_time_string: Same.
    :param last_time_string: Same.
    :param top_output_dir_name: Same.
    """

    print('Reading model from: "{0:s}"...'.format(model_file_name))
    model_object = cnn.read_model(model_file_name)

    model_directory_name, _ = os.path.split(model_file_name)
    model_metafile_name = '{0:s}/model_metadata.p'.format(model_directory_name)

    print('Reading metadata from: "{0:s}"...'.format(model_metafile_name))
    model_metadata_dict = cnn.read_model_metadata(model_metafile_name)
    training_option_dict = model_metadata_dict[cnn.TRAINING_OPTION_DICT_KEY]

    first_time_unix_sec = time_conversion.string_to_unix_sec(
        first_time_string, INPUT_TIME_FORMAT)
    last_time_unix_sec = time_conversion.string_to_unix_sec(
        last_time_string, INPUT_TIME_FORMAT)

    training_option_dict[trainval_io.SAMPLING_FRACTIONS_KEY] = None
    training_option_dict[trainval_io.EXAMPLE_FILES_KEY] = [example_file_name]
    training_option_dict[
        trainval_io.FIRST_STORM_TIME_KEY] = first_time_unix_sec
    training_option_dict[trainval_io.LAST_STORM_TIME_KEY] = last_time_unix_sec

    if model_metadata_dict[cnn.LAYER_OPERATIONS_KEY] is not None:
        generator_object = testing_io.gridrad_generator_2d_reduced(
            option_dict=training_option_dict,
            list_of_operation_dicts=model_metadata_dict[
                cnn.LAYER_OPERATIONS_KEY],
            num_examples_total=LARGE_INTEGER)

    elif model_metadata_dict[cnn.CONV_2D3D_KEY]:
        generator_object = testing_io.myrorss_generator_2d3d(
            option_dict=training_option_dict, num_examples_total=LARGE_INTEGER)
    else:
        generator_object = testing_io.generator_2d_or_3d(
            option_dict=training_option_dict, num_examples_total=LARGE_INTEGER)

    include_soundings = (training_option_dict[trainval_io.SOUNDING_FIELDS_KEY]
                         is not None)

    try:
        storm_object_dict = next(generator_object)
    except StopIteration:
        storm_object_dict = None

    print(SEPARATOR_STRING)

    if storm_object_dict is not None:
        observed_labels = storm_object_dict[testing_io.TARGET_ARRAY_KEY]
        list_of_predictor_matrices = storm_object_dict[
            testing_io.INPUT_MATRICES_KEY]

        if include_soundings:
            sounding_matrix = list_of_predictor_matrices[-1]
        else:
            sounding_matrix = None

        if model_metadata_dict[cnn.CONV_2D3D_KEY]:
            if training_option_dict[trainval_io.UPSAMPLE_REFLECTIVITY_KEY]:
                class_probability_matrix = cnn.apply_2d_or_3d_cnn(
                    model_object=model_object,
                    radar_image_matrix=list_of_predictor_matrices[0],
                    sounding_matrix=sounding_matrix,
                    verbose=True)
            else:
                class_probability_matrix = cnn.apply_2d3d_cnn(
                    model_object=model_object,
                    reflectivity_matrix_dbz=list_of_predictor_matrices[0],
                    azimuthal_shear_matrix_s01=list_of_predictor_matrices[1],
                    sounding_matrix=sounding_matrix,
                    verbose=True)
        else:
            class_probability_matrix = cnn.apply_2d_or_3d_cnn(
                model_object=model_object,
                radar_image_matrix=list_of_predictor_matrices[0],
                sounding_matrix=sounding_matrix,
                verbose=True)

        print(SEPARATOR_STRING)
        num_examples = class_probability_matrix.shape[0]

        for k in [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100]:
            print(
                '{0:d}th percentile of {1:d} forecast probs = {2:.4f}'.format(
                    k, num_examples,
                    numpy.percentile(class_probability_matrix[:, 1], k)))

        print('\n')

    target_param_dict = target_val_utils.target_name_to_params(
        training_option_dict[trainval_io.TARGET_NAME_KEY])

    event_type_string = target_param_dict[target_val_utils.EVENT_TYPE_KEY]
    if event_type_string == linkage.TORNADO_EVENT_STRING:
        genesis_only = False
    elif event_type_string == linkage.TORNADOGENESIS_EVENT_STRING:
        genesis_only = True
    else:
        genesis_only = None

    target_name = target_val_utils.target_params_to_name(
        min_lead_time_sec=target_param_dict[
            target_val_utils.MIN_LEAD_TIME_KEY],
        max_lead_time_sec=target_param_dict[
            target_val_utils.MAX_LEAD_TIME_KEY],
        min_link_distance_metres=target_param_dict[
            target_val_utils.MIN_LINKAGE_DISTANCE_KEY],
        max_link_distance_metres=10000.,
        genesis_only=genesis_only)

    output_file_name = prediction_io.find_file(
        top_prediction_dir_name=top_output_dir_name,
        first_init_time_unix_sec=first_time_unix_sec,
        last_init_time_unix_sec=last_time_unix_sec,
        gridded=False,
        raise_error_if_missing=False)

    print('Writing "{0:s}" predictions to: "{1:s}"...'.format(
        target_name, output_file_name))

    if storm_object_dict is None:
        num_output_neurons = (
            model_object.layers[-1].output.get_shape().as_list()[-1])

        num_classes = max([num_output_neurons, 2])
        class_probability_matrix = numpy.full((0, num_classes), numpy.nan)

        prediction_io.write_ungridded_predictions(
            netcdf_file_name=output_file_name,
            class_probability_matrix=class_probability_matrix,
            storm_ids=[],
            storm_times_unix_sec=numpy.array([], dtype=int),
            target_name=target_name,
            observed_labels=numpy.array([], dtype=int))

        return

    prediction_io.write_ungridded_predictions(
        netcdf_file_name=output_file_name,
        class_probability_matrix=class_probability_matrix,
        storm_ids=storm_object_dict[testing_io.FULL_IDS_KEY],
        storm_times_unix_sec=storm_object_dict[testing_io.STORM_TIMES_KEY],
        target_name=target_name,
        observed_labels=observed_labels)
예제 #15
0
def _run(prediction_file_name, best_prob_threshold, upgraded_min_ef_rating,
         top_target_dir_name, num_bootstrap_reps, downsampling_fractions,
         output_dir_name):
    """Evaluates CNN predictions.

    This is effectively the main method.

    :param prediction_file_name: See documentation at top of file.
    :param best_prob_threshold: Same.
    :param upgraded_min_ef_rating: Same.
    :param top_target_dir_name: Same.
    :param num_bootstrap_reps: Same.
    :param downsampling_fractions: Same.
    :param output_dir_name: Same.
    :raises: ValueError: if file contains no examples (storm objects).
    :raises: ValueError: if file contains multi-class predictions.
    :raises: ValueError: if you try to upgrade minimum EF rating but the
        original is non-zero.
    """

    # Verify and process input args.
    if upgraded_min_ef_rating <= 0:
        upgraded_min_ef_rating = None

    num_bootstrap_reps = max([num_bootstrap_reps, 1])
    if best_prob_threshold < 0:
        best_prob_threshold = None

    # Read predictions.
    print('Reading data from: "{0:s}"...'.format(prediction_file_name))
    prediction_dict = prediction_io.read_ungridded_predictions(
        prediction_file_name)

    observed_labels = prediction_dict[prediction_io.OBSERVED_LABELS_KEY]
    class_probability_matrix = (
        prediction_dict[prediction_io.PROBABILITY_MATRIX_KEY])

    num_examples = len(observed_labels)
    num_classes = class_probability_matrix.shape[1]

    if num_examples == 0:
        raise ValueError('File contains no examples (storm objects).')

    if num_classes > 2:
        error_string = (
            'This script handles only binary, not {0:d}-class, classification.'
        ).format(num_classes)

        raise ValueError(error_string)

    forecast_probabilities = class_probability_matrix[:, -1]

    # If necessary, upgrade minimum EF rating.
    if upgraded_min_ef_rating is not None:
        target_param_dict = target_val_utils.target_name_to_params(
            prediction_dict[prediction_io.TARGET_NAME_KEY])
        orig_min_ef_rating = (
            target_param_dict[target_val_utils.MIN_FUJITA_RATING_KEY])

        if orig_min_ef_rating != 0:
            error_string = (
                'Cannot upgrade minimum EF rating when original min rating is '
                'non-zero (in this case it is {0:d}).'
            ).format(orig_min_ef_rating)

            raise ValueError(error_string)

        new_target_name = target_val_utils.target_params_to_name(
            min_lead_time_sec=target_param_dict[
                target_val_utils.MIN_LEAD_TIME_KEY],
            max_lead_time_sec=target_param_dict[
                target_val_utils.MAX_LEAD_TIME_KEY],
            min_link_distance_metres=target_param_dict[
                target_val_utils.MIN_LINKAGE_DISTANCE_KEY],
            max_link_distance_metres=target_param_dict[
                target_val_utils.MAX_LINKAGE_DISTANCE_KEY],
            tornadogenesis_only=(
                target_param_dict[target_val_utils.EVENT_TYPE_KEY] ==
                linkage.TORNADOGENESIS_EVENT_STRING),
            min_fujita_rating=upgraded_min_ef_rating)

        print(SEPARATOR_STRING)

        observed_labels = _read_new_target_values(
            top_target_dir_name=top_target_dir_name,
            new_target_name=new_target_name,
            full_storm_id_strings=prediction_dict[prediction_io.STORM_IDS_KEY],
            storm_times_unix_sec=prediction_dict[
                prediction_io.STORM_TIMES_KEY],
            orig_target_values=observed_labels)

        print(SEPARATOR_STRING)

        good_indices = numpy.where(observed_labels >= 0)[0]
        observed_labels = observed_labels[good_indices]
        forecast_probabilities = forecast_probabilities[good_indices]

    # Do calculations.
    output_file_name = model_eval.find_file_from_prediction_file(
        input_prediction_file_name=prediction_file_name,
        output_dir_name=output_dir_name,
        raise_error_if_missing=False)
    file_system_utils.mkdir_recursive_if_necessary(file_name=output_file_name)

    if numpy.any(downsampling_fractions <= 0):
        downsampling_dict = None
    else:
        downsampling_dict = {
            0: downsampling_fractions[0],
            1: downsampling_fractions[1]
        }

    _compute_scores(forecast_probabilities=forecast_probabilities,
                    observed_labels=observed_labels,
                    num_bootstrap_reps=num_bootstrap_reps,
                    best_prob_threshold=best_prob_threshold,
                    downsampling_dict=downsampling_dict,
                    output_file_name=output_file_name)
예제 #16
0
def write_ungridded_predictions(
        netcdf_file_name, class_probability_matrix, storm_ids,
        storm_times_unix_sec, target_name, observed_labels=None):
    """Writes predictions to NetCDF file.

    K = number of classes
    E = number of examples (storm objects)

    :param netcdf_file_name: Path to output file.
    :param class_probability_matrix: E-by-K numpy array of forecast
        probabilities.
    :param storm_ids: length-E list of storm IDs (strings).
    :param storm_times_unix_sec: length-E numpy array of valid times.
    :param target_name: Name of target variable.
    :param observed_labels: [this may be None]
        length-E numpy array of observed labels (integers in 0...[K - 1]).
    """

    # Check input args.
    error_checking.assert_is_numpy_array(
        class_probability_matrix, num_dimensions=2)
    error_checking.assert_is_geq_numpy_array(class_probability_matrix, 0.)
    error_checking.assert_is_leq_numpy_array(class_probability_matrix, 1.)

    num_examples = class_probability_matrix.shape[0]
    these_expected_dim = numpy.array([num_examples], dtype=int)

    error_checking.assert_is_string_list(storm_ids)
    error_checking.assert_is_numpy_array(
        numpy.array(storm_ids), exact_dimensions=these_expected_dim)

    error_checking.assert_is_integer_numpy_array(storm_times_unix_sec)
    error_checking.assert_is_numpy_array(
        storm_times_unix_sec, exact_dimensions=these_expected_dim)

    target_val_utils.target_name_to_params(target_name)

    if observed_labels is not None:
        error_checking.assert_is_integer_numpy_array(observed_labels)
        error_checking.assert_is_numpy_array(
            observed_labels, exact_dimensions=these_expected_dim)

    # Write to NetCDF file.
    file_system_utils.mkdir_recursive_if_necessary(file_name=netcdf_file_name)
    dataset_object = netCDF4.Dataset(
        netcdf_file_name, 'w', format='NETCDF3_64BIT_OFFSET')

    dataset_object.setncattr(TARGET_NAME_KEY, target_name)
    dataset_object.createDimension(
        EXAMPLE_DIMENSION_KEY, class_probability_matrix.shape[0]
    )
    dataset_object.createDimension(
        CLASS_DIMENSION_KEY, class_probability_matrix.shape[1]
    )

    if num_examples == 0:
        num_id_characters = 1
    else:
        num_id_characters = 1 + numpy.max(numpy.array([
            len(s) for s in storm_ids
        ]))

    dataset_object.createDimension(STORM_ID_CHAR_DIM_KEY, num_id_characters)

    # Add storm IDs.
    this_string_format = 'S{0:d}'.format(num_id_characters)
    storm_ids_char_array = netCDF4.stringtochar(numpy.array(
        storm_ids, dtype=this_string_format
    ))

    dataset_object.createVariable(
        STORM_IDS_KEY, datatype='S1',
        dimensions=(EXAMPLE_DIMENSION_KEY, STORM_ID_CHAR_DIM_KEY)
    )
    dataset_object.variables[STORM_IDS_KEY][:] = numpy.array(
        storm_ids_char_array)

    # Add storm times.
    dataset_object.createVariable(
        STORM_TIMES_KEY, datatype=numpy.int32, dimensions=EXAMPLE_DIMENSION_KEY
    )
    dataset_object.variables[STORM_TIMES_KEY][:] = storm_times_unix_sec

    # Add probabilities.
    dataset_object.createVariable(
        PROBABILITY_MATRIX_KEY, datatype=numpy.float32,
        dimensions=(EXAMPLE_DIMENSION_KEY, CLASS_DIMENSION_KEY)
    )
    dataset_object.variables[PROBABILITY_MATRIX_KEY][:] = (
        class_probability_matrix
    )

    if observed_labels is not None:
        dataset_object.createVariable(
            OBSERVED_LABELS_KEY, datatype=numpy.int32,
            dimensions=EXAMPLE_DIMENSION_KEY
        )
        dataset_object.variables[OBSERVED_LABELS_KEY][:] = observed_labels

    dataset_object.close()
예제 #17
0
def _run(top_input_dir_name, target_name_for_downsampling,
         first_spc_date_string, last_spc_date_string, downsampling_classes,
         downsampling_fractions, for_training, top_output_dir_name):
    """Downsamples storm objects, based on target values.

    This is effectively the main method.

    :param top_input_dir_name: See documentation at top of file.
    :param target_name_for_downsampling: Same.
    :param first_spc_date_string: Same.
    :param last_spc_date_string: Same.
    :param downsampling_classes: Same.
    :param downsampling_fractions: Same.
    :param for_training: Same.
    :param top_output_dir_name: Same.
    """

    all_spc_date_strings = time_conversion.get_spc_dates_in_range(
        first_spc_date_string=first_spc_date_string,
        last_spc_date_string=last_spc_date_string)

    downsampling_dict = dict(
        list(zip(downsampling_classes, downsampling_fractions)))

    target_param_dict = target_val_utils.target_name_to_params(
        target_name_for_downsampling)
    event_type_string = target_param_dict[target_val_utils.EVENT_TYPE_KEY]

    input_target_file_names = []
    spc_date_string_by_file = []

    for this_spc_date_string in all_spc_date_strings:
        this_file_name = target_val_utils.find_target_file(
            top_directory_name=top_input_dir_name,
            event_type_string=event_type_string,
            spc_date_string=this_spc_date_string,
            raise_error_if_missing=False)

        if not os.path.isfile(this_file_name):
            continue

        input_target_file_names.append(this_file_name)
        spc_date_string_by_file.append(this_spc_date_string)

    num_files = len(input_target_file_names)
    target_dict_by_file = [None] * num_files

    full_id_strings = []
    storm_times_unix_sec = numpy.array([], dtype=int)
    storm_to_file_indices = numpy.array([], dtype=int)

    target_names = []
    target_matrix = None

    for i in range(num_files):
        print('Reading data from: "{0:s}"...'.format(
            input_target_file_names[i]))

        target_dict_by_file[i] = target_val_utils.read_target_values(
            netcdf_file_name=input_target_file_names[i])

        if i == 0:
            target_names = (
                target_dict_by_file[i][target_val_utils.TARGET_NAMES_KEY])

        these_full_id_strings = (
            target_dict_by_file[i][target_val_utils.FULL_IDS_KEY])

        full_id_strings += these_full_id_strings
        this_num_storm_objects = len(these_full_id_strings)

        storm_times_unix_sec = numpy.concatenate(
            (storm_times_unix_sec,
             target_dict_by_file[i][target_val_utils.VALID_TIMES_KEY]))

        storm_to_file_indices = numpy.concatenate(
            (storm_to_file_indices,
             numpy.full(this_num_storm_objects, i, dtype=int)))

        this_target_matrix = (
            target_dict_by_file[i][target_val_utils.TARGET_MATRIX_KEY])

        if target_matrix is None:
            target_matrix = this_target_matrix + 0
        else:
            target_matrix = numpy.concatenate(
                (target_matrix, this_target_matrix), axis=0)

    print(SEPARATOR_STRING)

    downsampling_index = target_names.index(target_name_for_downsampling)
    good_indices = numpy.where(target_matrix[:, downsampling_index] !=
                               target_val_utils.INVALID_STORM_INTEGER)[0]

    full_id_strings = [full_id_strings[k] for k in good_indices]
    storm_times_unix_sec = storm_times_unix_sec[good_indices]
    target_matrix = target_matrix[good_indices, :]
    storm_to_file_indices = storm_to_file_indices[good_indices]

    primary_id_strings = temporal_tracking.full_to_partial_ids(
        full_id_strings)[0]

    if for_training:
        indices_to_keep = fancy_downsampling.downsample_for_training(
            primary_id_strings=primary_id_strings,
            storm_times_unix_sec=storm_times_unix_sec,
            target_values=target_matrix[:, downsampling_index],
            target_name=target_name_for_downsampling,
            class_fraction_dict=downsampling_dict)
    else:
        indices_to_keep = fancy_downsampling.downsample_for_non_training(
            primary_id_strings=primary_id_strings,
            storm_times_unix_sec=storm_times_unix_sec,
            target_values=target_matrix[:, downsampling_index],
            target_name=target_name_for_downsampling,
            class_fraction_dict=downsampling_dict)

    print(SEPARATOR_STRING)

    for i in range(num_files):
        these_object_subindices = numpy.where(
            storm_to_file_indices[indices_to_keep] == i)[0]

        these_object_indices = indices_to_keep[these_object_subindices]
        if len(these_object_indices) == 0:
            continue

        these_indices_in_file = tracking_utils.find_storm_objects(
            all_id_strings=target_dict_by_file[i][
                target_val_utils.FULL_IDS_KEY],
            all_times_unix_sec=target_dict_by_file[i][
                target_val_utils.VALID_TIMES_KEY],
            id_strings_to_keep=[
                full_id_strings[k] for k in these_object_indices
            ],
            times_to_keep_unix_sec=storm_times_unix_sec[these_object_indices],
            allow_missing=False)

        this_output_dict = {
            tracking_utils.FULL_ID_COLUMN: [
                target_dict_by_file[i][target_val_utils.FULL_IDS_KEY][k]
                for k in these_indices_in_file
            ],
            tracking_utils.VALID_TIME_COLUMN:
            target_dict_by_file[i][target_val_utils.VALID_TIMES_KEY]
            [these_indices_in_file]
        }

        for j in range(len(target_names)):
            this_output_dict[target_names[j]] = (target_dict_by_file[i][
                target_val_utils.TARGET_MATRIX_KEY][these_indices_in_file, j])

        this_output_table = pandas.DataFrame.from_dict(this_output_dict)

        this_new_file_name = target_val_utils.find_target_file(
            top_directory_name=top_output_dir_name,
            event_type_string=event_type_string,
            spc_date_string=spc_date_string_by_file[i],
            raise_error_if_missing=False)

        print((
            'Writing {0:d} downsampled storm objects (out of {1:d} total) to: '
            '"{2:s}"...').format(
                len(this_output_table.index),
                len(target_dict_by_file[i][target_val_utils.FULL_IDS_KEY]),
                this_new_file_name))

        target_val_utils.write_target_values(
            storm_to_events_table=this_output_table,
            target_names=target_names,
            netcdf_file_name=this_new_file_name)