def test_find_storm_objects_allow_missing_false(self):
        """Ensures correct output from find_storm_objects.

        In this case, one desired storm object is missing and
        `allow_missing = False`.
        """

        with self.assertRaises(ValueError):
            tracking_utils.find_storm_objects(
                all_storm_ids=ALL_STORM_IDS,
                all_times_unix_sec=ALL_TIMES_UNIX_SEC,
                storm_ids_to_keep=STORM_IDS_TO_KEEP_ONE_MISSING,
                times_to_keep_unix_sec=TIMES_TO_KEEP_UNIX_SEC_ONE_MISSING,
                allow_missing=False)
def _find_examples_in_prediction_dict(prediction_dict, full_storm_id_strings,
                                      storm_times_unix_sec, allow_missing):
    """Finds examples (with given ID-timem pairs) in dictionary w/ predictions.

    E = number of desired examples

    :param prediction_dict: Dictionary returned by
        `prediction_io.read_ungridded_predictions`.
    :param full_storm_id_strings: length-E list of storm IDs.
    :param storm_times_unix_sec: length-E numpy array of valid times.
    :param allow_missing: Boolean flag.  If True, will allow for missing storm
        objects.
    :return: indices_in_dict: length-E numpy array of indices.  If
        `k in indices_in_dict`, the [k]th example in the dictionary is one of
        the desired examples.
    """

    num_desired_examples = len(full_storm_id_strings)

    if len(numpy.unique(storm_times_unix_sec)) == 1:
        indices_in_dict = numpy.where(prediction_dict[
            prediction_io.STORM_TIMES_KEY] == storm_times_unix_sec[0])[0]

        if len(indices_in_dict) == 0:
            return numpy.full(num_desired_examples, -1, dtype=int)

        subindices = tracking_utils.find_storm_objects(
            all_id_strings=[
                prediction_dict[prediction_io.STORM_IDS_KEY][k]
                for k in indices_in_dict
            ],
            all_times_unix_sec=prediction_dict[
                prediction_io.STORM_TIMES_KEY][indices_in_dict],
            id_strings_to_keep=full_storm_id_strings,
            times_to_keep_unix_sec=storm_times_unix_sec,
            allow_missing=allow_missing)

        indices_in_dict = numpy.array(
            [indices_in_dict[k] if k >= 0 else -1 for k in subindices],
            dtype=int)

        return indices_in_dict

    return tracking_utils.find_storm_objects(
        all_id_strings=prediction_dict[prediction_io.STORM_IDS_KEY],
        all_times_unix_sec=prediction_dict[prediction_io.STORM_TIMES_KEY],
        id_strings_to_keep=full_storm_id_strings,
        times_to_keep_unix_sec=storm_times_unix_sec,
        allow_missing=allow_missing)
    def test_find_storm_objects_0missing(self):
        """Ensures correct output from find_storm_objects.

        In this case, no desired storm objects are missing.
        """

        these_indices = tracking_utils.find_storm_objects(
            all_id_strings=ALL_STORM_ID_STRINGS,
            all_times_unix_sec=ALL_TIMES_UNIX_SEC,
            id_strings_to_keep=KEPT_ID_STRINGS_0MISSING,
            times_to_keep_unix_sec=KEPT_TIMES_UNIX_SEC_0MISSING,
            allow_missing=False)

        self.assertTrue(
            numpy.array_equal(these_indices, RELEVANT_INDICES_0MISSING))
    def test_find_storm_objects_none_missing(self):
        """Ensures correct output from find_storm_objects.

        In this case, no desired storm objects are missing.
        """

        these_indices = tracking_utils.find_storm_objects(
            all_storm_ids=ALL_STORM_IDS,
            all_times_unix_sec=ALL_TIMES_UNIX_SEC,
            storm_ids_to_keep=STORM_IDS_TO_KEEP_NONE_MISSING,
            times_to_keep_unix_sec=TIMES_TO_KEEP_UNIX_SEC_NONE_MISSING,
            allow_missing=False)

        self.assertTrue(
            numpy.array_equal(these_indices, RELEVANT_INDICES_NONE_MISSING))
def _read_storm_locations_one_time(
        top_tracking_dir_name, valid_time_unix_sec, desired_full_id_strings):
    """Reads storm locations at one time.

    K = number of storm objects desired

    :param top_tracking_dir_name: See documentation at top of file.
    :param valid_time_unix_sec: Valid time.
    :param desired_full_id_strings: length-K list of full storm IDs.  Locations
        will be read for these storms only.
    :return: desired_latitudes_deg: length-K numpy array of latitudes (deg N).
    :return: desired_longitudes_deg: length-K numpy array of longitudes (deg E).
    """

    spc_date_string = time_conversion.time_to_spc_date_string(
        valid_time_unix_sec)
    desired_times_unix_sec = numpy.full(
        len(desired_full_id_strings), valid_time_unix_sec, dtype=int
    )

    tracking_file_name = tracking_io.find_file(
        top_tracking_dir_name=top_tracking_dir_name,
        tracking_scale_metres2=DUMMY_TRACKING_SCALE_METRES2,
        source_name=tracking_utils.SEGMOTION_NAME,
        valid_time_unix_sec=valid_time_unix_sec,
        spc_date_string=spc_date_string, raise_error_if_missing=True)

    print('Reading storm locations from: "{0:s}"...'.format(tracking_file_name))
    storm_object_table = tracking_io.read_file(tracking_file_name)

    desired_indices = tracking_utils.find_storm_objects(
        all_id_strings=storm_object_table[
            tracking_utils.FULL_ID_COLUMN].values.tolist(),
        all_times_unix_sec=storm_object_table[
            tracking_utils.VALID_TIME_COLUMN].values,
        id_strings_to_keep=desired_full_id_strings,
        times_to_keep_unix_sec=desired_times_unix_sec, allow_missing=False)

    desired_latitudes_deg = storm_object_table[
        tracking_utils.CENTROID_LATITUDE_COLUMN].values[desired_indices]
    desired_longitudes_deg = storm_object_table[
        tracking_utils.CENTROID_LONGITUDE_COLUMN].values[desired_indices]

    return desired_latitudes_deg, desired_longitudes_deg
def _plot_one_example(full_id_string,
                      storm_time_unix_sec,
                      target_name,
                      forecast_probability,
                      tornado_dir_name,
                      top_tracking_dir_name,
                      top_myrorss_dir_name,
                      radar_field_name,
                      radar_height_m_asl,
                      latitude_buffer_deg,
                      longitude_buffer_deg,
                      top_output_dir_name,
                      aux_forecast_probabilities=None,
                      aux_activation_dict=None):
    """Plots one example with surrounding context at several times.

    N = number of storm objects read from auxiliary activation file

    :param full_id_string: Full storm ID.
    :param storm_time_unix_sec: Storm time.
    :param target_name: Name of target variable.
    :param forecast_probability: Forecast tornado probability for this example.
    :param tornado_dir_name: See documentation at top of file.
    :param top_tracking_dir_name: Same.
    :param top_myrorss_dir_name: Same.
    :param radar_field_name: Same.
    :param radar_height_m_asl: Same.
    :param latitude_buffer_deg: Same.
    :param longitude_buffer_deg: Same.
    :param top_output_dir_name: Same.
    :param aux_forecast_probabilities: length-N numpy array of forecast
        probabilities.  If this is None, will not plot forecast probs in maps.
    :param aux_activation_dict: Dictionary returned by
        `model_activation.read_file` from auxiliary file.  If this is None, will
        not plot forecast probs in maps.
    """

    storm_time_string = time_conversion.unix_sec_to_string(
        storm_time_unix_sec, TIME_FORMAT)

    # Create output directory for this example.
    output_dir_name = '{0:s}/{1:s}_{2:s}'.format(top_output_dir_name,
                                                 full_id_string,
                                                 storm_time_string)
    file_system_utils.mkdir_recursive_if_necessary(
        directory_name=output_dir_name)

    # Find tracking files.
    tracking_file_names = _find_tracking_files_one_example(
        valid_time_unix_sec=storm_time_unix_sec,
        top_tracking_dir_name=top_tracking_dir_name,
        target_name=target_name)

    tracking_times_unix_sec = numpy.array(
        [tracking_io.file_name_to_time(f) for f in tracking_file_names],
        dtype=int)

    tracking_time_strings = [
        time_conversion.unix_sec_to_string(t, TIME_FORMAT)
        for t in tracking_times_unix_sec
    ]

    # Read tracking files.
    storm_object_table = tracking_io.read_many_files(tracking_file_names)
    print('\n')

    if aux_activation_dict is not None:
        these_indices = tracking_utils.find_storm_objects(
            all_id_strings=aux_activation_dict[model_activation.FULL_IDS_KEY],
            all_times_unix_sec=aux_activation_dict[
                model_activation.STORM_TIMES_KEY],
            id_strings_to_keep=storm_object_table[
                tracking_utils.FULL_ID_COLUMN].values.tolist(),
            times_to_keep_unix_sec=storm_object_table[
                tracking_utils.VALID_TIME_COLUMN].values,
            allow_missing=True)

        storm_object_probs = numpy.array([
            aux_forecast_probabilities[k] if k >= 0 else numpy.nan
            for k in these_indices
        ])

        storm_object_table = storm_object_table.assign(
            **{FORECAST_PROBABILITY_COLUMN: storm_object_probs})

    primary_id_string = temporal_tracking.full_to_partial_ids([full_id_string
                                                               ])[0][0]

    this_storm_object_table = storm_object_table.loc[storm_object_table[
        tracking_utils.PRIMARY_ID_COLUMN] == primary_id_string]

    latitude_limits_deg, longitude_limits_deg = _get_plotting_limits(
        storm_object_table=this_storm_object_table,
        latitude_buffer_deg=latitude_buffer_deg,
        longitude_buffer_deg=longitude_buffer_deg)

    storm_min_latitudes_deg = numpy.array([
        numpy.min(numpy.array(p.exterior.xy[1])) for p in storm_object_table[
            tracking_utils.LATLNG_POLYGON_COLUMN].values
    ])

    storm_max_latitudes_deg = numpy.array([
        numpy.max(numpy.array(p.exterior.xy[1])) for p in storm_object_table[
            tracking_utils.LATLNG_POLYGON_COLUMN].values
    ])

    storm_min_longitudes_deg = numpy.array([
        numpy.min(numpy.array(p.exterior.xy[0])) for p in storm_object_table[
            tracking_utils.LATLNG_POLYGON_COLUMN].values
    ])

    storm_max_longitudes_deg = numpy.array([
        numpy.max(numpy.array(p.exterior.xy[0])) for p in storm_object_table[
            tracking_utils.LATLNG_POLYGON_COLUMN].values
    ])

    min_latitude_flags = numpy.logical_and(
        storm_min_latitudes_deg >= latitude_limits_deg[0],
        storm_min_latitudes_deg <= latitude_limits_deg[1])

    max_latitude_flags = numpy.logical_and(
        storm_max_latitudes_deg >= latitude_limits_deg[0],
        storm_max_latitudes_deg <= latitude_limits_deg[1])

    latitude_flags = numpy.logical_or(min_latitude_flags, max_latitude_flags)

    min_longitude_flags = numpy.logical_and(
        storm_min_longitudes_deg >= longitude_limits_deg[0],
        storm_min_longitudes_deg <= longitude_limits_deg[1])

    max_longitude_flags = numpy.logical_and(
        storm_max_longitudes_deg >= longitude_limits_deg[0],
        storm_max_longitudes_deg <= longitude_limits_deg[1])

    longitude_flags = numpy.logical_or(min_longitude_flags,
                                       max_longitude_flags)
    good_indices = numpy.where(
        numpy.logical_and(latitude_flags, longitude_flags))[0]

    storm_object_table = storm_object_table.iloc[good_indices]

    # Read tornado reports.
    target_param_dict = target_val_utils.target_name_to_params(target_name)
    min_lead_time_seconds = target_param_dict[
        target_val_utils.MIN_LEAD_TIME_KEY]
    max_lead_time_seconds = target_param_dict[
        target_val_utils.MAX_LEAD_TIME_KEY]

    tornado_table = linkage._read_input_tornado_reports(
        input_directory_name=tornado_dir_name,
        storm_times_unix_sec=numpy.array([storm_time_unix_sec], dtype=int),
        max_time_before_storm_start_sec=-1 * min_lead_time_seconds,
        max_time_after_storm_end_sec=max_lead_time_seconds,
        genesis_only=True)

    tornado_table = tornado_table.loc[
        (tornado_table[linkage.EVENT_LATITUDE_COLUMN] >= latitude_limits_deg[0]
         )
        & (tornado_table[linkage.EVENT_LATITUDE_COLUMN] <=
           latitude_limits_deg[1])]

    tornado_table = tornado_table.loc[
        (tornado_table[linkage.EVENT_LONGITUDE_COLUMN] >=
         longitude_limits_deg[0])
        & (tornado_table[linkage.EVENT_LONGITUDE_COLUMN] <=
           longitude_limits_deg[1])]

    for i in range(len(tracking_file_names)):
        this_storm_object_table = storm_object_table.loc[storm_object_table[
            tracking_utils.VALID_TIME_COLUMN] == tracking_times_unix_sec[i]]

        _plot_one_example_one_time(
            storm_object_table=this_storm_object_table,
            full_id_string=full_id_string,
            valid_time_unix_sec=tracking_times_unix_sec[i],
            tornado_table=copy.deepcopy(tornado_table),
            top_myrorss_dir_name=top_myrorss_dir_name,
            radar_field_name=radar_field_name,
            radar_height_m_asl=radar_height_m_asl,
            latitude_limits_deg=latitude_limits_deg,
            longitude_limits_deg=longitude_limits_deg)

        if aux_activation_dict is None:
            this_title_string = (
                'Valid time = {0:s} ... forecast prob at {1:s} = {2:.3f}'
            ).format(tracking_time_strings[i], storm_time_string,
                     forecast_probability)

            pyplot.title(this_title_string, fontsize=TITLE_FONT_SIZE)

        this_file_name = '{0:s}/{1:s}.jpg'.format(output_dir_name,
                                                  tracking_time_strings[i])

        print('Saving figure to file: "{0:s}"...\n'.format(this_file_name))
        pyplot.savefig(this_file_name, dpi=FIGURE_RESOLUTION_DPI)
        pyplot.close()

        imagemagick_utils.trim_whitespace(input_file_name=this_file_name,
                                          output_file_name=this_file_name)
