Example #1
0
def _run(storm_metafile_name, warning_dir_name):
    """Finds which storms are linked to an NWS tornado warning.

    This is effectively the main method.

    :param storm_metafile_name: See documentation at top of file.
    :param warning_dir_name: Same.
    """

    print(
        'Reading storm metadata from: "{0:s}"...'.format(storm_metafile_name))
    full_storm_id_strings, valid_times_unix_sec = (
        tracking_io.read_ids_and_times(storm_metafile_name))
    secondary_id_strings = (
        temporal_tracking.full_to_partial_ids(full_storm_id_strings)[-1])

    these_times_unix_sec = numpy.concatenate(
        (valid_times_unix_sec, valid_times_unix_sec - NUM_SECONDS_PER_DAY,
         valid_times_unix_sec + NUM_SECONDS_PER_DAY))

    spc_date_strings = [
        time_conversion.time_to_spc_date_string(t)
        for t in these_times_unix_sec
    ]
    spc_date_strings = numpy.unique(numpy.array(spc_date_strings))

    linked_secondary_id_strings = []

    for this_spc_date_string in spc_date_strings:
        this_file_name = '{0:s}/tornado_warnings_{1:s}.p'.format(
            warning_dir_name, this_spc_date_string)
        print('Reading warnings from: "{0:s}"...'.format(this_file_name))

        this_file_handle = open(this_file_name, 'rb')
        this_warning_table = pickle.load(this_file_handle)
        this_file_handle.close()

        this_num_warnings = len(this_warning_table.index)

        for k in range(this_num_warnings):
            linked_secondary_id_strings += (
                this_warning_table[LINKED_SECONDARY_IDS_KEY].values[k])

    print(SEPARATOR_STRING)

    storm_warned_flags = numpy.array(
        [s in linked_secondary_id_strings for s in secondary_id_strings],
        dtype=bool)

    print(('{0:d} of {1:d} storm objects are linked to an NWS tornado warning!'
           ).format(numpy.sum(storm_warned_flags), len(storm_warned_flags)))