Esempio n. 7
0
def _run(prediction_file_name, top_tracking_dir_name, prob_threshold,
         grid_spacing_metres, output_dir_name):
    """Plots spatial distribution of false alarms.

    This is effectively the main method.

    :param prediction_file_name: See documentation at top of file.
    :param top_tracking_dir_name: Same.
    :param prob_threshold: Same.
    :param grid_spacing_metres: Same.
    :param output_dir_name: Same.
    """

    # Process input args.
    file_system_utils.mkdir_recursive_if_necessary(
        directory_name=output_dir_name)
    error_checking.assert_is_greater(prob_threshold, 0.)
    error_checking.assert_is_less_than(prob_threshold, 1.)

    grid_metadata_dict = grids.create_equidistant_grid(
        min_latitude_deg=MIN_LATITUDE_DEG,
        max_latitude_deg=MAX_LATITUDE_DEG,
        min_longitude_deg=MIN_LONGITUDE_DEG,
        max_longitude_deg=MAX_LONGITUDE_DEG,
        x_spacing_metres=grid_spacing_metres,
        y_spacing_metres=grid_spacing_metres,
        azimuthal=False)

    # Read predictions and find positive forecasts and false alarms.
    print('Reading predictions from: "{0:s}"...'.format(prediction_file_name))
    prediction_dict = prediction_io.read_ungridded_predictions(
        prediction_file_name)

    observed_labels = prediction_dict[prediction_io.OBSERVED_LABELS_KEY]
    forecast_labels = (
        prediction_dict[prediction_io.PROBABILITY_MATRIX_KEY][:, -1] >=
        prob_threshold).astype(int)

    pos_forecast_indices = numpy.where(forecast_labels == 1)[0]
    false_alarm_indices = numpy.where(
        numpy.logical_and(observed_labels == 0, forecast_labels == 1))[0]

    num_examples = len(observed_labels)
    num_positive_forecasts = len(pos_forecast_indices)
    num_false_alarms = len(false_alarm_indices)

    print(('Probability threshold = {0:.3f} ... number of examples, positive '
           'forecasts, false alarms = {1:d}, {2:d}, {3:d}').format(
               prob_threshold, num_examples, num_positive_forecasts,
               num_false_alarms))

    # Find and read tracking files.
    pos_forecast_id_strings = [
        prediction_dict[prediction_io.STORM_IDS_KEY][k]
        for k in pos_forecast_indices
    ]
    pos_forecast_times_unix_sec = (
        prediction_dict[prediction_io.STORM_TIMES_KEY][pos_forecast_indices])

    file_times_unix_sec = numpy.unique(pos_forecast_times_unix_sec)
    num_files = len(file_times_unix_sec)
    storm_object_tables = [None] * num_files

    print(SEPARATOR_STRING)

    for i in range(num_files):
        this_tracking_file_name = tracking_io.find_file(
            top_tracking_dir_name=top_tracking_dir_name,
            tracking_scale_metres2=DUMMY_TRACKING_SCALE_METRES2,
            source_name=tracking_utils.SEGMOTION_NAME,
            valid_time_unix_sec=file_times_unix_sec[i],
            spc_date_string=time_conversion.time_to_spc_date_string(
                file_times_unix_sec[i]),
            raise_error_if_missing=True)

        print('Reading data from: "{0:s}"...'.format(this_tracking_file_name))
        this_table = tracking_io.read_file(this_tracking_file_name)
        storm_object_tables[i] = this_table.loc[this_table[
            tracking_utils.FULL_ID_COLUMN].isin(pos_forecast_id_strings)]

        if i == 0:
            continue

        storm_object_tables[i] = storm_object_tables[i].align(
            storm_object_tables[0], axis=1)[0]

    storm_object_table = pandas.concat(storm_object_tables,
                                       axis=0,
                                       ignore_index=True)
    print(SEPARATOR_STRING)

    # Find latitudes and longitudes of false alarms.
    all_id_strings = (
        storm_object_table[tracking_utils.FULL_ID_COLUMN].values.tolist())
    all_times_unix_sec = (
        storm_object_table[tracking_utils.VALID_TIME_COLUMN].values)
    good_indices = tracking_utils.find_storm_objects(
        all_id_strings=all_id_strings,
        all_times_unix_sec=all_times_unix_sec,
        id_strings_to_keep=pos_forecast_id_strings,
        times_to_keep_unix_sec=pos_forecast_times_unix_sec,
        allow_missing=False)

    pos_forecast_latitudes_deg = storm_object_table[
        tracking_utils.CENTROID_LATITUDE_COLUMN].values[good_indices]

    pos_forecast_longitudes_deg = storm_object_table[
        tracking_utils.CENTROID_LONGITUDE_COLUMN].values[good_indices]

    false_alarm_id_strings = [
        prediction_dict[prediction_io.STORM_IDS_KEY][k]
        for k in false_alarm_indices
    ]
    false_alarm_times_unix_sec = (
        prediction_dict[prediction_io.STORM_TIMES_KEY][false_alarm_indices])
    good_indices = tracking_utils.find_storm_objects(
        all_id_strings=all_id_strings,
        all_times_unix_sec=all_times_unix_sec,
        id_strings_to_keep=false_alarm_id_strings,
        times_to_keep_unix_sec=false_alarm_times_unix_sec,
        allow_missing=False)

    false_alarm_latitudes_deg = storm_object_table[
        tracking_utils.CENTROID_LATITUDE_COLUMN].values[good_indices]

    false_alarm_longitudes_deg = storm_object_table[
        tracking_utils.CENTROID_LONGITUDE_COLUMN].values[good_indices]

    pos_forecast_x_coords_metres, pos_forecast_y_coords_metres = (
        projections.project_latlng_to_xy(
            latitudes_deg=pos_forecast_latitudes_deg,
            longitudes_deg=pos_forecast_longitudes_deg,
            projection_object=grid_metadata_dict[grids.PROJECTION_KEY]))

    num_pos_forecasts_matrix = grids.count_events_on_equidistant_grid(
        event_x_coords_metres=pos_forecast_x_coords_metres,
        event_y_coords_metres=pos_forecast_y_coords_metres,
        grid_point_x_coords_metres=grid_metadata_dict[grids.X_COORDS_KEY],
        grid_point_y_coords_metres=grid_metadata_dict[grids.Y_COORDS_KEY])[0]
    print(SEPARATOR_STRING)

    false_alarm_x_coords_metres, false_alarm_y_coords_metres = (
        projections.project_latlng_to_xy(
            latitudes_deg=false_alarm_latitudes_deg,
            longitudes_deg=false_alarm_longitudes_deg,
            projection_object=grid_metadata_dict[grids.PROJECTION_KEY]))

    num_false_alarms_matrix = grids.count_events_on_equidistant_grid(
        event_x_coords_metres=false_alarm_x_coords_metres,
        event_y_coords_metres=false_alarm_y_coords_metres,
        grid_point_x_coords_metres=grid_metadata_dict[grids.X_COORDS_KEY],
        grid_point_y_coords_metres=grid_metadata_dict[grids.Y_COORDS_KEY])[0]
    print(SEPARATOR_STRING)

    num_pos_forecasts_matrix = num_pos_forecasts_matrix.astype(float)
    num_pos_forecasts_matrix[num_pos_forecasts_matrix == 0] = numpy.nan
    num_false_alarms_matrix = num_false_alarms_matrix.astype(float)
    num_false_alarms_matrix[num_false_alarms_matrix == 0] = numpy.nan
    far_matrix = num_false_alarms_matrix / num_pos_forecasts_matrix

    this_max_value = numpy.nanpercentile(num_false_alarms_matrix,
                                         MAX_COUNT_PERCENTILE_TO_PLOT)
    if this_max_value < 10:
        this_max_value = numpy.nanmax(num_false_alarms_matrix)

    figure_object = plotter._plot_one_value(
        data_matrix=num_false_alarms_matrix,
        grid_metadata_dict=grid_metadata_dict,
        colour_map_object=CMAP_OBJECT_FOR_COUNTS,
        min_colour_value=0,
        max_colour_value=this_max_value,
        plot_cbar_min_arrow=False,
        plot_cbar_max_arrow=True)[0]

    num_false_alarms_file_name = '{0:s}/num_false_alarms.jpg'.format(
        output_dir_name)

    print('Saving figure to: "{0:s}"...'.format(num_false_alarms_file_name))
    figure_object.savefig(num_false_alarms_file_name,
                          dpi=FIGURE_RESOLUTION_DPI,
                          pad_inches=0,
                          bbox_inches='tight')
    pyplot.close(figure_object)

    this_max_value = numpy.nanpercentile(num_pos_forecasts_matrix,
                                         MAX_COUNT_PERCENTILE_TO_PLOT)
    if this_max_value < 10:
        this_max_value = numpy.nanmax(num_pos_forecasts_matrix)

    figure_object = plotter._plot_one_value(
        data_matrix=num_pos_forecasts_matrix,
        grid_metadata_dict=grid_metadata_dict,
        colour_map_object=CMAP_OBJECT_FOR_COUNTS,
        min_colour_value=0,
        max_colour_value=this_max_value,
        plot_cbar_min_arrow=False,
        plot_cbar_max_arrow=True)[0]

    num_pos_forecasts_file_name = '{0:s}/num_positive_forecasts.jpg'.format(
        output_dir_name)

    print('Saving figure to: "{0:s}"...'.format(num_pos_forecasts_file_name))
    figure_object.savefig(num_pos_forecasts_file_name,
                          dpi=FIGURE_RESOLUTION_DPI,
                          pad_inches=0,
                          bbox_inches='tight')
    pyplot.close(figure_object)

    this_max_value = numpy.nanpercentile(far_matrix,
                                         MAX_FAR_PERCENTILE_TO_PLOT)
    this_min_value = numpy.nanpercentile(far_matrix,
                                         100. - MAX_FAR_PERCENTILE_TO_PLOT)

    figure_object = plotter._plot_one_value(
        data_matrix=far_matrix,
        grid_metadata_dict=grid_metadata_dict,
        colour_map_object=CMAP_OBJECT_FOR_FAR,
        min_colour_value=this_min_value,
        max_colour_value=this_max_value,
        plot_cbar_min_arrow=this_min_value > 0.,
        plot_cbar_max_arrow=this_max_value < 1.)[0]

    far_file_name = '{0:s}/false_alarm_ratio.jpg'.format(output_dir_name)

    print('Saving figure to: "{0:s}"...'.format(far_file_name))
    figure_object.savefig(far_file_name,
                          dpi=FIGURE_RESOLUTION_DPI,
                          pad_inches=0,
                          bbox_inches='tight')
    pyplot.close(figure_object)
Esempio n. 8
0
def _extract_storm_images(
        num_image_rows, num_image_columns, rotate_grids,
        rotated_grid_spacing_metres, radar_field_names, radar_heights_m_agl,
        spc_date_string, top_radar_dir_name, top_tracking_dir_name,
        elevation_dir_name, tracking_scale_metres2, target_name,
        top_target_dir_name, top_output_dir_name):
    """Extracts storm-centered radar images from GridRad data.

    :param num_image_rows: See documentation at top of file.
    :param num_image_columns: Same.
    :param rotate_grids: Same.
    :param rotated_grid_spacing_metres: Same.
    :param radar_field_names: Same.
    :param radar_heights_m_agl: Same.
    :param spc_date_string: Same.
    :param top_radar_dir_name: Same.
    :param top_tracking_dir_name: Same.
    :param elevation_dir_name: Same.
    :param tracking_scale_metres2: Same.
    :param target_name: Same.
    :param top_target_dir_name: Same.
    :param top_output_dir_name: Same.
    """

    if target_name in ['', 'None']:
        target_name = None

    if target_name is not None:
        target_param_dict = target_val_utils.target_name_to_params(target_name)

        target_file_name = target_val_utils.find_target_file(
            top_directory_name=top_target_dir_name,
            event_type_string=target_param_dict[
                target_val_utils.EVENT_TYPE_KEY],
            spc_date_string=spc_date_string)

        print('Reading data from: "{0:s}"...'.format(target_file_name))
        target_dict = target_val_utils.read_target_values(
            netcdf_file_name=target_file_name, target_names=[target_name]
        )

        print('\n')

    # Find storm objects on the given SPC date.
    tracking_file_names = tracking_io.find_files_one_spc_date(
        spc_date_string=spc_date_string,
        source_name=tracking_utils.SEGMOTION_NAME,
        top_tracking_dir_name=top_tracking_dir_name,
        tracking_scale_metres2=tracking_scale_metres2
    )[0]

    # Read storm objects on the given SPC date.
    storm_object_table = tracking_io.read_many_files(
        tracking_file_names
    )[storm_images.STORM_COLUMNS_NEEDED]

    print(SEPARATOR_STRING)

    if target_name is not None:
        print((
            'Removing storm objects without target values (variable = '
            '"{0:s}")...'
        ).format(target_name))

        these_indices = tracking_utils.find_storm_objects(
            all_id_strings=storm_object_table[
                tracking_utils.FULL_ID_COLUMN].values.tolist(),
            all_times_unix_sec=storm_object_table[
                tracking_utils.VALID_TIME_COLUMN].values.astype(int),
            id_strings_to_keep=target_dict[target_val_utils.FULL_IDS_KEY],
            times_to_keep_unix_sec=target_dict[
                target_val_utils.VALID_TIMES_KEY],
            allow_missing=False)

        num_storm_objects_orig = len(storm_object_table.index)
        storm_object_table = storm_object_table.iloc[these_indices]
        num_storm_objects = len(storm_object_table.index)

        print('Removed {0:d} of {1:d} storm objects!\n'.format(
            num_storm_objects_orig - num_storm_objects, num_storm_objects_orig
        ))

    # Extract storm-centered radar images.
    storm_images.extract_storm_images_gridrad(
        storm_object_table=storm_object_table,
        top_radar_dir_name=top_radar_dir_name,
        top_output_dir_name=top_output_dir_name,
        elevation_dir_name=elevation_dir_name,
        num_storm_image_rows=num_image_rows,
        num_storm_image_columns=num_image_columns, rotate_grids=rotate_grids,
        rotated_grid_spacing_metres=rotated_grid_spacing_metres,
        radar_field_names=radar_field_names,
        radar_heights_m_agl=radar_heights_m_agl)
Esempio n. 9
0
def _extract_storm_images(
        num_image_rows, num_image_columns, rotate_grids,
        rotated_grid_spacing_metres, radar_field_names, refl_heights_m_agl,
        spc_date_string, first_time_string, last_time_string,
        tarred_myrorss_dir_name, untarred_myrorss_dir_name,
        top_tracking_dir_name, elevation_dir_name, tracking_scale_metres2,
        target_name, top_target_dir_name, top_output_dir_name):
    """Extracts storm-centered img for each field/height pair and storm object.

    :param num_image_rows: See documentation at top of file.
    :param num_image_columns: Same.
    :param rotate_grids: Same.
    :param rotated_grid_spacing_metres: Same.
    :param radar_field_names: Same.
    :param refl_heights_m_agl: Same.
    :param spc_date_string: Same.
    :param first_time_string: Same.
    :param last_time_string: Same.
    :param tarred_myrorss_dir_name: Same.
    :param untarred_myrorss_dir_name: Same.
    :param top_tracking_dir_name: Same.
    :param elevation_dir_name: Same.
    :param tracking_scale_metres2: Same.
    :param target_name: Same.
    :param top_target_dir_name: Same.
    :param top_output_dir_name: Same.
    :raises: ValueError: if `first_time_string` and `last_time_string` have
        different SPC dates.
    """

    if elevation_dir_name in ['', 'None']:
        elevation_dir_name = None

    if elevation_dir_name is None:
        host_name = socket.gethostname()

        if 'casper' in host_name:
            elevation_dir_name = '/glade/work/ryanlage/elevation'
        else:
            elevation_dir_name = '/condo/swatwork/ralager/elevation'

    if spc_date_string in ['', 'None']:
        first_time_unix_sec = time_conversion.string_to_unix_sec(
            first_time_string, TIME_FORMAT)
        last_time_unix_sec = time_conversion.string_to_unix_sec(
            last_time_string, TIME_FORMAT)

        first_spc_date_string = time_conversion.time_to_spc_date_string(
            first_time_unix_sec)
        last_spc_date_string = time_conversion.time_to_spc_date_string(
            last_time_unix_sec)

        if first_spc_date_string != last_spc_date_string:
            error_string = (
                'First ({0:s}) and last ({1:s}) times have different SPC dates.'
                '  This script can handle only one SPC date.'
            ).format(first_time_string, last_time_string)

            raise ValueError(error_string)

        spc_date_string = first_spc_date_string
    else:
        first_time_unix_sec = 0
        last_time_unix_sec = int(1e12)

    if tarred_myrorss_dir_name in ['', 'None']:
        tarred_myrorss_dir_name = None
    if target_name in ['', 'None']:
        target_name = None

    if target_name is not None:
        target_param_dict = target_val_utils.target_name_to_params(target_name)

        target_file_name = target_val_utils.find_target_file(
            top_directory_name=top_target_dir_name,
            event_type_string=target_param_dict[
                target_val_utils.EVENT_TYPE_KEY],
            spc_date_string=spc_date_string)

        print('Reading data from: "{0:s}"...'.format(target_file_name))
        target_dict = target_val_utils.read_target_values(
            netcdf_file_name=target_file_name, target_names=[target_name]
        )
        print('\n')

    refl_heights_m_asl = radar_utils.get_valid_heights(
        data_source=radar_utils.MYRORSS_SOURCE_ID,
        field_name=radar_utils.REFL_NAME)

    # Untar files.
    if tarred_myrorss_dir_name is not None:
        az_shear_field_names = list(
            set(radar_field_names) & set(ALL_AZ_SHEAR_FIELD_NAMES)
        )

        if len(az_shear_field_names) > 0:
            az_shear_tar_file_name = (
                '{0:s}/{1:s}/azimuthal_shear_only/{2:s}.tar'
            ).format(
                tarred_myrorss_dir_name, spc_date_string[:4], spc_date_string
            )

            myrorss_io.unzip_1day_tar_file(
                tar_file_name=az_shear_tar_file_name,
                field_names=az_shear_field_names,
                spc_date_string=spc_date_string,
                top_target_directory_name=untarred_myrorss_dir_name)
            print(SEPARATOR_STRING)

        non_shear_field_names = list(
            set(radar_field_names) - set(ALL_AZ_SHEAR_FIELD_NAMES)
        )

        if len(non_shear_field_names) > 0:
            non_shear_tar_file_name = '{0:s}/{1:s}/{2:s}.tar'.format(
                tarred_myrorss_dir_name, spc_date_string[:4], spc_date_string
            )

            myrorss_io.unzip_1day_tar_file(
                tar_file_name=non_shear_tar_file_name,
                field_names=non_shear_field_names,
                spc_date_string=spc_date_string,
                top_target_directory_name=untarred_myrorss_dir_name,
                refl_heights_m_asl=refl_heights_m_asl)
            print(SEPARATOR_STRING)

    # Read storm tracks for the given SPC date.
    tracking_file_names = tracking_io.find_files_one_spc_date(
        spc_date_string=spc_date_string,
        source_name=tracking_utils.SEGMOTION_NAME,
        top_tracking_dir_name=top_tracking_dir_name,
        tracking_scale_metres2=tracking_scale_metres2
    )[0]

    file_times_unix_sec = numpy.array(
        [tracking_io.file_name_to_time(f) for f in tracking_file_names],
        dtype=int
    )

    good_indices = numpy.where(numpy.logical_and(
        file_times_unix_sec >= first_time_unix_sec,
        file_times_unix_sec <= last_time_unix_sec
    ))[0]

    tracking_file_names = [tracking_file_names[k] for k in good_indices]

    storm_object_table = tracking_io.read_many_files(
        tracking_file_names
    )[storm_images.STORM_COLUMNS_NEEDED]
    print(SEPARATOR_STRING)

    if target_name is not None:
        print((
            'Removing storm objects without target values (variable = '
            '"{0:s}")...'
        ).format(target_name))

        these_indices = tracking_utils.find_storm_objects(
            all_id_strings=storm_object_table[
                tracking_utils.FULL_ID_COLUMN].values.tolist(),
            all_times_unix_sec=storm_object_table[
                tracking_utils.VALID_TIME_COLUMN].values.astype(int),
            id_strings_to_keep=target_dict[target_val_utils.FULL_IDS_KEY],
            times_to_keep_unix_sec=target_dict[
                target_val_utils.VALID_TIMES_KEY],
            allow_missing=False)

        num_storm_objects_orig = len(storm_object_table.index)
        storm_object_table = storm_object_table.iloc[these_indices]
        num_storm_objects = len(storm_object_table.index)

        print('Removed {0:d} of {1:d} storm objects!\n'.format(
            num_storm_objects_orig - num_storm_objects, num_storm_objects_orig
        ))

    # Extract storm-centered radar images.
    storm_images.extract_storm_images_myrorss_or_mrms(
        storm_object_table=storm_object_table,
        radar_source=radar_utils.MYRORSS_SOURCE_ID,
        top_radar_dir_name=untarred_myrorss_dir_name,
        top_output_dir_name=top_output_dir_name,
        elevation_dir_name=elevation_dir_name,
        num_storm_image_rows=num_image_rows,
        num_storm_image_columns=num_image_columns, rotate_grids=rotate_grids,
        rotated_grid_spacing_metres=rotated_grid_spacing_metres,
        radar_field_names=radar_field_names,
        reflectivity_heights_m_agl=refl_heights_m_agl)
    print(SEPARATOR_STRING)

    # Remove untarred MYRORSS files.
    if tarred_myrorss_dir_name is not None:
        myrorss_io.remove_unzipped_data_1day(
            spc_date_string=spc_date_string,
            top_directory_name=untarred_myrorss_dir_name,
            field_names=radar_field_names,
            refl_heights_m_asl=refl_heights_m_asl)
Esempio n. 10
0
def _run(input_prediction_file_name, top_tracking_dir_name,
         tracking_scale_metres2, x_spacing_metres, y_spacing_metres,
         effective_radius_metres, smoothing_method_name,
         smoothing_cutoff_radius_metres, smoothing_efold_radius_metres,
         top_output_dir_name):
    """Projects CNN forecasts onto the RAP grid.

    This is effectively the same method.

    :param input_prediction_file_name: See documentation at top of file.
    :param top_tracking_dir_name: Same.
    :param tracking_scale_metres2: Same.
    :param x_spacing_metres: Same.
    :param y_spacing_metres: Same.
    :param effective_radius_metres: Same.
    :param smoothing_method_name: Same.
    :param smoothing_cutoff_radius_metres: Same.
    :param smoothing_efold_radius_metres: Same.
    :param top_output_dir_name: Same.
    """

    print('Reading data from: "{0:s}"...'.format(input_prediction_file_name))
    ungridded_forecast_dict = prediction_io.read_ungridded_predictions(
        input_prediction_file_name)

    target_param_dict = target_val_utils.target_name_to_params(
        ungridded_forecast_dict[prediction_io.TARGET_NAME_KEY])

    min_buffer_dist_metres = target_param_dict[
        target_val_utils.MIN_LINKAGE_DISTANCE_KEY]

    # TODO(thunderhoser): This is HACKY.
    if min_buffer_dist_metres == 0:
        min_buffer_dist_metres = numpy.nan

    max_buffer_dist_metres = target_param_dict[
        target_val_utils.MAX_LINKAGE_DISTANCE_KEY]

    min_lead_time_seconds = target_param_dict[
        target_val_utils.MIN_LEAD_TIME_KEY]

    max_lead_time_seconds = target_param_dict[
        target_val_utils.MAX_LEAD_TIME_KEY]

    forecast_column_name = gridded_forecasts._buffer_to_column_name(
        min_buffer_dist_metres=min_buffer_dist_metres,
        max_buffer_dist_metres=max_buffer_dist_metres,
        column_type=gridded_forecasts.FORECAST_COLUMN_TYPE)

    init_times_unix_sec = numpy.unique(
        ungridded_forecast_dict[prediction_io.STORM_TIMES_KEY])

    tracking_file_names = []

    for this_time_unix_sec in init_times_unix_sec:
        this_tracking_file_name = tracking_io.find_file(
            top_tracking_dir_name=top_tracking_dir_name,
            tracking_scale_metres2=tracking_scale_metres2,
            source_name=tracking_utils.SEGMOTION_NAME,
            valid_time_unix_sec=this_time_unix_sec,
            spc_date_string=time_conversion.time_to_spc_date_string(
                this_time_unix_sec),
            raise_error_if_missing=True)

        tracking_file_names.append(this_tracking_file_name)

    storm_object_table = tracking_io.read_many_files(tracking_file_names)
    print(SEPARATOR_STRING)

    tracking_utils.find_storm_objects(
        all_id_strings=ungridded_forecast_dict[prediction_io.STORM_IDS_KEY],
        all_times_unix_sec=ungridded_forecast_dict[
            prediction_io.STORM_TIMES_KEY],
        id_strings_to_keep=storm_object_table[
            tracking_utils.FULL_ID_COLUMN].values.tolist(),
        times_to_keep_unix_sec=storm_object_table[
            tracking_utils.VALID_TIME_COLUMN].values,
        allow_missing=False)

    sort_indices = tracking_utils.find_storm_objects(
        all_id_strings=storm_object_table[
            tracking_utils.FULL_ID_COLUMN].values.tolist(),
        all_times_unix_sec=storm_object_table[
            tracking_utils.VALID_TIME_COLUMN].values,
        id_strings_to_keep=ungridded_forecast_dict[
            prediction_io.STORM_IDS_KEY],
        times_to_keep_unix_sec=ungridded_forecast_dict[
            prediction_io.STORM_TIMES_KEY],
        allow_missing=False)

    forecast_probabilities = ungridded_forecast_dict[
        prediction_io.PROBABILITY_MATRIX_KEY][sort_indices, 1]

    storm_object_table = storm_object_table.assign(
        **{forecast_column_name: forecast_probabilities})

    gridded_forecast_dict = gridded_forecasts.create_forecast_grids(
        storm_object_table=storm_object_table,
        min_lead_time_sec=min_lead_time_seconds,
        max_lead_time_sec=max_lead_time_seconds,
        lead_time_resolution_sec=gridded_forecasts.
        DEFAULT_LEAD_TIME_RES_SECONDS,
        grid_spacing_x_metres=x_spacing_metres,
        grid_spacing_y_metres=y_spacing_metres,
        interp_to_latlng_grid=False,
        prob_radius_for_grid_metres=effective_radius_metres,
        smoothing_method=smoothing_method_name,
        smoothing_e_folding_radius_metres=smoothing_efold_radius_metres,
        smoothing_cutoff_radius_metres=smoothing_cutoff_radius_metres)

    print(SEPARATOR_STRING)

    output_file_name = prediction_io.find_file(
        top_prediction_dir_name=top_output_dir_name,
        first_init_time_unix_sec=numpy.min(
            storm_object_table[tracking_utils.VALID_TIME_COLUMN].values),
        last_init_time_unix_sec=numpy.max(
            storm_object_table[tracking_utils.VALID_TIME_COLUMN].values),
        gridded=True,
        raise_error_if_missing=False)

    print(('Writing results (forecast grids for {0:d} initial times) to: '
           '"{1:s}"...').format(
               len(gridded_forecast_dict[prediction_io.INIT_TIMES_KEY]),
               output_file_name))

    prediction_io.write_gridded_predictions(
        gridded_forecast_dict=gridded_forecast_dict,
        pickle_file_name=output_file_name)
def _filter_examples(trial_full_id_strings, trial_times_unix_sec,
                     num_trial_examples, baseline_full_id_strings,
                     baseline_times_unix_sec, num_baseline_examples,
                     num_novel_examples):
    """Filters trial and baseline examples (storm objects).

    T = original num trial examples
    t = desired num trial examples
    B = original num baseline examples
    b = desired num baseline examples

    :param trial_full_id_strings: length-T list of storm IDs.
    :param trial_times_unix_sec: length-T numpy array of storm times.
    :param num_trial_examples: t in the above discussion.  To keep all trial
        examples, make this non-positive.
    :param baseline_full_id_strings: length-B list of storm IDs.
    :param baseline_times_unix_sec: length-B numpy array of storm times.
    :param num_baseline_examples: b in the above discussion.  To keep all
        baseline examples, make this non-positive.
    :param num_novel_examples: Number of novel examples to find.
    :return: metadata_dict: Dictionary with the following keys.
    metadata_dict["trial_full_id_strings"]: length-t list of storm IDs.
    metadata_dict["trial_times_unix_sec"]: length-t numpy array of storm times.
    metadata_dict["baseline_full_id_strings"]: length-b list of storm IDs.
    metadata_dict["baseline_times_unix_sec"]: length-b numpy array of storm
        times.
    metadata_dict["num_novel_examples"]: Number of novel examples to find.
    """

    if 0 < num_trial_examples < len(trial_full_id_strings):
        trial_full_id_strings = trial_full_id_strings[:num_trial_examples]
        trial_times_unix_sec = trial_times_unix_sec[:num_trial_examples]

    num_trial_examples = len(trial_full_id_strings)
    if num_novel_examples <= 0:
        num_novel_examples = num_trial_examples + 0

    num_novel_examples = min([num_novel_examples, num_trial_examples])
    print('Number of novel examples to find: {0:d}'.format(num_novel_examples))

    bad_baseline_indices = tracking_utils.find_storm_objects(
        all_id_strings=baseline_full_id_strings,
        all_times_unix_sec=baseline_times_unix_sec,
        id_strings_to_keep=trial_full_id_strings,
        times_to_keep_unix_sec=trial_times_unix_sec,
        allow_missing=True)

    print('Removing {0:d} trial examples from baseline set...'.format(
        len(bad_baseline_indices)))

    baseline_times_unix_sec = numpy.delete(baseline_times_unix_sec,
                                           bad_baseline_indices)
    baseline_full_id_strings = numpy.delete(
        numpy.array(baseline_full_id_strings), bad_baseline_indices).tolist()

    if 0 < num_baseline_examples < len(baseline_full_id_strings):
        baseline_full_id_strings = baseline_full_id_strings[:
                                                            num_baseline_examples]
        baseline_times_unix_sec = baseline_times_unix_sec[:
                                                          num_baseline_examples]

    return {
        TRIAL_STORM_IDS_KEY: trial_full_id_strings,
        TRIAL_STORM_TIMES_KEY: trial_times_unix_sec,
        BASELINE_STORM_IDS_KEY: baseline_full_id_strings,
        BASELINE_STORM_TIMES_KEY: baseline_times_unix_sec,
        NUM_NOVEL_EXAMPLES_KEY: num_novel_examples
    }
Esempio n. 12
0
def _run(storm_metafile_name, top_tracking_dir_name, lead_time_seconds,
         output_file_name):
    """Plots spatial distribution of examples (storm objects) in file.

    This is effectively the main method.

    :param storm_metafile_name: See documentation at top of file.
    :param top_tracking_dir_name: Same.
    :param lead_time_seconds: Same.
    :param output_file_name: Same.
    """

    file_system_utils.mkdir_recursive_if_necessary(file_name=output_file_name)

    # Read storm metadata.
    print(
        'Reading storm metadata from: "{0:s}"...'.format(storm_metafile_name))
    orig_full_id_strings, orig_times_unix_sec = (
        tracking_io.read_ids_and_times(storm_metafile_name))
    orig_primary_id_strings = temporal_tracking.full_to_partial_ids(
        orig_full_id_strings)[0]

    # Find relevant tracking files.
    spc_date_strings = [
        time_conversion.time_to_spc_date_string(t) for t in orig_times_unix_sec
    ]
    spc_date_strings += [
        time_conversion.time_to_spc_date_string(t + lead_time_seconds)
        for t in orig_times_unix_sec
    ]
    spc_date_strings = list(set(spc_date_strings))

    tracking_file_names = []

    for this_spc_date_string in spc_date_strings:
        tracking_file_names += tracking_io.find_files_one_spc_date(
            top_tracking_dir_name=top_tracking_dir_name,
            tracking_scale_metres2=DUMMY_TRACKING_SCALE_METRES2,
            source_name=tracking_utils.SEGMOTION_NAME,
            spc_date_string=this_spc_date_string,
            raise_error_if_missing=False)[0]

    file_times_unix_sec = numpy.array(
        [tracking_io.file_name_to_time(f) for f in tracking_file_names],
        dtype=int)

    num_orig_storm_objects = len(orig_full_id_strings)
    num_files = len(file_times_unix_sec)
    keep_file_flags = numpy.full(num_files, 0, dtype=bool)

    for i in range(num_orig_storm_objects):
        these_flags = numpy.logical_and(
            file_times_unix_sec >= orig_times_unix_sec[i],
            file_times_unix_sec <= orig_times_unix_sec[i] + lead_time_seconds)
        keep_file_flags = numpy.logical_or(keep_file_flags, these_flags)

    del file_times_unix_sec
    keep_file_indices = numpy.where(keep_file_flags)[0]
    tracking_file_names = [tracking_file_names[k] for k in keep_file_indices]

    # Read relevant tracking files.
    num_files = len(tracking_file_names)
    storm_object_tables = [None] * num_files
    print(SEPARATOR_STRING)

    for i in range(num_files):
        print('Reading data from: "{0:s}"...'.format(tracking_file_names[i]))
        this_table = tracking_io.read_file(tracking_file_names[i])

        storm_object_tables[i] = this_table.loc[this_table[
            tracking_utils.PRIMARY_ID_COLUMN].isin(
                numpy.array(orig_primary_id_strings))]

        if i == 0:
            continue

        storm_object_tables[i] = storm_object_tables[i].align(
            storm_object_tables[0], axis=1)[0]

    storm_object_table = pandas.concat(storm_object_tables,
                                       axis=0,
                                       ignore_index=True)
    print(SEPARATOR_STRING)

    # Find relevant storm objects.
    orig_object_rows = tracking_utils.find_storm_objects(
        all_id_strings=storm_object_table[
            tracking_utils.FULL_ID_COLUMN].values.tolist(),
        all_times_unix_sec=storm_object_table[
            tracking_utils.VALID_TIME_COLUMN].values,
        id_strings_to_keep=orig_full_id_strings,
        times_to_keep_unix_sec=orig_times_unix_sec)

    good_object_rows = numpy.array([], dtype=int)

    for i in range(num_orig_storm_objects):
        # Non-merging successors only!

        first_rows = temporal_tracking.find_successors(
            storm_object_table=storm_object_table,
            target_row=orig_object_rows[i],
            num_seconds_forward=lead_time_seconds,
            max_num_sec_id_changes=1,
            change_type_string=temporal_tracking.SPLIT_STRING,
            return_all_on_path=True)

        second_rows = temporal_tracking.find_successors(
            storm_object_table=storm_object_table,
            target_row=orig_object_rows[i],
            num_seconds_forward=lead_time_seconds,
            max_num_sec_id_changes=0,
            change_type_string=temporal_tracking.MERGER_STRING,
            return_all_on_path=True)

        first_rows = first_rows.tolist()
        second_rows = second_rows.tolist()
        these_rows = set(first_rows) & set(second_rows)
        these_rows = numpy.array(list(these_rows), dtype=int)

        good_object_rows = numpy.concatenate((good_object_rows, these_rows))

    good_object_rows = numpy.unique(good_object_rows)
    storm_object_table = storm_object_table.iloc[good_object_rows]

    times_of_day_sec = numpy.mod(
        storm_object_table[tracking_utils.VALID_TIME_COLUMN].values,
        NUM_SECONDS_IN_DAY)
    storm_object_table = storm_object_table.assign(
        **{tracking_utils.VALID_TIME_COLUMN: times_of_day_sec})

    min_plot_latitude_deg = -LATLNG_BUFFER_DEG + numpy.min(
        storm_object_table[tracking_utils.CENTROID_LATITUDE_COLUMN].values)
    max_plot_latitude_deg = LATLNG_BUFFER_DEG + numpy.max(
        storm_object_table[tracking_utils.CENTROID_LATITUDE_COLUMN].values)
    min_plot_longitude_deg = -LATLNG_BUFFER_DEG + numpy.min(
        storm_object_table[tracking_utils.CENTROID_LONGITUDE_COLUMN].values)
    max_plot_longitude_deg = LATLNG_BUFFER_DEG + numpy.max(
        storm_object_table[tracking_utils.CENTROID_LONGITUDE_COLUMN].values)

    _, axes_object, basemap_object = (
        plotting_utils.create_equidist_cylindrical_map(
            min_latitude_deg=min_plot_latitude_deg,
            max_latitude_deg=max_plot_latitude_deg,
            min_longitude_deg=min_plot_longitude_deg,
            max_longitude_deg=max_plot_longitude_deg,
            resolution_string='i'))

    plotting_utils.plot_coastlines(basemap_object=basemap_object,
                                   axes_object=axes_object,
                                   line_colour=BORDER_COLOUR,
                                   line_width=BORDER_WIDTH * 2)
    plotting_utils.plot_countries(basemap_object=basemap_object,
                                  axes_object=axes_object,
                                  line_colour=BORDER_COLOUR,
                                  line_width=BORDER_WIDTH)
    plotting_utils.plot_states_and_provinces(basemap_object=basemap_object,
                                             axes_object=axes_object,
                                             line_colour=BORDER_COLOUR,
                                             line_width=BORDER_WIDTH)
    plotting_utils.plot_parallels(basemap_object=basemap_object,
                                  axes_object=axes_object,
                                  num_parallels=NUM_PARALLELS,
                                  line_width=BORDER_WIDTH)
    plotting_utils.plot_meridians(basemap_object=basemap_object,
                                  axes_object=axes_object,
                                  num_meridians=NUM_MERIDIANS,
                                  line_width=BORDER_WIDTH)

    # colour_bar_object = storm_plotting.plot_storm_tracks(
    #     storm_object_table=storm_object_table, axes_object=axes_object,
    #     basemap_object=basemap_object, colour_map_object=COLOUR_MAP_OBJECT,
    #     colour_min_unix_sec=0, colour_max_unix_sec=NUM_SECONDS_IN_DAY - 1,
    #     line_width=TRACK_LINE_WIDTH,
    #     start_marker_type=None, end_marker_type=None
    # )

    colour_bar_object = storm_plotting.plot_storm_centroids(
        storm_object_table=storm_object_table,
        axes_object=axes_object,
        basemap_object=basemap_object,
        colour_map_object=COLOUR_MAP_OBJECT,
        colour_min_unix_sec=0,
        colour_max_unix_sec=NUM_SECONDS_IN_DAY - 1)

    tick_times_unix_sec = numpy.linspace(0,
                                         NUM_SECONDS_IN_DAY,
                                         num=NUM_HOURS_IN_DAY + 1,
                                         dtype=int)
    tick_times_unix_sec = tick_times_unix_sec[:-1]
    tick_times_unix_sec = tick_times_unix_sec[::2]

    tick_time_strings = [
        time_conversion.unix_sec_to_string(t, COLOUR_BAR_TIME_FORMAT)
        for t in tick_times_unix_sec
    ]

    colour_bar_object.set_ticks(tick_times_unix_sec)
    colour_bar_object.set_ticklabels(tick_time_strings)

    print('Saving figure to: "{0:s}"...'.format(output_file_name))
    pyplot.savefig(output_file_name,
                   dpi=FIGURE_RESOLUTION_DPI,
                   pad_inches=0,
                   bbox_inches='tight')
    pyplot.close()