Example #2
0
def _run(input_activation_file_name, unique_storm_cells,
         num_low_activation_examples, num_high_activation_examples, num_hits,
         num_misses, num_false_alarms, num_correct_nulls, top_target_dir_name,
         output_dir_name):
    """Finds extreme examples (storm objects), based on model activations.

    This is effectively the main method.

    :param input_activation_file_name: See documentation at top of file.
    :param unique_storm_cells: Same.
    :param num_low_activation_examples: Same.
    :param num_high_activation_examples: Same.
    :param num_hits: Same.
    :param num_misses: Same.
    :param num_false_alarms: Same.
    :param num_correct_nulls: Same.
    :param top_target_dir_name: Same.
    :param output_dir_name: Same.
    :raises: ValueError: if the activation file contains activations for more
        than one model component.
    """

    # Check input args.
    example_counts = numpy.array([
        num_low_activation_examples, num_high_activation_examples, num_hits,
        num_misses, num_false_alarms, num_correct_nulls
    ],
                                 dtype=int)

    error_checking.assert_is_geq_numpy_array(example_counts, 0)

    file_system_utils.mkdir_recursive_if_necessary(
        directory_name=output_dir_name)

    # Read activations.
    print('Reading activations from: "{0:s}"...'.format(
        input_activation_file_name))
    activation_matrix, activation_metadata_dict = model_activation.read_file(
        input_activation_file_name)

    num_model_components = activation_matrix.shape[1]
    if num_model_components > 1:
        error_string = (
            'The file should contain activations for only one model component, '
            'not {0:d}.').format(num_model_components)

        raise ValueError(error_string)

    storm_activations = activation_matrix[:, 0]
    full_id_strings = activation_metadata_dict[model_activation.FULL_IDS_KEY]
    storm_times_unix_sec = activation_metadata_dict[
        model_activation.STORM_TIMES_KEY]

    num_storm_objects = len(full_id_strings)
    error_checking.assert_is_leq(numpy.sum(example_counts), num_storm_objects)

    # Find high- and low-activation examples.
    if num_low_activation_examples + num_high_activation_examples > 0:
        high_indices, low_indices = (
            model_activation.get_hilo_activation_examples(
                storm_activations=storm_activations,
                num_low_activation_examples=num_low_activation_examples,
                num_high_activation_examples=num_high_activation_examples,
                unique_storm_cells=unique_storm_cells,
                full_storm_id_strings=full_id_strings))
    else:
        high_indices = numpy.array([], dtype=int)
        low_indices = numpy.array([], dtype=int)

    # Write high-activation examples to file.
    if len(high_indices) > 0:
        high_activation_file_name = '{0:s}/high_activation_examples.p'.format(
            output_dir_name)

        print((
            'Writing IDs and times for high-activation examples to: "{0:s}"...'
        ).format(high_activation_file_name))

        this_activation_matrix = numpy.reshape(storm_activations[high_indices],
                                               (len(high_indices), 1))

        model_activation.write_file(
            pickle_file_name=high_activation_file_name,
            activation_matrix=this_activation_matrix,
            full_id_strings=[full_id_strings[k] for k in high_indices],
            storm_times_unix_sec=storm_times_unix_sec[high_indices],
            model_file_name=activation_metadata_dict[
                model_activation.MODEL_FILE_NAME_KEY],
            component_type_string=activation_metadata_dict[
                model_activation.COMPONENT_TYPE_KEY],
            target_class=activation_metadata_dict[
                model_activation.TARGET_CLASS_KEY],
            layer_name=activation_metadata_dict[
                model_activation.LAYER_NAME_KEY],
            neuron_index_matrix=activation_metadata_dict[
                model_activation.NEURON_INDICES_KEY],
            channel_indices=activation_metadata_dict[
                model_activation.CHANNEL_INDICES_KEY])

    # Write low-activation examples to file
    if len(low_indices) > 0:
        low_activation_file_name = '{0:s}/low_activation_examples.p'.format(
            output_dir_name)

        print(
            ('Writing IDs and times for low-activation examples to: "{0:s}"...'
             ).format(low_activation_file_name))

        this_activation_matrix = numpy.reshape(storm_activations[low_indices],
                                               (len(low_indices), 1))

        model_activation.write_file(
            pickle_file_name=low_activation_file_name,
            activation_matrix=this_activation_matrix,
            full_id_strings=[full_id_strings[k] for k in low_indices],
            storm_times_unix_sec=storm_times_unix_sec[low_indices],
            model_file_name=activation_metadata_dict[
                model_activation.MODEL_FILE_NAME_KEY],
            component_type_string=activation_metadata_dict[
                model_activation.COMPONENT_TYPE_KEY],
            target_class=activation_metadata_dict[
                model_activation.TARGET_CLASS_KEY],
            layer_name=activation_metadata_dict[
                model_activation.LAYER_NAME_KEY],
            neuron_index_matrix=activation_metadata_dict[
                model_activation.NEURON_INDICES_KEY],
            channel_indices=activation_metadata_dict[
                model_activation.CHANNEL_INDICES_KEY])

    if num_hits + num_misses + num_false_alarms + num_correct_nulls == 0:
        return

    print(SEPARATOR_STRING)
    target_value_dict = _read_target_values(
        top_target_dir_name=top_target_dir_name,
        storm_activations=storm_activations,
        activation_metadata_dict=activation_metadata_dict)
    print(SEPARATOR_STRING)

    full_id_strings = target_value_dict[FULL_IDS_KEY]
    storm_times_unix_sec = target_value_dict[STORM_TIMES_KEY]
    storm_activations = target_value_dict[STORM_ACTIVATIONS_KEY]
    storm_target_values = target_value_dict[TARGET_VALUES_KEY]

    ct_extreme_dict = model_activation.get_contingency_table_extremes(
        storm_activations=storm_activations,
        storm_target_values=storm_target_values,
        num_hits=num_hits,
        num_misses=num_misses,
        num_false_alarms=num_false_alarms,
        num_correct_nulls=num_correct_nulls,
        unique_storm_cells=unique_storm_cells,
        full_storm_id_strings=full_id_strings)

    hit_indices = ct_extreme_dict[model_activation.HIT_INDICES_KEY]
    miss_indices = ct_extreme_dict[model_activation.MISS_INDICES_KEY]
    false_alarm_indices = (
        ct_extreme_dict[model_activation.FALSE_ALARM_INDICES_KEY])
    correct_null_indices = (
        ct_extreme_dict[model_activation.CORRECT_NULL_INDICES_KEY])

    hit_id_strings = [full_id_strings[k] for k in hit_indices]
    miss_id_strings = [full_id_strings[k] for k in miss_indices]
    false_alarm_id_strings = [full_id_strings[k] for k in false_alarm_indices]
    correct_null_id_strings = [
        full_id_strings[k] for k in correct_null_indices
    ]

    hit_primary_id_strings = (
        temporal_tracking.full_to_partial_ids(hit_id_strings)[0])
    miss_primary_id_strings = (
        temporal_tracking.full_to_partial_ids(miss_id_strings)[0])
    fa_primary_id_strings = (
        temporal_tracking.full_to_partial_ids(false_alarm_id_strings)[0])
    cn_primary_id_strings = (
        temporal_tracking.full_to_partial_ids(correct_null_id_strings)[0])

    these_flags = numpy.array(
        [i in miss_primary_id_strings for i in hit_primary_id_strings],
        dtype=bool)
    print(
        ('Number of primary IDs in best hits AND worst misses = {0:d}').format(
            numpy.sum(these_flags)))

    these_flags = numpy.array(
        [i in fa_primary_id_strings for i in hit_primary_id_strings],
        dtype=bool)
    print(('Number of primary IDs in best hits AND worst false alarms = {0:d}'
           ).format(numpy.sum(these_flags)))

    these_flags = numpy.array(
        [i in cn_primary_id_strings for i in hit_primary_id_strings],
        dtype=bool)
    print(('Number of primary IDs in best hits AND best correct nulls = {0:d}'
           ).format(numpy.sum(these_flags)))

    these_flags = numpy.array(
        [i in fa_primary_id_strings for i in miss_primary_id_strings],
        dtype=bool)
    print(
        ('Number of primary IDs in worst misses AND worst false alarms = {0:d}'
         ).format(numpy.sum(these_flags)))

    these_flags = numpy.array(
        [i in cn_primary_id_strings for i in miss_primary_id_strings],
        dtype=bool)
    print(
        ('Number of primary IDs in worst misses AND best correct nulls = {0:d}'
         ).format(numpy.sum(these_flags)))

    these_flags = numpy.array(
        [i in cn_primary_id_strings for i in fa_primary_id_strings],
        dtype=bool)
    print((
        'Number of primary IDs in worst false alarms AND best correct nulls = '
        '{0:d}').format(numpy.sum(these_flags)))

    # Write best hits to file.
    if len(hit_indices) > 0:
        best_hit_file_name = '{0:s}/best_hits.p'.format(output_dir_name)
        print('Writing IDs and times for best hits to: "{0:s}"...'.format(
            best_hit_file_name))

        this_activation_matrix = numpy.reshape(storm_activations[hit_indices],
                                               (len(hit_indices), 1))

        model_activation.write_file(
            pickle_file_name=best_hit_file_name,
            activation_matrix=this_activation_matrix,
            full_id_strings=[full_id_strings[k] for k in hit_indices],
            storm_times_unix_sec=storm_times_unix_sec[hit_indices],
            model_file_name=activation_metadata_dict[
                model_activation.MODEL_FILE_NAME_KEY],
            component_type_string=activation_metadata_dict[
                model_activation.COMPONENT_TYPE_KEY],
            target_class=activation_metadata_dict[
                model_activation.TARGET_CLASS_KEY],
            layer_name=activation_metadata_dict[
                model_activation.LAYER_NAME_KEY],
            neuron_index_matrix=activation_metadata_dict[
                model_activation.NEURON_INDICES_KEY],
            channel_indices=activation_metadata_dict[
                model_activation.CHANNEL_INDICES_KEY])

    # Write worst misses to file.
    if len(miss_indices) > 0:
        worst_miss_file_name = '{0:s}/worst_misses.p'.format(output_dir_name)
        print('Writing IDs and times for worst misses to: "{0:s}"...'.format(
            worst_miss_file_name))

        this_activation_matrix = numpy.reshape(storm_activations[miss_indices],
                                               (len(miss_indices), 1))

        model_activation.write_file(
            pickle_file_name=worst_miss_file_name,
            activation_matrix=this_activation_matrix,
            full_id_strings=[full_id_strings[k] for k in miss_indices],
            storm_times_unix_sec=storm_times_unix_sec[miss_indices],
            model_file_name=activation_metadata_dict[
                model_activation.MODEL_FILE_NAME_KEY],
            component_type_string=activation_metadata_dict[
                model_activation.COMPONENT_TYPE_KEY],
            target_class=activation_metadata_dict[
                model_activation.TARGET_CLASS_KEY],
            layer_name=activation_metadata_dict[
                model_activation.LAYER_NAME_KEY],
            neuron_index_matrix=activation_metadata_dict[
                model_activation.NEURON_INDICES_KEY],
            channel_indices=activation_metadata_dict[
                model_activation.CHANNEL_INDICES_KEY])

    # Write worst false alarms to file.
    if len(false_alarm_indices) > 0:
        worst_fa_file_name = '{0:s}/worst_false_alarms.p'.format(
            output_dir_name)

        print(('Writing IDs and times for worst false alarms to: "{0:s}"...'
               ).format(worst_fa_file_name))

        this_activation_matrix = numpy.reshape(
            storm_activations[false_alarm_indices],
            (len(false_alarm_indices), 1))

        model_activation.write_file(
            pickle_file_name=worst_fa_file_name,
            activation_matrix=this_activation_matrix,
            full_id_strings=[full_id_strings[k] for k in false_alarm_indices],
            storm_times_unix_sec=storm_times_unix_sec[false_alarm_indices],
            model_file_name=activation_metadata_dict[
                model_activation.MODEL_FILE_NAME_KEY],
            component_type_string=activation_metadata_dict[
                model_activation.COMPONENT_TYPE_KEY],
            target_class=activation_metadata_dict[
                model_activation.TARGET_CLASS_KEY],
            layer_name=activation_metadata_dict[
                model_activation.LAYER_NAME_KEY],
            neuron_index_matrix=activation_metadata_dict[
                model_activation.NEURON_INDICES_KEY],
            channel_indices=activation_metadata_dict[
                model_activation.CHANNEL_INDICES_KEY])

    # Write best correct nulls to file.
    if len(correct_null_indices) > 0:
        best_cn_file_name = '{0:s}/best_correct_nulls.p'.format(
            output_dir_name)

        print(('Writing IDs and times for best correct nulls to: "{0:s}"...'
               ).format(best_cn_file_name))

        this_activation_matrix = numpy.reshape(
            storm_activations[correct_null_indices],
            (len(correct_null_indices), 1))

        model_activation.write_file(
            pickle_file_name=best_cn_file_name,
            activation_matrix=this_activation_matrix,
            full_id_strings=[full_id_strings[k] for k in correct_null_indices],
            storm_times_unix_sec=storm_times_unix_sec[correct_null_indices],
            model_file_name=activation_metadata_dict[
                model_activation.MODEL_FILE_NAME_KEY],
            component_type_string=activation_metadata_dict[
                model_activation.COMPONENT_TYPE_KEY],
            target_class=activation_metadata_dict[
                model_activation.TARGET_CLASS_KEY],
            layer_name=activation_metadata_dict[
                model_activation.LAYER_NAME_KEY],
            neuron_index_matrix=activation_metadata_dict[
                model_activation.NEURON_INDICES_KEY],
            channel_indices=activation_metadata_dict[
                model_activation.CHANNEL_INDICES_KEY])