Esempio n. 13
0
def _run(top_input_dir_name, target_name, first_spc_date_string,
         last_spc_date_string, class_fraction_keys, class_fraction_values,
         for_training, top_output_dir_name):
    """Downsamples storm objects, based on target values.

    This is effectively the main method.

    :param top_input_dir_name: See documentation at top of file.
    :param target_name: Same.
    :param first_spc_date_string: Same.
    :param last_spc_date_string: Same.
    :param class_fraction_keys: Same.
    :param class_fraction_values: Same.
    :param for_training: Same.
    :param top_output_dir_name: Same.
    """

    spc_date_strings = time_conversion.get_spc_dates_in_range(
        first_spc_date_string=first_spc_date_string,
        last_spc_date_string=last_spc_date_string)

    class_fraction_dict = dict(zip(class_fraction_keys, class_fraction_values))
    target_param_dict = target_val_utils.target_name_to_params(target_name)

    input_target_file_names = []
    spc_date_string_by_file = []

    for this_spc_date_string in spc_date_strings:
        this_file_name = target_val_utils.find_target_file(
            top_directory_name=top_input_dir_name,
            event_type_string=target_param_dict[
                target_val_utils.EVENT_TYPE_KEY],
            spc_date_string=this_spc_date_string,
            raise_error_if_missing=False)
        if not os.path.isfile(this_file_name):
            continue

        input_target_file_names.append(this_file_name)
        spc_date_string_by_file.append(this_spc_date_string)

    num_files = len(input_target_file_names)
    spc_date_strings = spc_date_string_by_file
    target_dict_by_file = [None] * num_files

    storm_ids = []
    storm_times_unix_sec = numpy.array([], dtype=int)
    target_values = numpy.array([], dtype=int)

    for i in range(num_files):
        print 'Reading "{0:s}" from: "{1:s}"...'.format(
            target_name, input_target_file_names[i])
        target_dict_by_file[i] = target_val_utils.read_target_values(
            netcdf_file_name=input_target_file_names[i],
            target_name=target_name)

        storm_ids += target_dict_by_file[i][target_val_utils.STORM_IDS_KEY]
        storm_times_unix_sec = numpy.concatenate(
            (storm_times_unix_sec,
             target_dict_by_file[i][target_val_utils.VALID_TIMES_KEY]))
        target_values = numpy.concatenate(
            (target_values,
             target_dict_by_file[i][target_val_utils.TARGET_VALUES_KEY]))

    print SEPARATOR_STRING

    good_indices = numpy.where(
        target_values != target_val_utils.INVALID_STORM_INTEGER)[0]

    storm_ids = [storm_ids[k] for k in good_indices]
    storm_times_unix_sec = storm_times_unix_sec[good_indices]
    target_values = target_values[good_indices]

    if for_training:
        storm_ids, storm_times_unix_sec, target_values = (
            fancy_downsampling.downsample_for_training(
                storm_ids=storm_ids,
                storm_times_unix_sec=storm_times_unix_sec,
                target_values=target_values,
                target_name=target_name,
                class_fraction_dict=class_fraction_dict))
    else:
        storm_ids, storm_times_unix_sec, target_values = (
            fancy_downsampling.downsample_for_non_training(
                storm_ids=storm_ids,
                storm_times_unix_sec=storm_times_unix_sec,
                target_values=target_values,
                target_name=target_name,
                class_fraction_dict=class_fraction_dict))

    print SEPARATOR_STRING

    for i in range(num_files):
        these_indices = tracking_utils.find_storm_objects(
            all_storm_ids=target_dict_by_file[i][
                target_val_utils.STORM_IDS_KEY],
            all_times_unix_sec=target_dict_by_file[i][
                target_val_utils.VALID_TIMES_KEY],
            storm_ids_to_keep=storm_ids,
            times_to_keep_unix_sec=storm_times_unix_sec,
            allow_missing=True)

        these_indices = these_indices[these_indices >= 0]
        if len(these_indices) == 0:
            continue

        this_output_dict = {
            tracking_utils.STORM_ID_COLUMN: [
                target_dict_by_file[i][target_val_utils.STORM_IDS_KEY][k]
                for k in these_indices
            ],
            tracking_utils.TIME_COLUMN:
            target_dict_by_file[i][
                target_val_utils.VALID_TIMES_KEY][these_indices],
            target_name:
            target_dict_by_file[i][target_val_utils.TARGET_VALUES_KEY]
            [these_indices]
        }
        this_output_table = pandas.DataFrame.from_dict(this_output_dict)

        this_new_file_name = target_val_utils.find_target_file(
            top_directory_name=top_output_dir_name,
            event_type_string=target_param_dict[
                target_val_utils.EVENT_TYPE_KEY],
            spc_date_string=spc_date_strings[i],
            raise_error_if_missing=False)

        print(
            'Writing {0:d} downsampled storm objects (out of {1:d} total) to: '
            '"{2:s}"...').format(
                len(this_output_table.index),
                len(target_dict_by_file[i][target_val_utils.STORM_IDS_KEY]),
                this_new_file_name)

        target_val_utils.write_target_values(
            storm_to_events_table=this_output_table,
            target_names=[target_name],
            netcdf_file_name=this_new_file_name)
def _run(input_example_dir_name, storm_metafile_name, num_examples_in_subset,
         subset_randomly, output_example_file_name):
    """Extracts desired examples and writes them to one file.

    This is effectively the main method.

    :param input_example_dir_name: See documentation at top of file.
    :param storm_metafile_name: Same.
    :param num_examples_in_subset: Same.
    :param subset_randomly: Same.
    :param output_example_file_name: Same.
    """

    print(
        'Reading storm metadata from: "{0:s}"...'.format(storm_metafile_name))
    example_id_strings, example_times_unix_sec = (
        tracking_io.read_ids_and_times(storm_metafile_name))

    if not 0 < num_examples_in_subset < len(example_id_strings):
        num_examples_in_subset = None

    if num_examples_in_subset is not None:
        if subset_randomly:
            these_indices = numpy.linspace(0,
                                           len(example_id_strings) - 1,
                                           num=len(example_id_strings),
                                           dtype=int)
            these_indices = numpy.random.choice(these_indices,
                                                size=num_examples_in_subset,
                                                replace=False)

            example_id_strings = [example_id_strings[k] for k in these_indices]
            example_times_unix_sec = example_times_unix_sec[these_indices]
        else:
            example_id_strings = example_id_strings[:num_examples_in_subset]
            example_times_unix_sec = (
                example_times_unix_sec[:num_examples_in_subset])

    example_spc_date_strings = numpy.array([
        time_conversion.time_to_spc_date_string(t)
        for t in example_times_unix_sec
    ])
    spc_date_strings = numpy.unique(example_spc_date_strings)

    example_file_name_by_day = [
        input_examples.find_example_file(
            top_directory_name=input_example_dir_name,
            shuffled=False,
            spc_date_string=d,
            raise_error_if_missing=True) for d in spc_date_strings
    ]

    num_days = len(spc_date_strings)

    for i in range(num_days):
        print('Reading data from: "{0:s}"...'.format(
            example_file_name_by_day[i]))
        all_example_dict = input_examples.read_example_file(
            netcdf_file_name=example_file_name_by_day[i],
            read_all_target_vars=True)

        these_indices = numpy.where(
            example_spc_date_strings == spc_date_strings[i])[0]

        desired_indices = tracking_utils.find_storm_objects(
            all_id_strings=all_example_dict[input_examples.FULL_IDS_KEY],
            all_times_unix_sec=all_example_dict[
                input_examples.STORM_TIMES_KEY],
            id_strings_to_keep=[example_id_strings[k] for k in these_indices],
            times_to_keep_unix_sec=example_times_unix_sec[these_indices],
            allow_missing=False)

        desired_example_dict = input_examples.subset_examples(
            example_dict=all_example_dict, indices_to_keep=desired_indices)

        print('Writing {0:d} desired examples to: "{1:s}"...'.format(
            len(desired_indices), output_example_file_name))
        input_examples.write_example_file(
            netcdf_file_name=output_example_file_name,
            example_dict=desired_example_dict,
            append_to_file=i > 0)
def _read_new_target_values(top_target_dir_name, new_target_name,
                            full_storm_id_strings, storm_times_unix_sec,
                            orig_target_values):
    """Reads new target values (for upgraded minimum EF rating).

    E = number of examples (storm objects)

    :param top_target_dir_name: See documentation at top of file.
    :param new_target_name: Name of new target variable (with upgraded minimum EF
        rating).
    :param full_storm_id_strings: length-E list of storm IDs.
    :param storm_times_unix_sec: length-E numpy array of valid times.
    :param orig_target_values: length-E numpy array of original target values
        (for original minimum EF rating), all integers in 0...1.
    :return: new_target_values: length-E numpy array of new target values
        (integers in -1...1).  -1 means that increasing minimum EF rating
        flipped the value from 1 to 0.
    """

    storm_spc_date_strings = numpy.array([
        time_conversion.time_to_spc_date_string(t)
        for t in storm_times_unix_sec
    ])
    unique_spc_date_strings = numpy.unique(storm_spc_date_strings)

    event_type_string = target_val_utils.target_name_to_params(
        new_target_name)[target_val_utils.EVENT_TYPE_KEY]

    num_spc_dates = len(unique_spc_date_strings)
    num_storm_objects = len(full_storm_id_strings)
    new_target_values = numpy.full(num_storm_objects, numpy.nan)

    for i in range(num_spc_dates):
        this_target_file_name = target_val_utils.find_target_file(
            top_directory_name=top_target_dir_name,
            event_type_string=event_type_string,
            spc_date_string=unique_spc_date_strings[i])

        print('Reading data from: "{0:s}"...'.format(this_target_file_name))
        this_target_value_dict = target_val_utils.read_target_values(
            netcdf_file_name=this_target_file_name,
            target_names=[new_target_name])

        these_storm_indices = numpy.where(
            storm_spc_date_strings == unique_spc_date_strings[i])[0]

        these_target_indices = tracking_utils.find_storm_objects(
            all_id_strings=this_target_value_dict[
                target_val_utils.FULL_IDS_KEY],
            all_times_unix_sec=this_target_value_dict[
                target_val_utils.VALID_TIMES_KEY],
            id_strings_to_keep=[
                full_storm_id_strings[k] for k in these_storm_indices
            ],
            times_to_keep_unix_sec=storm_times_unix_sec[these_storm_indices],
            allow_missing=False)

        new_target_values[these_storm_indices] = this_target_value_dict[
            target_val_utils.TARGET_MATRIX_KEY][these_target_indices, 0]

    assert not numpy.any(numpy.isnan(new_target_values))
    new_target_values = numpy.round(new_target_values).astype(int)

    bad_indices = numpy.where(new_target_values != orig_target_values)[0]
    print(('\n{0:d} of {1:d} new target values do not match original value.'
           ).format(len(bad_indices), num_storm_objects))

    new_target_values[bad_indices] = -1
    return new_target_values
Esempio n. 16
0
def myrorss_generator_2d3d(option_dict, num_examples_total):
    """Generates examples with both 2-D and 3-D radar images.

    Each example (storm object) consists of the following:

    - Storm-centered azimuthal shear (one 2-D image for each field)
    - Storm-centered reflectivity (one 3-D image)
    - Storm-centered sounding (optional)
    - Target value (class)

    :param option_dict: Dictionary with the following keys.
    option_dict['example_file_names']: See doc for
        `training_validation_io.myrorss_generator_2d3d`.
    option_dict['binarize_target']: Same.
    option_dict['radar_field_names']: Same.
    option_dict['radar_heights_m_agl']: Same.
    option_dict['sounding_field_names']: Same.
    option_dict['sounding_heights_m_agl']: Same.
    option_dict['first_storm_time_unix_sec']: Same.
    option_dict['last_storm_time_unix_sec']: Same.
    option_dict['num_grid_rows']: Same.
    option_dict['num_grid_columns']: Same.
    option_dict['normalization_type_string']: See doc for `generator_2d_or_3d`.
    option_dict['normalization_param_file_name']: Same.
    option_dict['min_normalized_value']: Same.
    option_dict['max_normalized_value']: Same.
    option_dict['class_to_sampling_fraction_dict']: Same.

    :param num_examples_total: Total number of examples to generate.

    :return: storm_object_dict: Dictionary with the following keys.
    storm_object_dict['list_of_input_matrices']: length-T list of numpy arrays,
        where T = number of input tensors to model.  The first axis of each
        array has length E.
    storm_object_dict['storm_ids']: length-E list of storm IDs.
    storm_object_dict['storm_times_unix_sec']: length-E numpy array of storm
        times.
    storm_object_dict['target_array']: See output doc for
        `training_validation_io.myrorss_generator_2d3d`.
    storm_object_dict['sounding_pressure_matrix_pascals']: numpy array (E x H_s)
        of pressures.  If soundings were not read, this is None.
    """

    storm_ids, storm_times_unix_sec = _find_examples_to_read(
        option_dict=option_dict, num_examples_total=num_examples_total)
    print '\n'

    example_file_names = option_dict[trainval_io.EXAMPLE_FILES_KEY]

    first_storm_time_unix_sec = option_dict[trainval_io.FIRST_STORM_TIME_KEY]
    last_storm_time_unix_sec = option_dict[trainval_io.LAST_STORM_TIME_KEY]
    num_grid_rows = option_dict[trainval_io.NUM_ROWS_KEY]
    num_grid_columns = option_dict[trainval_io.NUM_COLUMNS_KEY]
    azimuthal_shear_field_names = option_dict[trainval_io.RADAR_FIELDS_KEY]
    reflectivity_heights_m_agl = option_dict[trainval_io.RADAR_HEIGHTS_KEY]
    sounding_field_names = option_dict[trainval_io.SOUNDING_FIELDS_KEY]
    sounding_heights_m_agl = option_dict[trainval_io.SOUNDING_HEIGHTS_KEY]

    normalization_type_string = option_dict[trainval_io.NORMALIZATION_TYPE_KEY]
    normalization_param_file_name = option_dict[
        trainval_io.NORMALIZATION_FILE_KEY]
    min_normalized_value = option_dict[trainval_io.MIN_NORMALIZED_VALUE_KEY]
    max_normalized_value = option_dict[trainval_io.MAX_NORMALIZED_VALUE_KEY]

    binarize_target = option_dict[trainval_io.BINARIZE_TARGET_KEY]

    this_example_dict = input_examples.read_example_file(
        netcdf_file_name=example_file_names[0], metadata_only=True)
    target_name = this_example_dict[input_examples.TARGET_NAME_KEY]

    num_classes = target_val_utils.target_name_to_num_classes(
        target_name=target_name, include_dead_storms=False)

    if sounding_field_names is None:
        sounding_field_names_to_read = None
    else:
        if soundings.PRESSURE_NAME in sounding_field_names:
            sounding_field_names_to_read = sounding_field_names + []
        else:
            sounding_field_names_to_read = (
                sounding_field_names + [soundings.PRESSURE_NAME]
            )

    reflectivity_image_matrix_dbz = None
    az_shear_image_matrix_s01 = None
    sounding_matrix = None
    target_values = None
    sounding_pressure_matrix_pascals = None
    file_index = 0

    while True:
        if file_index >= len(example_file_names):
            raise StopIteration

        print 'Reading data from: "{0:s}"...'.format(
            example_file_names[file_index])

        this_example_dict = input_examples.read_example_file(
            netcdf_file_name=example_file_names[file_index],
            include_soundings=sounding_field_names is not None,
            radar_field_names_to_keep=azimuthal_shear_field_names,
            radar_heights_to_keep_m_agl=reflectivity_heights_m_agl,
            sounding_field_names_to_keep=sounding_field_names_to_read,
            sounding_heights_to_keep_m_agl=sounding_heights_m_agl,
            first_time_to_keep_unix_sec=first_storm_time_unix_sec,
            last_time_to_keep_unix_sec=last_storm_time_unix_sec,
            num_rows_to_keep=num_grid_rows,
            num_columns_to_keep=num_grid_columns)

        file_index += 1
        if this_example_dict is None:
            continue

        indices_to_keep = tracking_utils.find_storm_objects(
            all_storm_ids=this_example_dict[input_examples.STORM_IDS_KEY],
            all_times_unix_sec=this_example_dict[
                input_examples.STORM_TIMES_KEY],
            storm_ids_to_keep=storm_ids,
            times_to_keep_unix_sec=storm_times_unix_sec, allow_missing=True)

        indices_to_keep = indices_to_keep[indices_to_keep >= 0]
        if len(indices_to_keep) == 0:
            continue

        this_example_dict = input_examples.subset_examples(
            example_dict=this_example_dict, indices_to_keep=indices_to_keep)

        include_soundings = (
            input_examples.SOUNDING_MATRIX_KEY in this_example_dict)

        if include_soundings:
            pressure_index = this_example_dict[
                input_examples.SOUNDING_FIELDS_KEY
            ].index(soundings.PRESSURE_NAME)

            this_pressure_matrix_pascals = this_example_dict[
                input_examples.SOUNDING_MATRIX_KEY][..., pressure_index]

            this_sounding_matrix = this_example_dict[
                input_examples.SOUNDING_MATRIX_KEY]
            if soundings.PRESSURE_NAME not in sounding_field_names:
                this_sounding_matrix = this_sounding_matrix[..., -1]

        if target_values is None:
            reflectivity_image_matrix_dbz = (
                this_example_dict[input_examples.REFL_IMAGE_MATRIX_KEY] + 0.
            )
            az_shear_image_matrix_s01 = (
                this_example_dict[input_examples.AZ_SHEAR_IMAGE_MATRIX_KEY]
                + 0.
            )
            target_values = (
                this_example_dict[input_examples.TARGET_VALUES_KEY] + 0)

            if include_soundings:
                sounding_matrix = this_sounding_matrix + 0.
                sounding_pressure_matrix_pascals = (
                    this_pressure_matrix_pascals + 0.)
        else:
            reflectivity_image_matrix_dbz = numpy.concatenate(
                (reflectivity_image_matrix_dbz,
                 this_example_dict[input_examples.REFL_IMAGE_MATRIX_KEY]),
                axis=0)
            az_shear_image_matrix_s01 = numpy.concatenate((
                az_shear_image_matrix_s01,
                this_example_dict[input_examples.AZ_SHEAR_IMAGE_MATRIX_KEY]
            ), axis=0)
            target_values = numpy.concatenate((
                target_values,
                this_example_dict[input_examples.TARGET_VALUES_KEY]
            ))

            if include_soundings:
                sounding_matrix = numpy.concatenate(
                    (sounding_matrix, this_sounding_matrix), axis=0)
                sounding_pressure_matrix_pascals = numpy.concatenate(
                    (sounding_pressure_matrix_pascals,
                     this_pressure_matrix_pascals), axis=0)

        if normalization_type_string is not None:
            reflectivity_image_matrix_dbz = dl_utils.normalize_radar_images(
                radar_image_matrix=reflectivity_image_matrix_dbz,
                field_names=[radar_utils.REFL_NAME],
                normalization_type_string=normalization_type_string,
                normalization_param_file_name=normalization_param_file_name,
                min_normalized_value=min_normalized_value,
                max_normalized_value=max_normalized_value).astype('float32')

            az_shear_image_matrix_s01 = dl_utils.normalize_radar_images(
                radar_image_matrix=az_shear_image_matrix_s01,
                field_names=azimuthal_shear_field_names,
                normalization_type_string=normalization_type_string,
                normalization_param_file_name=normalization_param_file_name,
                min_normalized_value=min_normalized_value,
                max_normalized_value=max_normalized_value).astype('float32')

            if include_soundings:
                sounding_matrix = dl_utils.normalize_soundings(
                    sounding_matrix=sounding_matrix,
                    field_names=sounding_field_names,
                    normalization_type_string=normalization_type_string,
                    normalization_param_file_name=normalization_param_file_name,
                    min_normalized_value=min_normalized_value,
                    max_normalized_value=max_normalized_value).astype('float32')

        list_of_predictor_matrices = [
            reflectivity_image_matrix_dbz, az_shear_image_matrix_s01
        ]
        if include_soundings:
            list_of_predictor_matrices.append(sounding_matrix)

        target_array = _finalize_targets(
            target_values=target_values, binarize_target=binarize_target,
            num_classes=num_classes)

        storm_object_dict = {
            INPUT_MATRICES_KEY: list_of_predictor_matrices,
            TARGET_ARRAY_KEY: target_array,
            STORM_IDS_KEY: this_example_dict[input_examples.STORM_IDS_KEY],
            STORM_TIMES_KEY: this_example_dict[input_examples.STORM_TIMES_KEY],
            SOUNDING_PRESSURES_KEY: sounding_pressure_matrix_pascals + 0.
        }

        reflectivity_image_matrix_dbz = None
        az_shear_image_matrix_s01 = None
        sounding_matrix = None
        target_values = None
        sounding_pressure_matrix_pascals = None

        yield storm_object_dict
Esempio n. 17
0
def gridrad_generator_2d_reduced(option_dict, list_of_operation_dicts,
                                 num_examples_total):
    """Generates examples with 2-D GridRad images.

    These 2-D images are produced by applying layer operations to the native 3-D
    images.  The layer operations are specified by `list_of_operation_dicts`.

    Each example (storm object) consists of the following:

    - Storm-centered radar images (one 2-D image for each layer operation)
    - Storm-centered sounding (optional)
    - Target value (class)

    :param option_dict: Dictionary with the following keys.
    option_dict['example_file_names']: See doc for
        `training_validation_io.gridrad_generator_2d_reduced`.
    option_dict['binarize_target']: Same.
    option_dict['sounding_field_names']: Same.
    option_dict['sounding_heights_m_agl']: Same.
    option_dict['first_storm_time_unix_sec']: Same.
    option_dict['last_storm_time_unix_sec']: Same.
    option_dict['num_grid_rows']: Same.
    option_dict['num_grid_columns']: Same.
    option_dict['normalization_type_string']: Same.
    option_dict['normalization_param_file_name']: Same.
    option_dict['min_normalized_value']: Same.
    option_dict['max_normalized_value']: Same.
    option_dict['class_to_sampling_fraction_dict']: Same.

    :param list_of_operation_dicts: See doc for
        `input_examples.reduce_examples_3d_to_2d`.
    :param num_examples_total: Number of examples to generate.

    :return: storm_object_dict: Dictionary with the following keys.
    storm_object_dict['list_of_input_matrices']: length-T list of numpy arrays,
        where T = number of input tensors to model.  The first axis of each
        array has length E.
    storm_object_dict['storm_ids']: length-E list of storm IDs.
    storm_object_dict['storm_times_unix_sec']: length-E numpy array of storm
        times.
    storm_object_dict['target_array']: See output doc for
        `training_validation_io.gridrad_generator_2d_reduced`.
    storm_object_dict['sounding_pressure_matrix_pascals']: numpy array (E x H_s)
        of pressures.  If soundings were not read, this is None.
    storm_object_dict['radar_field_names']: length-C list of field names, where
        the [j]th item corresponds to the [j]th channel of the 2-D radar images
        returned in "list_of_input_matrices".
    storm_object_dict['min_radar_heights_m_agl']: length-C numpy array with
        minimum height for each layer operation (used to reduce 3-D radar images
        to 2-D).
    storm_object_dict['max_radar_heights_m_agl']: Same but with max heights.
    storm_object_dict['radar_layer_operation_names']: length-C list with names
        of layer operations.  Each name must be accepted by
        `input_examples._check_layer_operation`.
    """

    unique_radar_field_names, unique_radar_heights_m_agl = (
        trainval_io.layer_ops_to_field_height_pairs(list_of_operation_dicts)
    )

    option_dict[trainval_io.RADAR_FIELDS_KEY] = unique_radar_field_names
    option_dict[trainval_io.RADAR_HEIGHTS_KEY] = unique_radar_heights_m_agl

    storm_ids, storm_times_unix_sec = _find_examples_to_read(
        option_dict=option_dict, num_examples_total=num_examples_total)
    print '\n'

    example_file_names = option_dict[trainval_io.EXAMPLE_FILES_KEY]

    first_storm_time_unix_sec = option_dict[trainval_io.FIRST_STORM_TIME_KEY]
    last_storm_time_unix_sec = option_dict[trainval_io.LAST_STORM_TIME_KEY]
    num_grid_rows = option_dict[trainval_io.NUM_ROWS_KEY]
    num_grid_columns = option_dict[trainval_io.NUM_COLUMNS_KEY]
    sounding_field_names = option_dict[trainval_io.SOUNDING_FIELDS_KEY]
    sounding_heights_m_agl = option_dict[trainval_io.SOUNDING_HEIGHTS_KEY]

    normalization_type_string = option_dict[trainval_io.NORMALIZATION_TYPE_KEY]
    normalization_param_file_name = option_dict[
        trainval_io.NORMALIZATION_FILE_KEY]
    min_normalized_value = option_dict[trainval_io.MIN_NORMALIZED_VALUE_KEY]
    max_normalized_value = option_dict[trainval_io.MAX_NORMALIZED_VALUE_KEY]

    binarize_target = option_dict[trainval_io.BINARIZE_TARGET_KEY]

    this_example_dict = input_examples.read_example_file(
        netcdf_file_name=example_file_names[0], metadata_only=True)
    target_name = this_example_dict[input_examples.TARGET_NAME_KEY]

    num_classes = target_val_utils.target_name_to_num_classes(
        target_name=target_name, include_dead_storms=False)

    if sounding_field_names is None:
        sounding_field_names_to_read = None
    else:
        if soundings.PRESSURE_NAME in sounding_field_names:
            sounding_field_names_to_read = sounding_field_names + []
        else:
            sounding_field_names_to_read = (
                sounding_field_names + [soundings.PRESSURE_NAME]
            )

    radar_image_matrix = None
    sounding_matrix = None
    target_values = None
    sounding_pressure_matrix_pascals = None

    reduction_metadata_dict = {}
    file_index = 0

    while True:
        if file_index >= len(example_file_names):
            raise StopIteration

        print 'Reading data from: "{0:s}"...'.format(
            example_file_names[file_index])

        this_example_dict = input_examples.read_example_file(
            netcdf_file_name=example_file_names[file_index],
            include_soundings=sounding_field_names is not None,
            radar_field_names_to_keep=unique_radar_field_names,
            radar_heights_to_keep_m_agl=unique_radar_heights_m_agl,
            sounding_field_names_to_keep=sounding_field_names_to_read,
            sounding_heights_to_keep_m_agl=sounding_heights_m_agl,
            first_time_to_keep_unix_sec=first_storm_time_unix_sec,
            last_time_to_keep_unix_sec=last_storm_time_unix_sec,
            num_rows_to_keep=num_grid_rows,
            num_columns_to_keep=num_grid_columns)

        file_index += 1
        if this_example_dict is None:
            continue

        indices_to_keep = tracking_utils.find_storm_objects(
            all_storm_ids=this_example_dict[input_examples.STORM_IDS_KEY],
            all_times_unix_sec=this_example_dict[
                input_examples.STORM_TIMES_KEY],
            storm_ids_to_keep=storm_ids,
            times_to_keep_unix_sec=storm_times_unix_sec, allow_missing=True)

        indices_to_keep = indices_to_keep[indices_to_keep >= 0]
        if len(indices_to_keep) == 0:
            continue

        this_example_dict = input_examples.subset_examples(
            example_dict=this_example_dict, indices_to_keep=indices_to_keep)

        this_example_dict = input_examples.reduce_examples_3d_to_2d(
            example_dict=this_example_dict,
            list_of_operation_dicts=list_of_operation_dicts)

        radar_field_names_2d = this_example_dict[
            input_examples.RADAR_FIELDS_KEY]
        for this_key in REDUCTION_METADATA_KEYS:
            reduction_metadata_dict[this_key] = this_example_dict[this_key]

        include_soundings = (
            input_examples.SOUNDING_MATRIX_KEY in this_example_dict)

        if include_soundings:
            pressure_index = this_example_dict[
                input_examples.SOUNDING_FIELDS_KEY
            ].index(soundings.PRESSURE_NAME)

            this_pressure_matrix_pascals = this_example_dict[
                input_examples.SOUNDING_MATRIX_KEY][..., pressure_index]

            this_sounding_matrix = this_example_dict[
                input_examples.SOUNDING_MATRIX_KEY]
            if soundings.PRESSURE_NAME not in sounding_field_names:
                this_sounding_matrix = this_sounding_matrix[..., :-1]

        if target_values is None:
            radar_image_matrix = (
                this_example_dict[input_examples.RADAR_IMAGE_MATRIX_KEY]
                + 0.
            )
            target_values = (
                this_example_dict[input_examples.TARGET_VALUES_KEY] + 0)

            if include_soundings:
                sounding_matrix = this_sounding_matrix + 0.
                sounding_pressure_matrix_pascals = (
                    this_pressure_matrix_pascals + 0.)
        else:
            radar_image_matrix = numpy.concatenate(
                (radar_image_matrix,
                 this_example_dict[input_examples.RADAR_IMAGE_MATRIX_KEY]),
                axis=0)
            target_values = numpy.concatenate((
                target_values,
                this_example_dict[input_examples.TARGET_VALUES_KEY]
            ))

            if include_soundings:
                sounding_matrix = numpy.concatenate(
                    (sounding_matrix, this_sounding_matrix), axis=0)
                sounding_pressure_matrix_pascals = numpy.concatenate(
                    (sounding_pressure_matrix_pascals,
                     this_pressure_matrix_pascals), axis=0)

        if normalization_type_string is not None:
            radar_image_matrix = dl_utils.normalize_radar_images(
                radar_image_matrix=radar_image_matrix,
                field_names=radar_field_names_2d,
                normalization_type_string=normalization_type_string,
                normalization_param_file_name=normalization_param_file_name,
                min_normalized_value=min_normalized_value,
                max_normalized_value=max_normalized_value).astype('float32')

            if include_soundings:
                sounding_matrix = dl_utils.normalize_soundings(
                    sounding_matrix=sounding_matrix,
                    field_names=sounding_field_names,
                    normalization_type_string=normalization_type_string,
                    normalization_param_file_name=normalization_param_file_name,
                    min_normalized_value=min_normalized_value,
                    max_normalized_value=max_normalized_value).astype('float32')

        list_of_predictor_matrices = [radar_image_matrix]
        if include_soundings:
            list_of_predictor_matrices.append(sounding_matrix)

        target_array = _finalize_targets(
            target_values=target_values, binarize_target=binarize_target,
            num_classes=num_classes)

        storm_object_dict = {
            INPUT_MATRICES_KEY: list_of_predictor_matrices,
            TARGET_ARRAY_KEY: target_array,
            STORM_IDS_KEY: this_example_dict[input_examples.STORM_IDS_KEY],
            STORM_TIMES_KEY: this_example_dict[input_examples.STORM_TIMES_KEY],
            SOUNDING_PRESSURES_KEY:
                copy.deepcopy(sounding_pressure_matrix_pascals)
        }

        for this_key in REDUCTION_METADATA_KEYS:
            storm_object_dict[this_key] = reduction_metadata_dict[this_key]

        radar_image_matrix = None
        sounding_matrix = None
        target_values = None
        sounding_pressure_matrix_pascals = None

        yield storm_object_dict