def _plot_one_example(full_id_string,
                      storm_time_unix_sec,
                      target_name,
                      forecast_probability,
                      tornado_dir_name,
                      top_tracking_dir_name,
                      top_myrorss_dir_name,
                      radar_field_name,
                      radar_height_m_asl,
                      latitude_buffer_deg,
                      longitude_buffer_deg,
                      top_output_dir_name,
                      aux_forecast_probabilities=None,
                      aux_activation_dict=None):
    """Plots one example with surrounding context at several times.

    N = number of storm objects read from auxiliary activation file

    :param full_id_string: Full storm ID.
    :param storm_time_unix_sec: Storm time.
    :param target_name: Name of target variable.
    :param forecast_probability: Forecast tornado probability for this example.
    :param tornado_dir_name: See documentation at top of file.
    :param top_tracking_dir_name: Same.
    :param top_myrorss_dir_name: Same.
    :param radar_field_name: Same.
    :param radar_height_m_asl: Same.
    :param latitude_buffer_deg: Same.
    :param longitude_buffer_deg: Same.
    :param top_output_dir_name: Same.
    :param aux_forecast_probabilities: length-N numpy array of forecast
        probabilities.  If this is None, will not plot forecast probs in maps.
    :param aux_activation_dict: Dictionary returned by
        `model_activation.read_file` from auxiliary file.  If this is None, will
        not plot forecast probs in maps.
    """

    storm_time_string = time_conversion.unix_sec_to_string(
        storm_time_unix_sec, TIME_FORMAT)

    # Create output directory for this example.
    output_dir_name = '{0:s}/{1:s}_{2:s}'.format(top_output_dir_name,
                                                 full_id_string,
                                                 storm_time_string)
    file_system_utils.mkdir_recursive_if_necessary(
        directory_name=output_dir_name)

    # Find tracking files.
    tracking_file_names = _find_tracking_files_one_example(
        valid_time_unix_sec=storm_time_unix_sec,
        top_tracking_dir_name=top_tracking_dir_name,
        target_name=target_name)

    tracking_times_unix_sec = numpy.array(
        [tracking_io.file_name_to_time(f) for f in tracking_file_names],
        dtype=int)

    tracking_time_strings = [
        time_conversion.unix_sec_to_string(t, TIME_FORMAT)
        for t in tracking_times_unix_sec
    ]

    # Read tracking files.
    storm_object_table = tracking_io.read_many_files(tracking_file_names)
    print('\n')

    if aux_activation_dict is not None:
        these_indices = tracking_utils.find_storm_objects(
            all_id_strings=aux_activation_dict[model_activation.FULL_IDS_KEY],
            all_times_unix_sec=aux_activation_dict[
                model_activation.STORM_TIMES_KEY],
            id_strings_to_keep=storm_object_table[
                tracking_utils.FULL_ID_COLUMN].values.tolist(),
            times_to_keep_unix_sec=storm_object_table[
                tracking_utils.VALID_TIME_COLUMN].values,
            allow_missing=True)

        storm_object_probs = numpy.array([
            aux_forecast_probabilities[k] if k >= 0 else numpy.nan
            for k in these_indices
        ])

        storm_object_table = storm_object_table.assign(
            **{FORECAST_PROBABILITY_COLUMN: storm_object_probs})

    primary_id_string = temporal_tracking.full_to_partial_ids([full_id_string
                                                               ])[0][0]

    this_storm_object_table = storm_object_table.loc[storm_object_table[
        tracking_utils.PRIMARY_ID_COLUMN] == primary_id_string]

    latitude_limits_deg, longitude_limits_deg = _get_plotting_limits(
        storm_object_table=this_storm_object_table,
        latitude_buffer_deg=latitude_buffer_deg,
        longitude_buffer_deg=longitude_buffer_deg)

    storm_min_latitudes_deg = numpy.array([
        numpy.min(numpy.array(p.exterior.xy[1])) for p in storm_object_table[
            tracking_utils.LATLNG_POLYGON_COLUMN].values
    ])

    storm_max_latitudes_deg = numpy.array([
        numpy.max(numpy.array(p.exterior.xy[1])) for p in storm_object_table[
            tracking_utils.LATLNG_POLYGON_COLUMN].values
    ])

    storm_min_longitudes_deg = numpy.array([
        numpy.min(numpy.array(p.exterior.xy[0])) for p in storm_object_table[
            tracking_utils.LATLNG_POLYGON_COLUMN].values
    ])

    storm_max_longitudes_deg = numpy.array([
        numpy.max(numpy.array(p.exterior.xy[0])) for p in storm_object_table[
            tracking_utils.LATLNG_POLYGON_COLUMN].values
    ])

    min_latitude_flags = numpy.logical_and(
        storm_min_latitudes_deg >= latitude_limits_deg[0],
        storm_min_latitudes_deg <= latitude_limits_deg[1])

    max_latitude_flags = numpy.logical_and(
        storm_max_latitudes_deg >= latitude_limits_deg[0],
        storm_max_latitudes_deg <= latitude_limits_deg[1])

    latitude_flags = numpy.logical_or(min_latitude_flags, max_latitude_flags)

    min_longitude_flags = numpy.logical_and(
        storm_min_longitudes_deg >= longitude_limits_deg[0],
        storm_min_longitudes_deg <= longitude_limits_deg[1])

    max_longitude_flags = numpy.logical_and(
        storm_max_longitudes_deg >= longitude_limits_deg[0],
        storm_max_longitudes_deg <= longitude_limits_deg[1])

    longitude_flags = numpy.logical_or(min_longitude_flags,
                                       max_longitude_flags)
    good_indices = numpy.where(
        numpy.logical_and(latitude_flags, longitude_flags))[0]

    storm_object_table = storm_object_table.iloc[good_indices]

    # Read tornado reports.
    target_param_dict = target_val_utils.target_name_to_params(target_name)
    min_lead_time_seconds = target_param_dict[
        target_val_utils.MIN_LEAD_TIME_KEY]
    max_lead_time_seconds = target_param_dict[
        target_val_utils.MAX_LEAD_TIME_KEY]

    tornado_table = linkage._read_input_tornado_reports(
        input_directory_name=tornado_dir_name,
        storm_times_unix_sec=numpy.array([storm_time_unix_sec], dtype=int),
        max_time_before_storm_start_sec=-1 * min_lead_time_seconds,
        max_time_after_storm_end_sec=max_lead_time_seconds,
        genesis_only=True)

    tornado_table = tornado_table.loc[
        (tornado_table[linkage.EVENT_LATITUDE_COLUMN] >= latitude_limits_deg[0]
         )
        & (tornado_table[linkage.EVENT_LATITUDE_COLUMN] <=
           latitude_limits_deg[1])]

    tornado_table = tornado_table.loc[
        (tornado_table[linkage.EVENT_LONGITUDE_COLUMN] >=
         longitude_limits_deg[0])
        & (tornado_table[linkage.EVENT_LONGITUDE_COLUMN] <=
           longitude_limits_deg[1])]

    for i in range(len(tracking_file_names)):
        this_storm_object_table = storm_object_table.loc[storm_object_table[
            tracking_utils.VALID_TIME_COLUMN] == tracking_times_unix_sec[i]]

        _plot_one_example_one_time(
            storm_object_table=this_storm_object_table,
            full_id_string=full_id_string,
            valid_time_unix_sec=tracking_times_unix_sec[i],
            tornado_table=copy.deepcopy(tornado_table),
            top_myrorss_dir_name=top_myrorss_dir_name,
            radar_field_name=radar_field_name,
            radar_height_m_asl=radar_height_m_asl,
            latitude_limits_deg=latitude_limits_deg,
            longitude_limits_deg=longitude_limits_deg)

        if aux_activation_dict is None:
            this_title_string = (
                'Valid time = {0:s} ... forecast prob at {1:s} = {2:.3f}'
            ).format(tracking_time_strings[i], storm_time_string,
                     forecast_probability)

            pyplot.title(this_title_string, fontsize=TITLE_FONT_SIZE)

        this_file_name = '{0:s}/{1:s}.jpg'.format(output_dir_name,
                                                  tracking_time_strings[i])

        print('Saving figure to file: "{0:s}"...\n'.format(this_file_name))
        pyplot.savefig(this_file_name, dpi=FIGURE_RESOLUTION_DPI)
        pyplot.close()

        imagemagick_utils.trim_whitespace(input_file_name=this_file_name,
                                          output_file_name=this_file_name)
def _plot_one_example_one_time(storm_object_table, full_id_string,
                               valid_time_unix_sec, tornado_table,
                               top_myrorss_dir_name, radar_field_name,
                               radar_height_m_asl, latitude_limits_deg,
                               longitude_limits_deg):
    """Plots one example with surrounding context at one time.

    :param storm_object_table: pandas DataFrame, containing only storm objects
        at one time with the relevant primary ID.  Columns are documented in
        `storm_tracking_io.write_file`.
    :param full_id_string: Full ID of storm of interest.
    :param valid_time_unix_sec: Valid time.
    :param tornado_table: pandas DataFrame created by
        `linkage._read_input_tornado_reports`.
    :param top_myrorss_dir_name: See documentation at top of file.
    :param radar_field_name: Same.
    :param radar_height_m_asl: Same.
    :param latitude_limits_deg: See doc for `_get_plotting_limits`.
    :param longitude_limits_deg: Same.
    """

    min_plot_latitude_deg = latitude_limits_deg[0]
    max_plot_latitude_deg = latitude_limits_deg[1]
    min_plot_longitude_deg = longitude_limits_deg[0]
    max_plot_longitude_deg = longitude_limits_deg[1]

    radar_file_name = myrorss_and_mrms_io.find_raw_file(
        top_directory_name=top_myrorss_dir_name,
        spc_date_string=time_conversion.time_to_spc_date_string(
            valid_time_unix_sec),
        unix_time_sec=valid_time_unix_sec,
        data_source=radar_utils.MYRORSS_SOURCE_ID,
        field_name=radar_field_name,
        height_m_asl=radar_height_m_asl,
        raise_error_if_missing=True)

    print('Reading data from: "{0:s}"...'.format(radar_file_name))

    radar_metadata_dict = myrorss_and_mrms_io.read_metadata_from_raw_file(
        netcdf_file_name=radar_file_name,
        data_source=radar_utils.MYRORSS_SOURCE_ID)

    sparse_grid_table = (myrorss_and_mrms_io.read_data_from_sparse_grid_file(
        netcdf_file_name=radar_file_name,
        field_name_orig=radar_metadata_dict[
            myrorss_and_mrms_io.FIELD_NAME_COLUMN_ORIG],
        data_source=radar_utils.MYRORSS_SOURCE_ID,
        sentinel_values=radar_metadata_dict[radar_utils.SENTINEL_VALUE_COLUMN])
                         )

    radar_matrix, grid_point_latitudes_deg, grid_point_longitudes_deg = (
        radar_s2f.sparse_to_full_grid(sparse_grid_table=sparse_grid_table,
                                      metadata_dict=radar_metadata_dict))

    radar_matrix = numpy.flip(radar_matrix, axis=0)
    grid_point_latitudes_deg = grid_point_latitudes_deg[::-1]

    axes_object, basemap_object = (
        plotting_utils.create_equidist_cylindrical_map(
            min_latitude_deg=min_plot_latitude_deg,
            max_latitude_deg=max_plot_latitude_deg,
            min_longitude_deg=min_plot_longitude_deg,
            max_longitude_deg=max_plot_longitude_deg,
            resolution_string='i')[1:])

    plotting_utils.plot_coastlines(basemap_object=basemap_object,
                                   axes_object=axes_object,
                                   line_colour=BORDER_COLOUR)

    plotting_utils.plot_countries(basemap_object=basemap_object,
                                  axes_object=axes_object,
                                  line_colour=BORDER_COLOUR)

    plotting_utils.plot_states_and_provinces(basemap_object=basemap_object,
                                             axes_object=axes_object,
                                             line_colour=BORDER_COLOUR)

    plotting_utils.plot_parallels(basemap_object=basemap_object,
                                  axes_object=axes_object,
                                  num_parallels=NUM_PARALLELS)

    plotting_utils.plot_meridians(basemap_object=basemap_object,
                                  axes_object=axes_object,
                                  num_meridians=NUM_MERIDIANS)

    radar_plotting.plot_latlng_grid(
        field_matrix=radar_matrix,
        field_name=radar_field_name,
        axes_object=axes_object,
        min_grid_point_latitude_deg=numpy.min(grid_point_latitudes_deg),
        min_grid_point_longitude_deg=numpy.min(grid_point_longitudes_deg),
        latitude_spacing_deg=numpy.diff(grid_point_latitudes_deg[:2])[0],
        longitude_spacing_deg=numpy.diff(grid_point_longitudes_deg[:2])[0])

    colour_map_object, colour_norm_object = (
        radar_plotting.get_default_colour_scheme(radar_field_name))

    plotting_utils.plot_colour_bar(axes_object_or_matrix=axes_object,
                                   data_matrix=radar_matrix,
                                   colour_map_object=colour_map_object,
                                   colour_norm_object=colour_norm_object,
                                   orientation_string='horizontal',
                                   extend_min=False,
                                   extend_max=True,
                                   fraction_of_axis_length=0.8)

    first_list, second_list = temporal_tracking.full_to_partial_ids(
        [full_id_string])
    primary_id_string = first_list[0]
    secondary_id_string = second_list[0]

    # Plot outlines of unrelated storms (with different primary IDs).
    this_storm_object_table = storm_object_table.loc[storm_object_table[
        tracking_utils.PRIMARY_ID_COLUMN] != primary_id_string]

    storm_plotting.plot_storm_outlines(
        storm_object_table=this_storm_object_table,
        axes_object=axes_object,
        basemap_object=basemap_object,
        line_width=2,
        line_colour='k',
        line_style='dashed')

    # Plot outlines of related storms (with the same primary ID).
    this_storm_object_table = storm_object_table.loc[
        (storm_object_table[tracking_utils.PRIMARY_ID_COLUMN] ==
         primary_id_string) & (storm_object_table[
             tracking_utils.SECONDARY_ID_COLUMN] != secondary_id_string)]

    this_num_storm_objects = len(this_storm_object_table.index)

    if this_num_storm_objects > 0:
        storm_plotting.plot_storm_outlines(
            storm_object_table=this_storm_object_table,
            axes_object=axes_object,
            basemap_object=basemap_object,
            line_width=2,
            line_colour='k',
            line_style='solid')

        for j in range(len(this_storm_object_table)):
            axes_object.text(
                this_storm_object_table[
                    tracking_utils.CENTROID_LONGITUDE_COLUMN].values[j],
                this_storm_object_table[
                    tracking_utils.CENTROID_LATITUDE_COLUMN].values[j],
                'P',
                fontsize=FONT_SIZE,
                color=FONT_COLOUR,
                fontweight='bold',
                horizontalalignment='center',
                verticalalignment='center')

    # Plot outline of storm of interest (same secondary ID).
    this_storm_object_table = storm_object_table.loc[storm_object_table[
        tracking_utils.SECONDARY_ID_COLUMN] == secondary_id_string]

    storm_plotting.plot_storm_outlines(
        storm_object_table=this_storm_object_table,
        axes_object=axes_object,
        basemap_object=basemap_object,
        line_width=4,
        line_colour='k',
        line_style='solid')

    this_num_storm_objects = len(this_storm_object_table.index)

    plot_forecast = (this_num_storm_objects > 0 and FORECAST_PROBABILITY_COLUMN
                     in list(this_storm_object_table))

    if plot_forecast:
        this_polygon_object_latlng = this_storm_object_table[
            tracking_utils.LATLNG_POLYGON_COLUMN].values[0]

        this_latitude_deg = numpy.min(
            numpy.array(this_polygon_object_latlng.exterior.xy[1]))

        this_longitude_deg = this_storm_object_table[
            tracking_utils.CENTROID_LONGITUDE_COLUMN].values[0]

        label_string = 'Prob = {0:.3f}\nat {1:s}'.format(
            this_storm_object_table[FORECAST_PROBABILITY_COLUMN].values[0],
            time_conversion.unix_sec_to_string(valid_time_unix_sec,
                                               TORNADO_TIME_FORMAT))

        bounding_box_dict = {
            'facecolor':
            plotting_utils.colour_from_numpy_to_tuple(
                PROBABILITY_BACKGROUND_COLOUR),
            'alpha':
            PROBABILITY_BACKGROUND_OPACITY,
            'edgecolor':
            'k',
            'linewidth':
            1
        }

        axes_object.text(this_longitude_deg,
                         this_latitude_deg,
                         label_string,
                         fontsize=FONT_SIZE,
                         color=plotting_utils.colour_from_numpy_to_tuple(
                             PROBABILITY_FONT_COLOUR),
                         fontweight='bold',
                         bbox=bounding_box_dict,
                         horizontalalignment='center',
                         verticalalignment='top',
                         zorder=1e10)

    tornado_latitudes_deg = tornado_table[linkage.EVENT_LATITUDE_COLUMN].values
    tornado_longitudes_deg = tornado_table[
        linkage.EVENT_LONGITUDE_COLUMN].values

    tornado_times_unix_sec = tornado_table[linkage.EVENT_TIME_COLUMN].values
    tornado_time_strings = [
        time_conversion.unix_sec_to_string(t, TORNADO_TIME_FORMAT)
        for t in tornado_times_unix_sec
    ]

    axes_object.plot(tornado_longitudes_deg,
                     tornado_latitudes_deg,
                     linestyle='None',
                     marker=TORNADO_MARKER_TYPE,
                     markersize=TORNADO_MARKER_SIZE,
                     markeredgewidth=TORNADO_MARKER_EDGE_WIDTH,
                     markerfacecolor=plotting_utils.colour_from_numpy_to_tuple(
                         TORNADO_MARKER_COLOUR),
                     markeredgecolor=plotting_utils.colour_from_numpy_to_tuple(
                         TORNADO_MARKER_COLOUR))

    num_tornadoes = len(tornado_latitudes_deg)

    for j in range(num_tornadoes):
        axes_object.text(tornado_longitudes_deg[j] + 0.02,
                         tornado_latitudes_deg[j] - 0.02,
                         tornado_time_strings[j],
                         fontsize=FONT_SIZE,
                         color=FONT_COLOUR,
                         fontweight='bold',
                         horizontalalignment='left',
                         verticalalignment='top')