Esempio n. 18
0
def _run(cnn_file_name, upconvnet_file_name, top_example_dir_name,
         baseline_storm_metafile_name, trial_storm_metafile_name,
         num_baseline_examples, num_trial_examples, num_novel_examples,
         cnn_feature_layer_name, percent_svd_variance_to_keep,
         output_file_name):
    """Runs novelty detection.

    This is effectively the main method.

    :param cnn_file_name: See documentation at top of file.
    :param upconvnet_file_name: Same.
    :param top_example_dir_name: Same.
    :param baseline_storm_metafile_name: Same.
    :param trial_storm_metafile_name: Same.
    :param num_baseline_examples: Same.
    :param num_trial_examples: Same.
    :param num_novel_examples: Same.
    :param cnn_feature_layer_name: Same.
    :param percent_svd_variance_to_keep: Same.
    :param output_file_name: Same.
    :raises: ValueError: if dimensions of first CNN input matrix != dimensions
        of upconvnet output.
    """

    print('Reading trained CNN from: "{0:s}"...'.format(cnn_file_name))
    cnn_model_object = cnn.read_model(cnn_file_name)

    cnn_metafile_name = '{0:s}/model_metadata.p'.format(
        os.path.split(cnn_file_name)[0]
    )

    print('Reading trained upconvnet from: "{0:s}"...'.format(
        upconvnet_file_name))
    upconvnet_model_object = cnn.read_model(upconvnet_file_name)

    # ucn_output_dimensions = numpy.array(
    #     upconvnet_model_object.output.get_shape().as_list()[1:], dtype=int
    # )

    if isinstance(cnn_model_object.input, list):
        first_cnn_input_tensor = cnn_model_object.input[0]
    else:
        first_cnn_input_tensor = cnn_model_object.input

    cnn_input_dimensions = numpy.array(
        first_cnn_input_tensor.get_shape().as_list()[1:], dtype=int
    )

    # if not numpy.array_equal(cnn_input_dimensions, ucn_output_dimensions):
    #     error_string = (
    #         'Dimensions of first CNN input matrix ({0:s}) should equal '
    #         'dimensions of upconvnet output ({1:s}).'
    #     ).format(str(cnn_input_dimensions), str(ucn_output_dimensions))
    #
    #     raise ValueError(error_string)

    print('Reading CNN metadata from: "{0:s}"...'.format(cnn_metafile_name))
    cnn_metadata_dict = cnn.read_model_metadata(cnn_metafile_name)

    print('Reading metadata for baseline examples from: "{0:s}"...'.format(
        baseline_storm_metafile_name))
    baseline_full_id_strings, baseline_times_unix_sec = (
        tracking_io.read_ids_and_times(baseline_storm_metafile_name)
    )

    print('Reading metadata for trial examples from: "{0:s}"...'.format(
        trial_storm_metafile_name))
    trial_full_id_strings, trial_times_unix_sec = (
        tracking_io.read_ids_and_times(trial_storm_metafile_name)
    )

    if 0 < num_baseline_examples < len(baseline_full_id_strings):
        baseline_full_id_strings = baseline_full_id_strings[
            :num_baseline_examples]
        baseline_times_unix_sec = baseline_times_unix_sec[
            :num_baseline_examples]

    if 0 < num_trial_examples < len(trial_full_id_strings):
        trial_full_id_strings = trial_full_id_strings[:num_trial_examples]
        trial_times_unix_sec = trial_times_unix_sec[:num_trial_examples]

    num_trial_examples = len(trial_full_id_strings)

    if num_novel_examples <= 0:
        num_novel_examples = num_trial_examples + 0

    num_novel_examples = min([num_novel_examples, num_trial_examples])
    print('Number of novel examples to find: {0:d}'.format(num_novel_examples))

    bad_baseline_indices = tracking_utils.find_storm_objects(
        all_id_strings=baseline_full_id_strings,
        all_times_unix_sec=baseline_times_unix_sec,
        id_strings_to_keep=trial_full_id_strings,
        times_to_keep_unix_sec=trial_times_unix_sec, allow_missing=True)

    print('Removing {0:d} trial examples from baseline set...'.format(
        len(bad_baseline_indices)
    ))

    baseline_times_unix_sec = numpy.delete(
        baseline_times_unix_sec, bad_baseline_indices
    )
    baseline_full_id_strings = numpy.delete(
        numpy.array(baseline_full_id_strings), bad_baseline_indices
    )
    baseline_full_id_strings = baseline_full_id_strings.tolist()

    # num_baseline_examples = len(baseline_full_id_strings)

    print(SEPARATOR_STRING)

    list_of_baseline_input_matrices, _ = testing_io.read_specific_examples(
        top_example_dir_name=top_example_dir_name,
        desired_full_id_strings=baseline_full_id_strings,
        desired_times_unix_sec=baseline_times_unix_sec,
        option_dict=cnn_metadata_dict[cnn.TRAINING_OPTION_DICT_KEY],
        list_of_layer_operation_dicts=cnn_metadata_dict[
            cnn.LAYER_OPERATIONS_KEY]
    )

    print(SEPARATOR_STRING)

    list_of_trial_input_matrices, _ = testing_io.read_specific_examples(
        top_example_dir_name=top_example_dir_name,
        desired_full_id_strings=trial_full_id_strings,
        desired_times_unix_sec=trial_times_unix_sec,
        option_dict=cnn_metadata_dict[cnn.TRAINING_OPTION_DICT_KEY],
        list_of_layer_operation_dicts=cnn_metadata_dict[
            cnn.LAYER_OPERATIONS_KEY]
    )

    print(SEPARATOR_STRING)

    novelty_dict = novelty_detection.do_novelty_detection(
        list_of_baseline_input_matrices=list_of_baseline_input_matrices,
        list_of_trial_input_matrices=list_of_trial_input_matrices,
        cnn_model_object=cnn_model_object,
        cnn_feature_layer_name=cnn_feature_layer_name,
        upconvnet_model_object=upconvnet_model_object,
        num_novel_examples=num_novel_examples, multipass=False,
        percent_svd_variance_to_keep=percent_svd_variance_to_keep)

    print(SEPARATOR_STRING)

    print('Adding metadata to novelty-detection results...')
    novelty_dict = novelty_detection.add_metadata(
        novelty_dict=novelty_dict,
        baseline_full_id_strings=baseline_full_id_strings,
        baseline_storm_times_unix_sec=baseline_times_unix_sec,
        trial_full_id_strings=trial_full_id_strings,
        trial_storm_times_unix_sec=trial_times_unix_sec,
        cnn_file_name=cnn_file_name, upconvnet_file_name=upconvnet_file_name)

    print('Denormalizing inputs and outputs of novelty detection...')

    novelty_dict[novelty_detection.BASELINE_INPUTS_KEY] = (
        model_interpretation.denormalize_data(
            list_of_input_matrices=novelty_dict[
                novelty_detection.BASELINE_INPUTS_KEY
            ],
            model_metadata_dict=cnn_metadata_dict)
    )

    novelty_dict[novelty_detection.TRIAL_INPUTS_KEY] = (
        model_interpretation.denormalize_data(
            list_of_input_matrices=novelty_dict[
                novelty_detection.TRIAL_INPUTS_KEY
            ],
            model_metadata_dict=cnn_metadata_dict)
    )

    cnn_metadata_dict[
        cnn.TRAINING_OPTION_DICT_KEY][trainval_io.SOUNDING_FIELDS_KEY] = None

    novelty_dict[novelty_detection.NOVEL_IMAGES_UPCONV_KEY] = (
        model_interpretation.denormalize_data(
            list_of_input_matrices=[
                novelty_dict[novelty_detection.NOVEL_IMAGES_UPCONV_KEY]
            ],
            model_metadata_dict=cnn_metadata_dict)
    )[0]

    novelty_dict[novelty_detection.NOVEL_IMAGES_UPCONV_SVD_KEY] = (
        model_interpretation.denormalize_data(
            list_of_input_matrices=[
                novelty_dict[novelty_detection.NOVEL_IMAGES_UPCONV_SVD_KEY]
            ],
            model_metadata_dict=cnn_metadata_dict)
    )[0]

    print('Writing results to: "{0:s}"...'.format(output_file_name))
    novelty_detection.write_standard_file(novelty_dict=novelty_dict,
                                          pickle_file_name=output_file_name)
Esempio n. 19
0
def read_specific_examples(
        top_example_dir_name, desired_storm_ids, desired_times_unix_sec,
        option_dict, list_of_layer_operation_dicts=None):
    """Reads predictors for specific examples (storm objects).

    E = number of desired examples

    :param top_example_dir_name: Name of top-level directory with pre-processed
        examples.  Files therein will be found by
        `input_examples.find_example_file`.
    :param desired_storm_ids: length-E list of storm IDs (strings).
    :param desired_times_unix_sec: length-E numpy array of storm times.
    :param option_dict: See doc for any generator in this file.
    :param list_of_layer_operation_dicts: See doc for
        `gridrad_generator_2d_reduced`.  If you do not want to reduce radar
        images from 3-D to 2-D, leave this as None.
    :return: list_of_predictor_matrices: length-T list of numpy arrays, where
        T = number of input tensors to model.  The first dimension of each numpy
        array has length E.
    :return: sounding_pressure_matrix_pascals: numpy array (E x H_s) of
        pressures.  If soundings were not read, this is None.
    """

    option_dict[trainval_io.SAMPLING_FRACTIONS_KEY] = None

    desired_spc_date_strings = [
        time_conversion.time_to_spc_date_string(t)
        for t in desired_times_unix_sec
    ]
    unique_spc_date_strings = numpy.unique(
        numpy.array(desired_spc_date_strings)
    ).tolist()

    myrorss_2d3d = None

    storm_ids = []
    storm_times_unix_sec = numpy.array([], dtype=int)
    list_of_predictor_matrices = None
    sounding_pressure_matrix_pascals = None

    for this_spc_date_string in unique_spc_date_strings:
        this_start_time_unix_sec = time_conversion.get_start_of_spc_date(
            this_spc_date_string)
        this_end_time_unix_sec = time_conversion.get_end_of_spc_date(
            this_spc_date_string)

        this_example_file_name = input_examples.find_example_file(
            top_directory_name=top_example_dir_name, shuffled=False,
            spc_date_string=this_spc_date_string)

        option_dict[trainval_io.EXAMPLE_FILES_KEY] = [this_example_file_name]
        option_dict[trainval_io.FIRST_STORM_TIME_KEY] = this_start_time_unix_sec
        option_dict[trainval_io.LAST_STORM_TIME_KEY] = this_end_time_unix_sec

        if myrorss_2d3d is None:
            netcdf_dataset = netCDF4.Dataset(this_example_file_name)
            myrorss_2d3d = (
                input_examples.REFL_IMAGE_MATRIX_KEY in netcdf_dataset.variables
            )
            netcdf_dataset.close()

        if list_of_layer_operation_dicts is not None:
            this_generator = gridrad_generator_2d_reduced(
                option_dict=option_dict,
                list_of_operation_dicts=list_of_layer_operation_dicts,
                num_examples_total=LARGE_INTEGER)
        elif myrorss_2d3d:
            this_generator = myrorss_generator_2d3d(
                option_dict=option_dict, num_examples_total=LARGE_INTEGER)
        else:
            this_generator = generator_2d_or_3d(
                option_dict=option_dict, num_examples_total=LARGE_INTEGER)

        this_storm_object_dict = next(this_generator)

        these_desired_indices = numpy.where(numpy.logical_and(
            desired_times_unix_sec >= this_start_time_unix_sec,
            desired_times_unix_sec <= this_end_time_unix_sec
        ))[0]

        these_indices = tracking_utils.find_storm_objects(
            all_storm_ids=this_storm_object_dict[STORM_IDS_KEY],
            all_times_unix_sec=this_storm_object_dict[STORM_TIMES_KEY],
            storm_ids_to_keep=
            [desired_storm_ids[k] for k in these_desired_indices],
            times_to_keep_unix_sec=
            desired_times_unix_sec[these_desired_indices],
            allow_missing=False
        )

        storm_ids += [
            this_storm_object_dict[STORM_IDS_KEY][k] for k in these_indices
        ]
        storm_times_unix_sec = numpy.concatenate((
            storm_times_unix_sec,
            this_storm_object_dict[STORM_TIMES_KEY][these_indices]
        ))

        this_pressure_matrix_pascals = this_storm_object_dict[
            SOUNDING_PRESSURES_KEY]

        if this_pressure_matrix_pascals is not None:
            this_pressure_matrix_pascals = this_pressure_matrix_pascals[
                these_indices, ...]

            if sounding_pressure_matrix_pascals is None:
                sounding_pressure_matrix_pascals = (
                    this_pressure_matrix_pascals + 0.)
            else:
                sounding_pressure_matrix_pascals = numpy.concatenate(
                    (sounding_pressure_matrix_pascals,
                     this_pressure_matrix_pascals), axis=0)

        if list_of_predictor_matrices is None:
            num_matrices = len(this_storm_object_dict[INPUT_MATRICES_KEY])
            list_of_predictor_matrices = [None] * num_matrices

        for k in range(len(list_of_predictor_matrices)):
            this_new_matrix = this_storm_object_dict[INPUT_MATRICES_KEY][k][
                these_indices, ...]

            if list_of_predictor_matrices[k] is None:
                list_of_predictor_matrices[k] = this_new_matrix + 0.
            else:
                list_of_predictor_matrices[k] = numpy.concatenate(
                    (list_of_predictor_matrices[k], this_new_matrix), axis=0)

    sort_indices = tracking_utils.find_storm_objects(
        all_storm_ids=storm_ids, all_times_unix_sec=storm_times_unix_sec,
        storm_ids_to_keep=desired_storm_ids,
        times_to_keep_unix_sec=desired_times_unix_sec, allow_missing=False)

    for k in range(len(list_of_predictor_matrices)):
        list_of_predictor_matrices[k] = list_of_predictor_matrices[k][
            sort_indices, ...]

    if sounding_pressure_matrix_pascals is not None:
        sounding_pressure_matrix_pascals = sounding_pressure_matrix_pascals[
            sort_indices, ...]

    return list_of_predictor_matrices, sounding_pressure_matrix_pascals