Example #5
0
def _run(storm_metafile_name, top_tracking_dir_name, lead_time_seconds,
         output_file_name):
    """Plots spatial distribution of examples (storm objects) in file.

    This is effectively the main method.

    :param storm_metafile_name: See documentation at top of file.
    :param top_tracking_dir_name: Same.
    :param lead_time_seconds: Same.
    :param output_file_name: Same.
    """

    file_system_utils.mkdir_recursive_if_necessary(file_name=output_file_name)

    # Read storm metadata.
    print(
        'Reading storm metadata from: "{0:s}"...'.format(storm_metafile_name))
    orig_full_id_strings, orig_times_unix_sec = (
        tracking_io.read_ids_and_times(storm_metafile_name))
    orig_primary_id_strings = temporal_tracking.full_to_partial_ids(
        orig_full_id_strings)[0]

    # Find relevant tracking files.
    spc_date_strings = [
        time_conversion.time_to_spc_date_string(t) for t in orig_times_unix_sec
    ]
    spc_date_strings += [
        time_conversion.time_to_spc_date_string(t + lead_time_seconds)
        for t in orig_times_unix_sec
    ]
    spc_date_strings = list(set(spc_date_strings))

    tracking_file_names = []

    for this_spc_date_string in spc_date_strings:
        tracking_file_names += tracking_io.find_files_one_spc_date(
            top_tracking_dir_name=top_tracking_dir_name,
            tracking_scale_metres2=DUMMY_TRACKING_SCALE_METRES2,
            source_name=tracking_utils.SEGMOTION_NAME,
            spc_date_string=this_spc_date_string,
            raise_error_if_missing=False)[0]

    file_times_unix_sec = numpy.array(
        [tracking_io.file_name_to_time(f) for f in tracking_file_names],
        dtype=int)

    num_orig_storm_objects = len(orig_full_id_strings)
    num_files = len(file_times_unix_sec)
    keep_file_flags = numpy.full(num_files, 0, dtype=bool)

    for i in range(num_orig_storm_objects):
        these_flags = numpy.logical_and(
            file_times_unix_sec >= orig_times_unix_sec[i],
            file_times_unix_sec <= orig_times_unix_sec[i] + lead_time_seconds)
        keep_file_flags = numpy.logical_or(keep_file_flags, these_flags)

    del file_times_unix_sec
    keep_file_indices = numpy.where(keep_file_flags)[0]
    tracking_file_names = [tracking_file_names[k] for k in keep_file_indices]

    # Read relevant tracking files.
    num_files = len(tracking_file_names)
    storm_object_tables = [None] * num_files
    print(SEPARATOR_STRING)

    for i in range(num_files):
        print('Reading data from: "{0:s}"...'.format(tracking_file_names[i]))
        this_table = tracking_io.read_file(tracking_file_names[i])

        storm_object_tables[i] = this_table.loc[this_table[
            tracking_utils.PRIMARY_ID_COLUMN].isin(
                numpy.array(orig_primary_id_strings))]

        if i == 0:
            continue

        storm_object_tables[i] = storm_object_tables[i].align(
            storm_object_tables[0], axis=1)[0]

    storm_object_table = pandas.concat(storm_object_tables,
                                       axis=0,
                                       ignore_index=True)
    print(SEPARATOR_STRING)

    # Find relevant storm objects.
    orig_object_rows = tracking_utils.find_storm_objects(
        all_id_strings=storm_object_table[
            tracking_utils.FULL_ID_COLUMN].values.tolist(),
        all_times_unix_sec=storm_object_table[
            tracking_utils.VALID_TIME_COLUMN].values,
        id_strings_to_keep=orig_full_id_strings,
        times_to_keep_unix_sec=orig_times_unix_sec)

    good_object_rows = numpy.array([], dtype=int)

    for i in range(num_orig_storm_objects):
        # Non-merging successors only!

        first_rows = temporal_tracking.find_successors(
            storm_object_table=storm_object_table,
            target_row=orig_object_rows[i],
            num_seconds_forward=lead_time_seconds,
            max_num_sec_id_changes=1,
            change_type_string=temporal_tracking.SPLIT_STRING,
            return_all_on_path=True)

        second_rows = temporal_tracking.find_successors(
            storm_object_table=storm_object_table,
            target_row=orig_object_rows[i],
            num_seconds_forward=lead_time_seconds,
            max_num_sec_id_changes=0,
            change_type_string=temporal_tracking.MERGER_STRING,
            return_all_on_path=True)

        first_rows = first_rows.tolist()
        second_rows = second_rows.tolist()
        these_rows = set(first_rows) & set(second_rows)
        these_rows = numpy.array(list(these_rows), dtype=int)

        good_object_rows = numpy.concatenate((good_object_rows, these_rows))

    good_object_rows = numpy.unique(good_object_rows)
    storm_object_table = storm_object_table.iloc[good_object_rows]

    times_of_day_sec = numpy.mod(
        storm_object_table[tracking_utils.VALID_TIME_COLUMN].values,
        NUM_SECONDS_IN_DAY)
    storm_object_table = storm_object_table.assign(
        **{tracking_utils.VALID_TIME_COLUMN: times_of_day_sec})

    min_plot_latitude_deg = -LATLNG_BUFFER_DEG + numpy.min(
        storm_object_table[tracking_utils.CENTROID_LATITUDE_COLUMN].values)
    max_plot_latitude_deg = LATLNG_BUFFER_DEG + numpy.max(
        storm_object_table[tracking_utils.CENTROID_LATITUDE_COLUMN].values)
    min_plot_longitude_deg = -LATLNG_BUFFER_DEG + numpy.min(
        storm_object_table[tracking_utils.CENTROID_LONGITUDE_COLUMN].values)
    max_plot_longitude_deg = LATLNG_BUFFER_DEG + numpy.max(
        storm_object_table[tracking_utils.CENTROID_LONGITUDE_COLUMN].values)

    _, axes_object, basemap_object = (
        plotting_utils.create_equidist_cylindrical_map(
            min_latitude_deg=min_plot_latitude_deg,
            max_latitude_deg=max_plot_latitude_deg,
            min_longitude_deg=min_plot_longitude_deg,
            max_longitude_deg=max_plot_longitude_deg,
            resolution_string='i'))

    plotting_utils.plot_coastlines(basemap_object=basemap_object,
                                   axes_object=axes_object,
                                   line_colour=BORDER_COLOUR,
                                   line_width=BORDER_WIDTH * 2)
    plotting_utils.plot_countries(basemap_object=basemap_object,
                                  axes_object=axes_object,
                                  line_colour=BORDER_COLOUR,
                                  line_width=BORDER_WIDTH)
    plotting_utils.plot_states_and_provinces(basemap_object=basemap_object,
                                             axes_object=axes_object,
                                             line_colour=BORDER_COLOUR,
                                             line_width=BORDER_WIDTH)
    plotting_utils.plot_parallels(basemap_object=basemap_object,
                                  axes_object=axes_object,
                                  num_parallels=NUM_PARALLELS,
                                  line_width=BORDER_WIDTH)
    plotting_utils.plot_meridians(basemap_object=basemap_object,
                                  axes_object=axes_object,
                                  num_meridians=NUM_MERIDIANS,
                                  line_width=BORDER_WIDTH)

    # colour_bar_object = storm_plotting.plot_storm_tracks(
    #     storm_object_table=storm_object_table, axes_object=axes_object,
    #     basemap_object=basemap_object, colour_map_object=COLOUR_MAP_OBJECT,
    #     colour_min_unix_sec=0, colour_max_unix_sec=NUM_SECONDS_IN_DAY - 1,
    #     line_width=TRACK_LINE_WIDTH,
    #     start_marker_type=None, end_marker_type=None
    # )

    colour_bar_object = storm_plotting.plot_storm_centroids(
        storm_object_table=storm_object_table,
        axes_object=axes_object,
        basemap_object=basemap_object,
        colour_map_object=COLOUR_MAP_OBJECT,
        colour_min_unix_sec=0,
        colour_max_unix_sec=NUM_SECONDS_IN_DAY - 1)

    tick_times_unix_sec = numpy.linspace(0,
                                         NUM_SECONDS_IN_DAY,
                                         num=NUM_HOURS_IN_DAY + 1,
                                         dtype=int)
    tick_times_unix_sec = tick_times_unix_sec[:-1]
    tick_times_unix_sec = tick_times_unix_sec[::2]

    tick_time_strings = [
        time_conversion.unix_sec_to_string(t, COLOUR_BAR_TIME_FORMAT)
        for t in tick_times_unix_sec
    ]

    colour_bar_object.set_ticks(tick_times_unix_sec)
    colour_bar_object.set_ticklabels(tick_time_strings)

    print('Saving figure to: "{0:s}"...'.format(output_file_name))
    pyplot.savefig(output_file_name,
                   dpi=FIGURE_RESOLUTION_DPI,
                   pad_inches=0,
                   bbox_inches='tight')
    pyplot.close()
Example #6
0
def _plot_one_example_one_time(
        storm_object_table, full_id_string, valid_time_unix_sec,
        tornado_table, top_myrorss_dir_name, radar_field_name,
        radar_height_m_asl, latitude_limits_deg, longitude_limits_deg):
    """Plots one example with surrounding context at one time.

    :param storm_object_table: pandas DataFrame, containing only storm objects
        at one time with the relevant primary ID.  Columns are documented in
        `storm_tracking_io.write_file`.
    :param full_id_string: Full ID of storm of interest.
    :param valid_time_unix_sec: Valid time.
    :param tornado_table: pandas DataFrame created by
        `linkage._read_input_tornado_reports`.
    :param top_myrorss_dir_name: See documentation at top of file.
    :param radar_field_name: Same.
    :param radar_height_m_asl: Same.
    :param latitude_limits_deg: See doc for `_get_plotting_limits`.
    :param longitude_limits_deg: Same.
    """

    min_plot_latitude_deg = latitude_limits_deg[0]
    max_plot_latitude_deg = latitude_limits_deg[1]
    min_plot_longitude_deg = longitude_limits_deg[0]
    max_plot_longitude_deg = longitude_limits_deg[1]

    radar_file_name = myrorss_and_mrms_io.find_raw_file_inexact_time(
        top_directory_name=top_myrorss_dir_name,
        desired_time_unix_sec=valid_time_unix_sec,
        spc_date_string=time_conversion.time_to_spc_date_string(
            valid_time_unix_sec),
        data_source=radar_utils.MYRORSS_SOURCE_ID,
        field_name=radar_field_name, height_m_asl=radar_height_m_asl,
        max_time_offset_sec=
        myrorss_and_mrms_io.DEFAULT_MAX_TIME_OFFSET_FOR_NON_SHEAR_SEC,
        raise_error_if_missing=True)

    print('Reading data from: "{0:s}"...'.format(radar_file_name))

    radar_metadata_dict = myrorss_and_mrms_io.read_metadata_from_raw_file(
        netcdf_file_name=radar_file_name,
        data_source=radar_utils.MYRORSS_SOURCE_ID)

    sparse_grid_table = (
        myrorss_and_mrms_io.read_data_from_sparse_grid_file(
            netcdf_file_name=radar_file_name,
            field_name_orig=radar_metadata_dict[
                myrorss_and_mrms_io.FIELD_NAME_COLUMN_ORIG],
            data_source=radar_utils.MYRORSS_SOURCE_ID,
            sentinel_values=radar_metadata_dict[
                radar_utils.SENTINEL_VALUE_COLUMN]
        )
    )

    radar_matrix, grid_point_latitudes_deg, grid_point_longitudes_deg = (
        radar_s2f.sparse_to_full_grid(
            sparse_grid_table=sparse_grid_table,
            metadata_dict=radar_metadata_dict)
    )

    radar_matrix = numpy.flip(radar_matrix, axis=0)
    grid_point_latitudes_deg = grid_point_latitudes_deg[::-1]

    axes_object, basemap_object = (
        plotting_utils.create_equidist_cylindrical_map(
            min_latitude_deg=min_plot_latitude_deg,
            max_latitude_deg=max_plot_latitude_deg,
            min_longitude_deg=min_plot_longitude_deg,
            max_longitude_deg=max_plot_longitude_deg, resolution_string='h'
        )[1:]
    )

    plotting_utils.plot_coastlines(
        basemap_object=basemap_object, axes_object=axes_object,
        line_colour=plotting_utils.DEFAULT_COUNTRY_COLOUR)

    plotting_utils.plot_countries(
        basemap_object=basemap_object, axes_object=axes_object)

    plotting_utils.plot_states_and_provinces(
        basemap_object=basemap_object, axes_object=axes_object)

    plotting_utils.plot_parallels(
        basemap_object=basemap_object, axes_object=axes_object,
        num_parallels=NUM_PARALLELS, line_width=0)

    plotting_utils.plot_meridians(
        basemap_object=basemap_object, axes_object=axes_object,
        num_meridians=NUM_MERIDIANS, line_width=0)

    radar_plotting.plot_latlng_grid(
        field_matrix=radar_matrix, field_name=radar_field_name,
        axes_object=axes_object,
        min_grid_point_latitude_deg=numpy.min(grid_point_latitudes_deg),
        min_grid_point_longitude_deg=numpy.min(grid_point_longitudes_deg),
        latitude_spacing_deg=numpy.diff(grid_point_latitudes_deg[:2])[0],
        longitude_spacing_deg=numpy.diff(grid_point_longitudes_deg[:2])[0]
    )

    colour_map_object, colour_norm_object = (
        radar_plotting.get_default_colour_scheme(radar_field_name)
    )

    plotting_utils.plot_colour_bar(
        axes_object_or_matrix=axes_object, data_matrix=radar_matrix,
        colour_map_object=colour_map_object,
        colour_norm_object=colour_norm_object, orientation_string='horizontal',
        padding=0.05, extend_min=False, extend_max=True,
        fraction_of_axis_length=0.8)

    first_list, second_list = temporal_tracking.full_to_partial_ids(
        [full_id_string]
    )
    primary_id_string = first_list[0]
    secondary_id_string = second_list[0]

    # Plot outlines of unrelated storms (with different primary IDs).
    this_storm_object_table = storm_object_table.loc[
        storm_object_table[tracking_utils.PRIMARY_ID_COLUMN] !=
        primary_id_string
    ]

    storm_plotting.plot_storm_outlines(
        storm_object_table=this_storm_object_table, axes_object=axes_object,
        basemap_object=basemap_object, line_width=AUXILIARY_STORM_WIDTH,
        line_colour='k', line_style='dashed')

    # Plot outlines of related storms (with the same primary ID).
    this_storm_object_table = storm_object_table.loc[
        (storm_object_table[tracking_utils.PRIMARY_ID_COLUMN] ==
         primary_id_string) &
        (storm_object_table[tracking_utils.SECONDARY_ID_COLUMN] !=
         secondary_id_string)
    ]

    this_num_storm_objects = len(this_storm_object_table.index)

    if this_num_storm_objects > 0:
        storm_plotting.plot_storm_outlines(
            storm_object_table=this_storm_object_table, axes_object=axes_object,
            basemap_object=basemap_object, line_width=AUXILIARY_STORM_WIDTH,
            line_colour='k', line_style='solid'
        )

        for j in range(len(this_storm_object_table)):
            axes_object.text(
                this_storm_object_table[
                    tracking_utils.CENTROID_LONGITUDE_COLUMN
                ].values[j],
                this_storm_object_table[
                    tracking_utils.CENTROID_LATITUDE_COLUMN
                ].values[j],
                'P',
                fontsize=MAIN_FONT_SIZE, color=FONT_COLOUR, fontweight='bold',
                horizontalalignment='center', verticalalignment='center'
            )

    # Plot outline of storm of interest (same secondary ID).
    this_storm_object_table = storm_object_table.loc[
        storm_object_table[tracking_utils.SECONDARY_ID_COLUMN] ==
        secondary_id_string
    ]

    storm_plotting.plot_storm_outlines(
        storm_object_table=this_storm_object_table, axes_object=axes_object,
        basemap_object=basemap_object, line_width=MAIN_STORM_WIDTH,
        line_colour='k', line_style='solid')

    this_num_storm_objects = len(this_storm_object_table.index)

    plot_forecast = (
        this_num_storm_objects > 0 and
        FORECAST_PROBABILITY_COLUMN in list(this_storm_object_table)
    )

    if plot_forecast:
        label_string = 'Prob = {0:.3f}\nat {1:s}'.format(
            this_storm_object_table[FORECAST_PROBABILITY_COLUMN].values[0],
            time_conversion.unix_sec_to_string(
                valid_time_unix_sec, TORNADO_TIME_FORMAT)
        )

        axes_object.set_title(
            label_string.replace('\n', ' '), fontsize=TITLE_FONT_SIZE
        )

    tornado_id_strings = tornado_table[tornado_io.TORNADO_ID_COLUMN].values

    for this_tornado_id_string in numpy.unique(tornado_id_strings):
        these_rows = numpy.where(
            tornado_id_strings == this_tornado_id_string
        )[0]

        this_tornado_table = tornado_table.iloc[these_rows].sort_values(
            linkage.EVENT_TIME_COLUMN, axis=0, ascending=True, inplace=False
        )
        _plot_one_tornado(
            tornado_table=this_tornado_table, axes_object=axes_object
        )