Esempio n. 20
0
def _run(input_human_file_name, input_machine_file_name, guided_gradcam_flag,
         abs_percentile_threshold, output_dir_name):
    """Compares human-generated vs. machine-generated interpretation map.

    This is effectively the main method.

    :param input_human_file_name: See documentation at top of file.
    :param input_machine_file_name: Same.
    :param guided_gradcam_flag: Same.
    :param abs_percentile_threshold: Same.
    :param output_dir_name: Same.
    """

    file_system_utils.mkdir_recursive_if_necessary(
        directory_name=output_dir_name)

    if abs_percentile_threshold < 0:
        abs_percentile_threshold = None
    if abs_percentile_threshold is not None:
        error_checking.assert_is_leq(abs_percentile_threshold, 100.)

    print('Reading data from: "{0:s}"...'.format(input_human_file_name))
    human_polygon_dict = human_polygons.read_polygons(input_human_file_name)

    human_positive_mask_matrix_4d = human_polygon_dict[
        human_polygons.POSITIVE_MASK_MATRIX_KEY]
    human_negative_mask_matrix_4d = human_polygon_dict[
        human_polygons.NEGATIVE_MASK_MATRIX_KEY]

    full_storm_id_string = human_polygon_dict[human_polygons.STORM_ID_KEY]
    storm_time_unix_sec = human_polygon_dict[human_polygons.STORM_TIME_KEY]
    pmm_flag = full_storm_id_string is None and storm_time_unix_sec is None

    print('Reading data from: "{0:s}"...'.format(input_machine_file_name))

    # TODO(thunderhoser): This is a HACK.
    machine_channel_indices = numpy.array([2, 8], dtype=int)

    if pmm_flag:
        try:
            saliency_dict = saliency_maps.read_pmm_file(input_machine_file_name)
            saliency_flag = True
            model_file_name = saliency_dict[saliency_maps.MODEL_FILE_KEY]

            predictor_matrix = saliency_dict.pop(
                saliency_maps.MEAN_INPUT_MATRICES_KEY
            )[0][..., machine_channel_indices]

            machine_interpretation_matrix_3d = saliency_dict.pop(
                saliency_maps.MEAN_SALIENCY_MATRICES_KEY
            )[0][..., machine_channel_indices]

        except ValueError:
            gradcam_dict = gradcam.read_pmm_file(input_machine_file_name)
            saliency_flag = False
            model_file_name = gradcam_dict[gradcam.MODEL_FILE_KEY]

            predictor_matrix = gradcam_dict.pop(
                gradcam.MEAN_INPUT_MATRICES_KEY
            )[0][..., machine_channel_indices]

            if guided_gradcam_flag:
                machine_interpretation_matrix_3d = gradcam_dict.pop(
                    gradcam.MEAN_GUIDED_GRADCAM_KEY
                )[..., machine_channel_indices]
            else:
                machine_interpretation_matrix_3d = gradcam_dict.pop(
                    gradcam.MEAN_CLASS_ACTIVATIONS_KEY)
    else:
        try:
            saliency_dict = saliency_maps.read_standard_file(
                input_machine_file_name)

            saliency_flag = True
            all_full_id_strings = saliency_dict[saliency_maps.FULL_IDS_KEY]
            all_times_unix_sec = saliency_dict[saliency_maps.STORM_TIMES_KEY]
            model_file_name = saliency_dict[saliency_maps.MODEL_FILE_KEY]

            predictor_matrix = saliency_dict.pop(
                saliency_maps.INPUT_MATRICES_KEY
            )[0][..., machine_channel_indices]

            machine_interpretation_matrix_3d = saliency_dict.pop(
                saliency_maps.SALIENCY_MATRICES_KEY
            )[0][..., machine_channel_indices]

        except ValueError:
            gradcam_dict = gradcam.read_standard_file(input_machine_file_name)

            saliency_flag = False
            all_full_id_strings = gradcam_dict[gradcam.FULL_IDS_KEY]
            all_times_unix_sec = gradcam_dict[gradcam.STORM_TIMES_KEY]
            model_file_name = gradcam_dict[gradcam.MODEL_FILE_KEY]

            predictor_matrix = gradcam_dict.pop(
                gradcam.INPUT_MATRICES_KEY
            )[0][..., machine_channel_indices]

            if guided_gradcam_flag:
                machine_interpretation_matrix_3d = gradcam_dict.pop(
                    gradcam.GUIDED_GRADCAM_KEY
                )[..., machine_channel_indices]
            else:
                machine_interpretation_matrix_3d = gradcam_dict.pop(
                    gradcam.CLASS_ACTIVATIONS_KEY)

        storm_object_index = tracking_utils.find_storm_objects(
            all_id_strings=all_full_id_strings,
            all_times_unix_sec=all_times_unix_sec,
            id_strings_to_keep=[full_storm_id_string],
            times_to_keep_unix_sec=numpy.array(
                [storm_time_unix_sec], dtype=int
            ),
            allow_missing=False
        )[0]

        predictor_matrix = predictor_matrix[storm_object_index, ...]
        machine_interpretation_matrix_3d = machine_interpretation_matrix_3d[
            storm_object_index, ...]

    if not saliency_flag and not guided_gradcam_flag:
        machine_interpretation_matrix_3d = numpy.expand_dims(
            machine_interpretation_matrix_3d, axis=-1)

        machine_interpretation_matrix_3d = numpy.repeat(
            a=machine_interpretation_matrix_3d,
            repeats=predictor_matrix.shape[-1], axis=-1)

    if not (saliency_flag or guided_gradcam_flag):
        human_negative_mask_matrix_4d = None

    model_metafile_name = '{0:s}/model_metadata.p'.format(
        os.path.split(model_file_name)[0]
    )

    print('Reading metadata from: "{0:s}"...'.format(model_metafile_name))
    model_metadata_dict = cnn.read_model_metadata(model_metafile_name)
    model_metadata_dict[cnn.LAYER_OPERATIONS_KEY] = [
        model_metadata_dict[cnn.LAYER_OPERATIONS_KEY][k]
        for k in machine_channel_indices
    ]

    human_positive_mask_matrix_3d, human_negative_mask_matrix_3d = (
        _reshape_human_maps(
            model_metadata_dict=model_metadata_dict,
            positive_mask_matrix_4d=human_positive_mask_matrix_4d,
            negative_mask_matrix_4d=human_negative_mask_matrix_4d)
    )

    num_channels = human_positive_mask_matrix_3d.shape[-1]
    machine_positive_mask_matrix_3d = numpy.full(
        human_positive_mask_matrix_3d.shape, False, dtype=bool)
    positive_iou_by_channel = numpy.full(num_channels, numpy.nan)

    if human_negative_mask_matrix_3d is None:
        machine_negative_mask_matrix_3d = None
        negative_iou_by_channel = None
    else:
        machine_negative_mask_matrix_3d = numpy.full(
            human_negative_mask_matrix_3d.shape, False, dtype=bool)
        negative_iou_by_channel = numpy.full(num_channels, numpy.nan)

    for k in range(num_channels):
        this_negative_matrix = (
            None if human_negative_mask_matrix_3d is None
            else human_negative_mask_matrix_3d[..., k]
        )

        this_comparison_dict = _do_comparison_one_channel(
            machine_interpretation_matrix_2d=machine_interpretation_matrix_3d[
                ..., k],
            abs_percentile_threshold=abs_percentile_threshold,
            human_positive_mask_matrix_2d=human_positive_mask_matrix_3d[..., k],
            human_negative_mask_matrix_2d=this_negative_matrix)

        machine_positive_mask_matrix_3d[..., k] = this_comparison_dict[
            MACHINE_POSITIVE_MASK_KEY]
        positive_iou_by_channel[k] = this_comparison_dict[POSITIVE_IOU_KEY]

        if human_negative_mask_matrix_3d is None:
            continue

        machine_negative_mask_matrix_3d[..., k] = this_comparison_dict[
            MACHINE_NEGATIVE_MASK_KEY]
        negative_iou_by_channel[k] = this_comparison_dict[NEGATIVE_IOU_KEY]

    this_file_name = '{0:s}/positive_comparison.jpg'.format(output_dir_name)
    _plot_comparison(
        predictor_matrix=predictor_matrix,
        model_metadata_dict=model_metadata_dict,
        machine_mask_matrix_3d=machine_positive_mask_matrix_3d,
        human_mask_matrix_3d=human_positive_mask_matrix_3d,
        iou_by_channel=positive_iou_by_channel,
        positive_flag=True, output_file_name=this_file_name)

    if human_negative_mask_matrix_3d is None:
        return

    this_file_name = '{0:s}/negative_comparison.jpg'.format(output_dir_name)
    _plot_comparison(
        predictor_matrix=predictor_matrix,
        model_metadata_dict=model_metadata_dict,
        machine_mask_matrix_3d=machine_negative_mask_matrix_3d,
        human_mask_matrix_3d=human_negative_mask_matrix_3d,
        iou_by_channel=negative_iou_by_channel,
        positive_flag=False, output_file_name=this_file_name)
def _extract_storm_images(num_image_rows, num_image_columns, rotate_grids,
                          rotated_grid_spacing_metres, radar_field_names,
                          refl_heights_m_agl, spc_date_string,
                          tarred_myrorss_dir_name, untarred_myrorss_dir_name,
                          top_tracking_dir_name, elevation_dir_name,
                          tracking_scale_metres2, target_name,
                          top_target_dir_name, top_output_dir_name):
    """Extracts storm-centered img for each field/height pair and storm object.

    :param num_image_rows: See documentation at top of file.
    :param num_image_columns: Same.
    :param rotate_grids: Same.
    :param rotated_grid_spacing_metres: Same.
    :param radar_field_names: Same.
    :param refl_heights_m_agl: Same.
    :param spc_date_string: Same.
    :param tarred_myrorss_dir_name: Same.
    :param untarred_myrorss_dir_name: Same.
    :param top_tracking_dir_name: Same.
    :param elevation_dir_name: Same.
    :param tracking_scale_metres2: Same.
    :param target_name: Same.
    :param top_target_dir_name: Same.
    :param top_output_dir_name: Same.
    """

    if target_name in ['', 'None']:
        target_name = None

    if target_name is not None:
        target_param_dict = target_val_utils.target_name_to_params(target_name)

        target_file_name = target_val_utils.find_target_file(
            top_directory_name=top_target_dir_name,
            event_type_string=target_param_dict[
                target_val_utils.EVENT_TYPE_KEY],
            spc_date_string=spc_date_string)

        print('Reading data from: "{0:s}"...'.format(target_file_name))
        target_dict = target_val_utils.read_target_values(
            netcdf_file_name=target_file_name, target_names=[target_name])
        print('\n')

    refl_heights_m_asl = radar_utils.get_valid_heights(
        data_source=radar_utils.MYRORSS_SOURCE_ID,
        field_name=radar_utils.REFL_NAME)

    # Untar files with azimuthal shear.
    az_shear_field_names = list(
        set(radar_field_names) & set(ALL_AZ_SHEAR_FIELD_NAMES))

    if len(az_shear_field_names):
        az_shear_tar_file_name = (
            '{0:s}/{1:s}/azimuthal_shear_only/{2:s}.tar').format(
                tarred_myrorss_dir_name, spc_date_string[:4], spc_date_string)

        myrorss_io.unzip_1day_tar_file(
            tar_file_name=az_shear_tar_file_name,
            field_names=az_shear_field_names,
            spc_date_string=spc_date_string,
            top_target_directory_name=untarred_myrorss_dir_name)
        print(SEPARATOR_STRING)

    # Untar files with other radar fields.
    non_shear_field_names = list(
        set(radar_field_names) - set(ALL_AZ_SHEAR_FIELD_NAMES))

    if len(non_shear_field_names):
        non_shear_tar_file_name = '{0:s}/{1:s}/{2:s}.tar'.format(
            tarred_myrorss_dir_name, spc_date_string[:4], spc_date_string)

        myrorss_io.unzip_1day_tar_file(
            tar_file_name=non_shear_tar_file_name,
            field_names=non_shear_field_names,
            spc_date_string=spc_date_string,
            top_target_directory_name=untarred_myrorss_dir_name,
            refl_heights_m_asl=refl_heights_m_asl)
        print(SEPARATOR_STRING)

    # Read storm tracks for the given SPC date.
    tracking_file_names = tracking_io.find_files_one_spc_date(
        spc_date_string=spc_date_string,
        source_name=tracking_utils.SEGMOTION_NAME,
        top_tracking_dir_name=top_tracking_dir_name,
        tracking_scale_metres2=tracking_scale_metres2)[0]

    storm_object_table = tracking_io.read_many_files(tracking_file_names)[
        storm_images.STORM_COLUMNS_NEEDED]
    print(SEPARATOR_STRING)

    if target_name is not None:
        print(('Removing storm objects without target values (variable = '
               '"{0:s}")...').format(target_name))

        these_indices = tracking_utils.find_storm_objects(
            all_id_strings=storm_object_table[
                tracking_utils.FULL_ID_COLUMN].values.tolist(),
            all_times_unix_sec=storm_object_table[
                tracking_utils.VALID_TIME_COLUMN].values.astype(int),
            id_strings_to_keep=target_dict[target_val_utils.FULL_IDS_KEY],
            times_to_keep_unix_sec=target_dict[
                target_val_utils.VALID_TIMES_KEY],
            allow_missing=False)

        num_storm_objects_orig = len(storm_object_table.index)
        storm_object_table = storm_object_table.iloc[these_indices]
        num_storm_objects = len(storm_object_table.index)

        print('Removed {0:d} of {1:d} storm objects!\n'.format(
            num_storm_objects_orig - num_storm_objects,
            num_storm_objects_orig))

    # Extract storm-centered radar images.
    storm_images.extract_storm_images_myrorss_or_mrms(
        storm_object_table=storm_object_table,
        radar_source=radar_utils.MYRORSS_SOURCE_ID,
        top_radar_dir_name=untarred_myrorss_dir_name,
        top_output_dir_name=top_output_dir_name,
        elevation_dir_name=elevation_dir_name,
        num_storm_image_rows=num_image_rows,
        num_storm_image_columns=num_image_columns,
        rotate_grids=rotate_grids,
        rotated_grid_spacing_metres=rotated_grid_spacing_metres,
        radar_field_names=radar_field_names,
        reflectivity_heights_m_agl=refl_heights_m_agl)
    print(SEPARATOR_STRING)

    # Remove untarred MYRORSS files.
    myrorss_io.remove_unzipped_data_1day(
        spc_date_string=spc_date_string,
        top_directory_name=untarred_myrorss_dir_name,
        field_names=radar_field_names,
        refl_heights_m_asl=refl_heights_m_asl)
    print(SEPARATOR_STRING)
Esempio n. 22
0
def _read_target_values(top_target_dir_name, storm_activations,
                        activation_metadata_dict):
    """Reads target value for each storm object.

    E = number of examples (storm objects)

    :param top_target_dir_name: See documentation at top of file.
    :param storm_activations: length-E numpy array of activations.
    :param activation_metadata_dict: Dictionary returned by
        `model_activation.read_file`.
    :return: target_dict: Dictionary with the following keys.
    target_dict['full_id_strings']: length-E list of full storm IDs.
    target_dict['storm_times_unix_sec']: length-E numpy array of storm times.
    target_dict['storm_activations']: length-E numpy array of model activations.
    target_dict['storm_target_values']: length-E numpy array of target values.

    :raises: ValueError: if the target variable is multiclass and not binarized.
    """

    # Convert input args.
    full_id_strings = activation_metadata_dict[model_activation.FULL_IDS_KEY]
    storm_times_unix_sec = activation_metadata_dict[
        model_activation.STORM_TIMES_KEY]

    storm_spc_date_strings_numpy = numpy.array([
        time_conversion.time_to_spc_date_string(t)
        for t in storm_times_unix_sec
    ],
                                               dtype=object)

    unique_spc_date_strings_numpy = numpy.unique(storm_spc_date_strings_numpy)

    # Read metadata for machine-learning model.
    model_file_name = activation_metadata_dict[
        model_activation.MODEL_FILE_NAME_KEY]
    model_metadata_file_name = '{0:s}/model_metadata.p'.format(
        os.path.split(model_file_name)[0])

    print('Reading metadata from: "{0:s}"...'.format(model_metadata_file_name))
    model_metadata_dict = cnn.read_model_metadata(model_metadata_file_name)
    training_option_dict = model_metadata_dict[cnn.TRAINING_OPTION_DICT_KEY]

    target_name = training_option_dict[trainval_io.TARGET_NAME_KEY]
    num_classes = target_val_utils.target_name_to_num_classes(
        target_name=target_name, include_dead_storms=False)

    binarize_target = (training_option_dict[trainval_io.BINARIZE_TARGET_KEY]
                       and num_classes > 2)

    if num_classes > 2 and not binarize_target:
        error_string = (
            'The target variable ("{0:s}") is multiclass, which this script '
            'cannot handle.').format(target_name)

        raise ValueError(error_string)

    event_type_string = target_val_utils.target_name_to_params(target_name)[
        target_val_utils.EVENT_TYPE_KEY]

    # Read target values.
    storm_target_values = numpy.array([], dtype=int)
    id_sort_indices = numpy.array([], dtype=int)
    num_spc_dates = len(unique_spc_date_strings_numpy)

    for i in range(num_spc_dates):
        this_target_file_name = target_val_utils.find_target_file(
            top_directory_name=top_target_dir_name,
            event_type_string=event_type_string,
            spc_date_string=unique_spc_date_strings_numpy[i])

        print('Reading data from: "{0:s}"...'.format(this_target_file_name))
        this_target_value_dict = target_val_utils.read_target_values(
            netcdf_file_name=this_target_file_name, target_names=[target_name])

        these_indices = numpy.where(storm_spc_date_strings_numpy ==
                                    unique_spc_date_strings_numpy[i])[0]
        id_sort_indices = numpy.concatenate((id_sort_indices, these_indices))

        these_indices = tracking_utils.find_storm_objects(
            all_id_strings=this_target_value_dict[
                target_val_utils.FULL_IDS_KEY],
            all_times_unix_sec=this_target_value_dict[
                target_val_utils.VALID_TIMES_KEY],
            id_strings_to_keep=[full_id_strings[k] for k in these_indices],
            times_to_keep_unix_sec=storm_times_unix_sec[these_indices])

        if len(these_indices) == 0:
            continue

        these_target_values = this_target_value_dict[
            target_val_utils.TARGET_MATRIX_KEY][these_indices, :]

        these_target_values = numpy.reshape(these_target_values,
                                            these_target_values.size)

        storm_target_values = numpy.concatenate(
            (storm_target_values, these_target_values))

    good_indices = numpy.where(
        storm_target_values != target_val_utils.INVALID_STORM_INTEGER)[0]

    storm_target_values = storm_target_values[good_indices]
    id_sort_indices = id_sort_indices[good_indices]

    if binarize_target:
        storm_target_values = (storm_target_values == num_classes -
                               1).astype(int)

    return {
        FULL_IDS_KEY: [full_id_strings[k] for k in id_sort_indices],
        STORM_TIMES_KEY: storm_times_unix_sec[id_sort_indices],
        STORM_ACTIVATIONS_KEY: storm_activations[id_sort_indices],
        TARGET_VALUES_KEY: storm_target_values
    }