Example #7
0
def _run(top_input_dir_name, target_name_for_downsampling,
         first_spc_date_string, last_spc_date_string, downsampling_classes,
         downsampling_fractions, for_training, top_output_dir_name):
    """Downsamples storm objects, based on target values.

    This is effectively the main method.

    :param top_input_dir_name: See documentation at top of file.
    :param target_name_for_downsampling: Same.
    :param first_spc_date_string: Same.
    :param last_spc_date_string: Same.
    :param downsampling_classes: Same.
    :param downsampling_fractions: Same.
    :param for_training: Same.
    :param top_output_dir_name: Same.
    """

    all_spc_date_strings = time_conversion.get_spc_dates_in_range(
        first_spc_date_string=first_spc_date_string,
        last_spc_date_string=last_spc_date_string)

    downsampling_dict = dict(
        list(zip(downsampling_classes, downsampling_fractions)))

    target_param_dict = target_val_utils.target_name_to_params(
        target_name_for_downsampling)
    event_type_string = target_param_dict[target_val_utils.EVENT_TYPE_KEY]

    input_target_file_names = []
    spc_date_string_by_file = []

    for this_spc_date_string in all_spc_date_strings:
        this_file_name = target_val_utils.find_target_file(
            top_directory_name=top_input_dir_name,
            event_type_string=event_type_string,
            spc_date_string=this_spc_date_string,
            raise_error_if_missing=False)

        if not os.path.isfile(this_file_name):
            continue

        input_target_file_names.append(this_file_name)
        spc_date_string_by_file.append(this_spc_date_string)

    num_files = len(input_target_file_names)
    target_dict_by_file = [None] * num_files

    full_id_strings = []
    storm_times_unix_sec = numpy.array([], dtype=int)
    storm_to_file_indices = numpy.array([], dtype=int)

    target_names = []
    target_matrix = None

    for i in range(num_files):
        print('Reading data from: "{0:s}"...'.format(
            input_target_file_names[i]))

        target_dict_by_file[i] = target_val_utils.read_target_values(
            netcdf_file_name=input_target_file_names[i])

        if i == 0:
            target_names = (
                target_dict_by_file[i][target_val_utils.TARGET_NAMES_KEY])

        these_full_id_strings = (
            target_dict_by_file[i][target_val_utils.FULL_IDS_KEY])

        full_id_strings += these_full_id_strings
        this_num_storm_objects = len(these_full_id_strings)

        storm_times_unix_sec = numpy.concatenate(
            (storm_times_unix_sec,
             target_dict_by_file[i][target_val_utils.VALID_TIMES_KEY]))

        storm_to_file_indices = numpy.concatenate(
            (storm_to_file_indices,
             numpy.full(this_num_storm_objects, i, dtype=int)))

        this_target_matrix = (
            target_dict_by_file[i][target_val_utils.TARGET_MATRIX_KEY])

        if target_matrix is None:
            target_matrix = this_target_matrix + 0
        else:
            target_matrix = numpy.concatenate(
                (target_matrix, this_target_matrix), axis=0)

    print(SEPARATOR_STRING)

    downsampling_index = target_names.index(target_name_for_downsampling)
    good_indices = numpy.where(target_matrix[:, downsampling_index] !=
                               target_val_utils.INVALID_STORM_INTEGER)[0]

    full_id_strings = [full_id_strings[k] for k in good_indices]
    storm_times_unix_sec = storm_times_unix_sec[good_indices]
    target_matrix = target_matrix[good_indices, :]
    storm_to_file_indices = storm_to_file_indices[good_indices]

    primary_id_strings = temporal_tracking.full_to_partial_ids(
        full_id_strings)[0]

    if for_training:
        indices_to_keep = fancy_downsampling.downsample_for_training(
            primary_id_strings=primary_id_strings,
            storm_times_unix_sec=storm_times_unix_sec,
            target_values=target_matrix[:, downsampling_index],
            target_name=target_name_for_downsampling,
            class_fraction_dict=downsampling_dict)
    else:
        indices_to_keep = fancy_downsampling.downsample_for_non_training(
            primary_id_strings=primary_id_strings,
            storm_times_unix_sec=storm_times_unix_sec,
            target_values=target_matrix[:, downsampling_index],
            target_name=target_name_for_downsampling,
            class_fraction_dict=downsampling_dict)

    print(SEPARATOR_STRING)

    for i in range(num_files):
        these_object_subindices = numpy.where(
            storm_to_file_indices[indices_to_keep] == i)[0]

        these_object_indices = indices_to_keep[these_object_subindices]
        if len(these_object_indices) == 0:
            continue

        these_indices_in_file = tracking_utils.find_storm_objects(
            all_id_strings=target_dict_by_file[i][
                target_val_utils.FULL_IDS_KEY],
            all_times_unix_sec=target_dict_by_file[i][
                target_val_utils.VALID_TIMES_KEY],
            id_strings_to_keep=[
                full_id_strings[k] for k in these_object_indices
            ],
            times_to_keep_unix_sec=storm_times_unix_sec[these_object_indices],
            allow_missing=False)

        this_output_dict = {
            tracking_utils.FULL_ID_COLUMN: [
                target_dict_by_file[i][target_val_utils.FULL_IDS_KEY][k]
                for k in these_indices_in_file
            ],
            tracking_utils.VALID_TIME_COLUMN:
            target_dict_by_file[i][target_val_utils.VALID_TIMES_KEY]
            [these_indices_in_file]
        }

        for j in range(len(target_names)):
            this_output_dict[target_names[j]] = (target_dict_by_file[i][
                target_val_utils.TARGET_MATRIX_KEY][these_indices_in_file, j])

        this_output_table = pandas.DataFrame.from_dict(this_output_dict)

        this_new_file_name = target_val_utils.find_target_file(
            top_directory_name=top_output_dir_name,
            event_type_string=event_type_string,
            spc_date_string=spc_date_string_by_file[i],
            raise_error_if_missing=False)

        print((
            'Writing {0:d} downsampled storm objects (out of {1:d} total) to: '
            '"{2:s}"...').format(
                len(this_output_table.index),
                len(target_dict_by_file[i][target_val_utils.FULL_IDS_KEY]),
                this_new_file_name))

        target_val_utils.write_target_values(
            storm_to_events_table=this_output_table,
            target_names=target_names,
            netcdf_file_name=this_new_file_name)