Esempio n. 23
0
def _run(input_gradcam_file_name, input_human_file_name,
         output_gradcam_file_name):
    """Runs human novelty detection on class-activation maps.
    
    This is effectively the main method.
    
    :param input_gradcam_file_name: See documentation at top of file.
    :param input_human_file_name: Same.
    :param output_gradcam_file_name: Same.
    :raises: TypeError: if `input_gradcam_file_name` was created by a net that
        does both 2-D and 3-D convolution.
    :raises: TypeError: if class-activation maps are not 2-D.
    """

    print('Reading data from: "{0:s}"...'.format(input_human_file_name))
    human_point_dict = human_polygons.read_points(input_human_file_name)

    # TODO(thunderhoser): Deal with no human points.
    human_grid_rows = human_point_dict[human_polygons.GRID_ROW_BY_POINT_KEY]
    human_grid_columns = human_point_dict[
        human_polygons.GRID_COLUMN_BY_POINT_KEY]
    human_panel_rows = human_point_dict[human_polygons.PANEL_ROW_BY_POINT_KEY]
    human_panel_columns = human_point_dict[
        human_polygons.PANEL_COLUMN_BY_POINT_KEY]

    full_storm_id_string = human_point_dict[human_polygons.STORM_ID_KEY]
    storm_time_unix_sec = human_point_dict[human_polygons.STORM_TIME_KEY]
    pmm_flag = full_storm_id_string is None and storm_time_unix_sec is None

    print('Reading data from: "{0:s}"...'.format(input_gradcam_file_name))

    if pmm_flag:
        gradcam_dict = gradcam.read_pmm_file(input_gradcam_file_name)
    else:
        gradcam_dict = gradcam.read_standard_file(input_gradcam_file_name)

    machine_region_dict = gradcam_dict[gradcam.REGION_DICT_KEY]
    list_of_mask_matrices = machine_region_dict[gradcam.MASK_MATRICES_KEY]

    if len(list_of_mask_matrices) == 3:
        raise TypeError('This script cannot handle nets that do both 2-D and '
                        '3-D convolution.')

    machine_mask_matrix = list_of_mask_matrices[0]

    if pmm_flag:
        storm_object_index = -1
    else:
        storm_object_index = tracking_utils.find_storm_objects(
            all_id_strings=gradcam_dict[gradcam.FULL_IDS_KEY],
            all_times_unix_sec=gradcam_dict[gradcam.STORM_TIMES_KEY],
            id_strings_to_keep=[full_storm_id_string],
            times_to_keep_unix_sec=numpy.array([storm_time_unix_sec],
                                               dtype=int),
            allow_missing=False)[0]

        machine_mask_matrix = machine_mask_matrix[storm_object_index, ...]

    num_spatial_dimensions = len(machine_mask_matrix.shape) - 1
    if num_spatial_dimensions != 2:
        raise TypeError(
            'This script can compare only with 2-D class-activation'
            ' maps.')

    if pmm_flag:
        machine_polygon_objects = machine_region_dict[
            gradcam.POLYGON_OBJECTS_KEY][0][0]
    else:
        machine_polygon_objects = machine_region_dict[
            gradcam.POLYGON_OBJECTS_KEY][storm_object_index][0]
def _run(top_orig_tracking_dir_name, top_new_tracking_dir_name,
         first_spc_date_string, last_spc_date_string, output_file_name):
    """Plots storms that were removed by remove_storms_outside_conus.py.

    This is effectively the main method.

    :param top_orig_tracking_dir_name: See documentation at top of file.
    :param top_new_tracking_dir_name: Same.
    :param first_spc_date_string: Same.
    :param last_spc_date_string: Same.
    :param output_file_name: Same.
    """

    file_system_utils.mkdir_recursive_if_necessary(file_name=output_file_name)

    spc_date_strings = time_conversion.get_spc_dates_in_range(
        first_spc_date_string=first_spc_date_string,
        last_spc_date_string=last_spc_date_string)

    orig_tracking_file_names = []

    for d in spc_date_strings:
        orig_tracking_file_names += tracking_io.find_files_one_spc_date(
            top_tracking_dir_name=top_orig_tracking_dir_name,
            tracking_scale_metres2=DUMMY_TRACKING_SCALE_METRES2,
            source_name=tracking_utils.SEGMOTION_NAME,
            spc_date_string=d,
            raise_error_if_missing=False)[0]

    valid_times_unix_sec = numpy.array(
        [tracking_io.file_name_to_time(f) for f in orig_tracking_file_names],
        dtype=int)

    new_tracking_file_names = [
        tracking_io.find_file(
            top_tracking_dir_name=top_new_tracking_dir_name,
            tracking_scale_metres2=DUMMY_TRACKING_SCALE_METRES2,
            source_name=tracking_utils.SEGMOTION_NAME,
            valid_time_unix_sec=t,
            spc_date_string=time_conversion.time_to_spc_date_string(t),
            raise_error_if_missing=True) for t in valid_times_unix_sec
    ]

    orig_storm_object_table = tracking_io.read_many_files(
        orig_tracking_file_names)
    print(SEPARATOR_STRING)

    new_storm_object_table = tracking_io.read_many_files(
        new_tracking_file_names)
    print(SEPARATOR_STRING)

    orig_storm_id_strings = (
        orig_storm_object_table[tracking_utils.FULL_ID_COLUMN].values.tolist())
    orig_storm_times_unix_sec = (
        orig_storm_object_table[tracking_utils.VALID_TIME_COLUMN].values)
    new_storm_id_strings = (
        new_storm_object_table[tracking_utils.FULL_ID_COLUMN].values.tolist())
    new_storm_times_unix_sec = (
        new_storm_object_table[tracking_utils.VALID_TIME_COLUMN].values)

    num_orig_storm_objects = len(orig_storm_object_table.index)
    orig_kept_flags = numpy.full(num_orig_storm_objects, 0, dtype=bool)

    these_indices = tracking_utils.find_storm_objects(
        all_id_strings=orig_storm_id_strings,
        all_times_unix_sec=orig_storm_times_unix_sec,
        id_strings_to_keep=new_storm_id_strings,
        times_to_keep_unix_sec=new_storm_times_unix_sec,
        allow_missing=False)

    orig_kept_flags[these_indices] = True
    orig_removed_indices = numpy.where(numpy.invert(orig_kept_flags))[0]
    print('{0:d} of {1:d} storm objects were outside CONUS.'.format(
        len(orig_removed_indices), num_orig_storm_objects))

    removed_storm_object_table = orig_storm_object_table.iloc[
        orig_removed_indices]
    removed_latitudes_deg = removed_storm_object_table[
        tracking_utils.CENTROID_LATITUDE_COLUMN].values

    removed_longitudes_deg = removed_storm_object_table[
        tracking_utils.CENTROID_LONGITUDE_COLUMN].values

    figure_object, axes_object, basemap_object = (
        plotting_utils.create_equidist_cylindrical_map(
            min_latitude_deg=numpy.min(removed_latitudes_deg) - 1.,
            max_latitude_deg=numpy.max(removed_latitudes_deg) + 1.,
            min_longitude_deg=numpy.min(removed_longitudes_deg) - 1.,
            max_longitude_deg=numpy.max(removed_longitudes_deg) + 1.,
            resolution_string='i'))

    plotting_utils.plot_coastlines(basemap_object=basemap_object,
                                   axes_object=axes_object,
                                   line_colour=BORDER_COLOUR)
    plotting_utils.plot_countries(basemap_object=basemap_object,
                                  axes_object=axes_object,
                                  line_colour=BORDER_COLOUR)
    plotting_utils.plot_states_and_provinces(basemap_object=basemap_object,
                                             axes_object=axes_object,
                                             line_colour=BORDER_COLOUR)
    plotting_utils.plot_parallels(basemap_object=basemap_object,
                                  axes_object=axes_object,
                                  num_parallels=NUM_PARALLELS)
    plotting_utils.plot_meridians(basemap_object=basemap_object,
                                  axes_object=axes_object,
                                  num_meridians=NUM_MERIDIANS)

    conus_latitudes_deg, conus_longitudes_deg = (
        conus_boundary.read_from_netcdf())
    conus_latitudes_deg, conus_longitudes_deg = conus_boundary.erode_boundary(
        latitudes_deg=conus_latitudes_deg,
        longitudes_deg=conus_longitudes_deg,
        erosion_distance_metres=EROSION_DISTANCE_METRES)

    axes_object.plot(conus_longitudes_deg,
                     conus_latitudes_deg,
                     color=LINE_COLOUR,
                     linestyle='solid',
                     linewidth=LINE_WIDTH)
    axes_object.plot(removed_longitudes_deg,
                     removed_latitudes_deg,
                     linestyle='None',
                     marker=MARKER_TYPE,
                     markersize=MARKER_SIZE,
                     markeredgewidth=0,
                     markerfacecolor=MARKER_COLOUR,
                     markeredgecolor=MARKER_COLOUR)

    print('Saving figure to: "{0:s}"...'.format(output_file_name))
    figure_object.savefig(output_file_name,
                          dpi=FIGURE_RESOLUTION_DPI,
                          pad_inches=0,
                          bbox_inches='tight')
    pyplot.close(figure_object)
Esempio n. 25
0
def _run(top_input_dir_name, target_name_for_downsampling,
         first_spc_date_string, last_spc_date_string, downsampling_classes,
         downsampling_fractions, for_training, top_output_dir_name):
    """Downsamples storm objects, based on target values.

    This is effectively the main method.

    :param top_input_dir_name: See documentation at top of file.
    :param target_name_for_downsampling: Same.
    :param first_spc_date_string: Same.
    :param last_spc_date_string: Same.
    :param downsampling_classes: Same.
    :param downsampling_fractions: Same.
    :param for_training: Same.
    :param top_output_dir_name: Same.
    """

    all_spc_date_strings = time_conversion.get_spc_dates_in_range(
        first_spc_date_string=first_spc_date_string,
        last_spc_date_string=last_spc_date_string)

    downsampling_dict = dict(
        list(zip(downsampling_classes, downsampling_fractions)))

    target_param_dict = target_val_utils.target_name_to_params(
        target_name_for_downsampling)
    event_type_string = target_param_dict[target_val_utils.EVENT_TYPE_KEY]

    input_target_file_names = []
    spc_date_string_by_file = []

    for this_spc_date_string in all_spc_date_strings:
        this_file_name = target_val_utils.find_target_file(
            top_directory_name=top_input_dir_name,
            event_type_string=event_type_string,
            spc_date_string=this_spc_date_string,
            raise_error_if_missing=False)

        if not os.path.isfile(this_file_name):
            continue

        input_target_file_names.append(this_file_name)
        spc_date_string_by_file.append(this_spc_date_string)

    num_files = len(input_target_file_names)
    target_dict_by_file = [None] * num_files

    full_id_strings = []
    storm_times_unix_sec = numpy.array([], dtype=int)
    storm_to_file_indices = numpy.array([], dtype=int)

    target_names = []
    target_matrix = None

    for i in range(num_files):
        print('Reading data from: "{0:s}"...'.format(
            input_target_file_names[i]))

        target_dict_by_file[i] = target_val_utils.read_target_values(
            netcdf_file_name=input_target_file_names[i])

        if i == 0:
            target_names = (
                target_dict_by_file[i][target_val_utils.TARGET_NAMES_KEY])

        these_full_id_strings = (
            target_dict_by_file[i][target_val_utils.FULL_IDS_KEY])

        full_id_strings += these_full_id_strings
        this_num_storm_objects = len(these_full_id_strings)

        storm_times_unix_sec = numpy.concatenate(
            (storm_times_unix_sec,
             target_dict_by_file[i][target_val_utils.VALID_TIMES_KEY]))

        storm_to_file_indices = numpy.concatenate(
            (storm_to_file_indices,
             numpy.full(this_num_storm_objects, i, dtype=int)))

        this_target_matrix = (
            target_dict_by_file[i][target_val_utils.TARGET_MATRIX_KEY])

        if target_matrix is None:
            target_matrix = this_target_matrix + 0
        else:
            target_matrix = numpy.concatenate(
                (target_matrix, this_target_matrix), axis=0)

    print(SEPARATOR_STRING)

    downsampling_index = target_names.index(target_name_for_downsampling)
    good_indices = numpy.where(target_matrix[:, downsampling_index] !=
                               target_val_utils.INVALID_STORM_INTEGER)[0]

    full_id_strings = [full_id_strings[k] for k in good_indices]
    storm_times_unix_sec = storm_times_unix_sec[good_indices]
    target_matrix = target_matrix[good_indices, :]
    storm_to_file_indices = storm_to_file_indices[good_indices]

    primary_id_strings = temporal_tracking.full_to_partial_ids(
        full_id_strings)[0]

    if for_training:
        indices_to_keep = fancy_downsampling.downsample_for_training(
            primary_id_strings=primary_id_strings,
            storm_times_unix_sec=storm_times_unix_sec,
            target_values=target_matrix[:, downsampling_index],
            target_name=target_name_for_downsampling,
            class_fraction_dict=downsampling_dict)
    else:
        indices_to_keep = fancy_downsampling.downsample_for_non_training(
            primary_id_strings=primary_id_strings,
            storm_times_unix_sec=storm_times_unix_sec,
            target_values=target_matrix[:, downsampling_index],
            target_name=target_name_for_downsampling,
            class_fraction_dict=downsampling_dict)

    print(SEPARATOR_STRING)

    for i in range(num_files):
        these_object_subindices = numpy.where(
            storm_to_file_indices[indices_to_keep] == i)[0]

        these_object_indices = indices_to_keep[these_object_subindices]
        if len(these_object_indices) == 0:
            continue

        these_indices_in_file = tracking_utils.find_storm_objects(
            all_id_strings=target_dict_by_file[i][
                target_val_utils.FULL_IDS_KEY],
            all_times_unix_sec=target_dict_by_file[i][
                target_val_utils.VALID_TIMES_KEY],
            id_strings_to_keep=[
                full_id_strings[k] for k in these_object_indices
            ],
            times_to_keep_unix_sec=storm_times_unix_sec[these_object_indices],
            allow_missing=False)

        this_output_dict = {
            tracking_utils.FULL_ID_COLUMN: [
                target_dict_by_file[i][target_val_utils.FULL_IDS_KEY][k]
                for k in these_indices_in_file
            ],
            tracking_utils.VALID_TIME_COLUMN:
            target_dict_by_file[i][target_val_utils.VALID_TIMES_KEY]
            [these_indices_in_file]
        }

        for j in range(len(target_names)):
            this_output_dict[target_names[j]] = (target_dict_by_file[i][
                target_val_utils.TARGET_MATRIX_KEY][these_indices_in_file, j])

        this_output_table = pandas.DataFrame.from_dict(this_output_dict)

        this_new_file_name = target_val_utils.find_target_file(
            top_directory_name=top_output_dir_name,
            event_type_string=event_type_string,
            spc_date_string=spc_date_string_by_file[i],
            raise_error_if_missing=False)

        print((
            'Writing {0:d} downsampled storm objects (out of {1:d} total) to: '
            '"{2:s}"...').format(
                len(this_output_table.index),
                len(target_dict_by_file[i][target_val_utils.FULL_IDS_KEY]),
                this_new_file_name))

        target_val_utils.write_target_values(
            storm_to_events_table=this_output_table,
            target_names=target_names,
            netcdf_file_name=this_new_file_name)