Example #1
0
def _extract_storm_images(
        num_image_rows, num_image_columns, rotate_grids,
        rotated_grid_spacing_metres, radar_field_names, radar_heights_m_agl,
        spc_date_string, top_radar_dir_name, top_tracking_dir_name,
        elevation_dir_name, tracking_scale_metres2, target_name,
        top_target_dir_name, top_output_dir_name):
    """Extracts storm-centered radar images from GridRad data.

    :param num_image_rows: See documentation at top of file.
    :param num_image_columns: Same.
    :param rotate_grids: Same.
    :param rotated_grid_spacing_metres: Same.
    :param radar_field_names: Same.
    :param radar_heights_m_agl: Same.
    :param spc_date_string: Same.
    :param top_radar_dir_name: Same.
    :param top_tracking_dir_name: Same.
    :param elevation_dir_name: Same.
    :param tracking_scale_metres2: Same.
    :param target_name: Same.
    :param top_target_dir_name: Same.
    :param top_output_dir_name: Same.
    """

    if target_name in ['', 'None']:
        target_name = None

    if target_name is not None:
        target_param_dict = target_val_utils.target_name_to_params(target_name)

        target_file_name = target_val_utils.find_target_file(
            top_directory_name=top_target_dir_name,
            event_type_string=target_param_dict[
                target_val_utils.EVENT_TYPE_KEY],
            spc_date_string=spc_date_string)

        print('Reading data from: "{0:s}"...'.format(target_file_name))
        target_dict = target_val_utils.read_target_values(
            netcdf_file_name=target_file_name, target_names=[target_name]
        )

        print('\n')

    # Find storm objects on the given SPC date.
    tracking_file_names = tracking_io.find_files_one_spc_date(
        spc_date_string=spc_date_string,
        source_name=tracking_utils.SEGMOTION_NAME,
        top_tracking_dir_name=top_tracking_dir_name,
        tracking_scale_metres2=tracking_scale_metres2
    )[0]

    # Read storm objects on the given SPC date.
    storm_object_table = tracking_io.read_many_files(
        tracking_file_names
    )[storm_images.STORM_COLUMNS_NEEDED]

    print(SEPARATOR_STRING)

    if target_name is not None:
        print((
            'Removing storm objects without target values (variable = '
            '"{0:s}")...'
        ).format(target_name))

        these_indices = tracking_utils.find_storm_objects(
            all_id_strings=storm_object_table[
                tracking_utils.FULL_ID_COLUMN].values.tolist(),
            all_times_unix_sec=storm_object_table[
                tracking_utils.VALID_TIME_COLUMN].values.astype(int),
            id_strings_to_keep=target_dict[target_val_utils.FULL_IDS_KEY],
            times_to_keep_unix_sec=target_dict[
                target_val_utils.VALID_TIMES_KEY],
            allow_missing=False)

        num_storm_objects_orig = len(storm_object_table.index)
        storm_object_table = storm_object_table.iloc[these_indices]
        num_storm_objects = len(storm_object_table.index)

        print('Removed {0:d} of {1:d} storm objects!\n'.format(
            num_storm_objects_orig - num_storm_objects, num_storm_objects_orig
        ))

    # Extract storm-centered radar images.
    storm_images.extract_storm_images_gridrad(
        storm_object_table=storm_object_table,
        top_radar_dir_name=top_radar_dir_name,
        top_output_dir_name=top_output_dir_name,
        elevation_dir_name=elevation_dir_name,
        num_storm_image_rows=num_image_rows,
        num_storm_image_columns=num_image_columns, rotate_grids=rotate_grids,
        rotated_grid_spacing_metres=rotated_grid_spacing_metres,
        radar_field_names=radar_field_names,
        radar_heights_m_agl=radar_heights_m_agl)
Example #2
0
def _read_target_values(top_target_dir_name, storm_activations,
                        activation_metadata_dict):
    """Reads target value for each storm object.

    E = number of examples (storm objects)

    :param top_target_dir_name: See documentation at top of file.
    :param storm_activations: length-E numpy array of activations.
    :param activation_metadata_dict: Dictionary returned by
        `model_activation.read_file`.
    :return: target_dict: Dictionary with the following keys.
    target_dict['full_id_strings']: length-E list of full storm IDs.
    target_dict['storm_times_unix_sec']: length-E numpy array of storm times.
    target_dict['storm_activations']: length-E numpy array of model activations.
    target_dict['storm_target_values']: length-E numpy array of target values.

    :raises: ValueError: if the target variable is multiclass and not binarized.
    """

    # Convert input args.
    full_id_strings = activation_metadata_dict[model_activation.FULL_IDS_KEY]
    storm_times_unix_sec = activation_metadata_dict[
        model_activation.STORM_TIMES_KEY]

    storm_spc_date_strings_numpy = numpy.array([
        time_conversion.time_to_spc_date_string(t)
        for t in storm_times_unix_sec
    ],
                                               dtype=object)

    unique_spc_date_strings_numpy = numpy.unique(storm_spc_date_strings_numpy)

    # Read metadata for machine-learning model.
    model_file_name = activation_metadata_dict[
        model_activation.MODEL_FILE_NAME_KEY]
    model_metadata_file_name = '{0:s}/model_metadata.p'.format(
        os.path.split(model_file_name)[0])

    print('Reading metadata from: "{0:s}"...'.format(model_metadata_file_name))
    model_metadata_dict = cnn.read_model_metadata(model_metadata_file_name)
    training_option_dict = model_metadata_dict[cnn.TRAINING_OPTION_DICT_KEY]

    target_name = training_option_dict[trainval_io.TARGET_NAME_KEY]
    num_classes = target_val_utils.target_name_to_num_classes(
        target_name=target_name, include_dead_storms=False)

    binarize_target = (training_option_dict[trainval_io.BINARIZE_TARGET_KEY]
                       and num_classes > 2)

    if num_classes > 2 and not binarize_target:
        error_string = (
            'The target variable ("{0:s}") is multiclass, which this script '
            'cannot handle.').format(target_name)

        raise ValueError(error_string)

    event_type_string = target_val_utils.target_name_to_params(target_name)[
        target_val_utils.EVENT_TYPE_KEY]

    # Read target values.
    storm_target_values = numpy.array([], dtype=int)
    id_sort_indices = numpy.array([], dtype=int)
    num_spc_dates = len(unique_spc_date_strings_numpy)

    for i in range(num_spc_dates):
        this_target_file_name = target_val_utils.find_target_file(
            top_directory_name=top_target_dir_name,
            event_type_string=event_type_string,
            spc_date_string=unique_spc_date_strings_numpy[i])

        print('Reading data from: "{0:s}"...'.format(this_target_file_name))
        this_target_value_dict = target_val_utils.read_target_values(
            netcdf_file_name=this_target_file_name, target_names=[target_name])

        these_indices = numpy.where(storm_spc_date_strings_numpy ==
                                    unique_spc_date_strings_numpy[i])[0]
        id_sort_indices = numpy.concatenate((id_sort_indices, these_indices))

        these_indices = tracking_utils.find_storm_objects(
            all_id_strings=this_target_value_dict[
                target_val_utils.FULL_IDS_KEY],
            all_times_unix_sec=this_target_value_dict[
                target_val_utils.VALID_TIMES_KEY],
            id_strings_to_keep=[full_id_strings[k] for k in these_indices],
            times_to_keep_unix_sec=storm_times_unix_sec[these_indices])

        if len(these_indices) == 0:
            continue

        these_target_values = this_target_value_dict[
            target_val_utils.TARGET_MATRIX_KEY][these_indices, :]

        these_target_values = numpy.reshape(these_target_values,
                                            these_target_values.size)

        storm_target_values = numpy.concatenate(
            (storm_target_values, these_target_values))

    good_indices = numpy.where(
        storm_target_values != target_val_utils.INVALID_STORM_INTEGER)[0]

    storm_target_values = storm_target_values[good_indices]
    id_sort_indices = id_sort_indices[good_indices]

    if binarize_target:
        storm_target_values = (storm_target_values == num_classes -
                               1).astype(int)

    return {
        FULL_IDS_KEY: [full_id_strings[k] for k in id_sort_indices],
        STORM_TIMES_KEY: storm_times_unix_sec[id_sort_indices],
        STORM_ACTIVATIONS_KEY: storm_activations[id_sort_indices],
        TARGET_VALUES_KEY: storm_target_values
    }
def _read_new_target_values(top_target_dir_name, new_target_name,
                            full_storm_id_strings, storm_times_unix_sec,
                            orig_target_values):
    """Reads new target values (for upgraded minimum EF rating).

    E = number of examples (storm objects)

    :param top_target_dir_name: See documentation at top of file.
    :param new_target_name: Name of new target variable (with upgraded minimum EF
        rating).
    :param full_storm_id_strings: length-E list of storm IDs.
    :param storm_times_unix_sec: length-E numpy array of valid times.
    :param orig_target_values: length-E numpy array of original target values
        (for original minimum EF rating), all integers in 0...1.
    :return: new_target_values: length-E numpy array of new target values
        (integers in -1...1).  -1 means that increasing minimum EF rating
        flipped the value from 1 to 0.
    """

    storm_spc_date_strings = numpy.array([
        time_conversion.time_to_spc_date_string(t)
        for t in storm_times_unix_sec
    ])
    unique_spc_date_strings = numpy.unique(storm_spc_date_strings)

    event_type_string = target_val_utils.target_name_to_params(
        new_target_name)[target_val_utils.EVENT_TYPE_KEY]

    num_spc_dates = len(unique_spc_date_strings)
    num_storm_objects = len(full_storm_id_strings)
    new_target_values = numpy.full(num_storm_objects, numpy.nan)

    for i in range(num_spc_dates):
        this_target_file_name = target_val_utils.find_target_file(
            top_directory_name=top_target_dir_name,
            event_type_string=event_type_string,
            spc_date_string=unique_spc_date_strings[i])

        print('Reading data from: "{0:s}"...'.format(this_target_file_name))
        this_target_value_dict = target_val_utils.read_target_values(
            netcdf_file_name=this_target_file_name,
            target_names=[new_target_name])

        these_storm_indices = numpy.where(
            storm_spc_date_strings == unique_spc_date_strings[i])[0]

        these_target_indices = tracking_utils.find_storm_objects(
            all_id_strings=this_target_value_dict[
                target_val_utils.FULL_IDS_KEY],
            all_times_unix_sec=this_target_value_dict[
                target_val_utils.VALID_TIMES_KEY],
            id_strings_to_keep=[
                full_storm_id_strings[k] for k in these_storm_indices
            ],
            times_to_keep_unix_sec=storm_times_unix_sec[these_storm_indices],
            allow_missing=False)

        new_target_values[these_storm_indices] = this_target_value_dict[
            target_val_utils.TARGET_MATRIX_KEY][these_target_indices, 0]

    assert not numpy.any(numpy.isnan(new_target_values))
    new_target_values = numpy.round(new_target_values).astype(int)

    bad_indices = numpy.where(new_target_values != orig_target_values)[0]
    print(('\n{0:d} of {1:d} new target values do not match original value.'
           ).format(len(bad_indices), num_storm_objects))

    new_target_values[bad_indices] = -1
    return new_target_values
def _extract_storm_images(num_image_rows, num_image_columns, rotate_grids,
                          rotated_grid_spacing_metres, radar_field_names,
                          refl_heights_m_agl, spc_date_string,
                          tarred_myrorss_dir_name, untarred_myrorss_dir_name,
                          top_tracking_dir_name, elevation_dir_name,
                          tracking_scale_metres2, target_name,
                          top_target_dir_name, top_output_dir_name):
    """Extracts storm-centered img for each field/height pair and storm object.

    :param num_image_rows: See documentation at top of file.
    :param num_image_columns: Same.
    :param rotate_grids: Same.
    :param rotated_grid_spacing_metres: Same.
    :param radar_field_names: Same.
    :param refl_heights_m_agl: Same.
    :param spc_date_string: Same.
    :param tarred_myrorss_dir_name: Same.
    :param untarred_myrorss_dir_name: Same.
    :param top_tracking_dir_name: Same.
    :param elevation_dir_name: Same.
    :param tracking_scale_metres2: Same.
    :param target_name: Same.
    :param top_target_dir_name: Same.
    :param top_output_dir_name: Same.
    """

    if target_name in ['', 'None']:
        target_name = None

    if target_name is not None:
        target_param_dict = target_val_utils.target_name_to_params(target_name)

        target_file_name = target_val_utils.find_target_file(
            top_directory_name=top_target_dir_name,
            event_type_string=target_param_dict[
                target_val_utils.EVENT_TYPE_KEY],
            spc_date_string=spc_date_string)

        print('Reading data from: "{0:s}"...'.format(target_file_name))
        target_dict = target_val_utils.read_target_values(
            netcdf_file_name=target_file_name, target_names=[target_name])
        print('\n')

    refl_heights_m_asl = radar_utils.get_valid_heights(
        data_source=radar_utils.MYRORSS_SOURCE_ID,
        field_name=radar_utils.REFL_NAME)

    # Untar files with azimuthal shear.
    az_shear_field_names = list(
        set(radar_field_names) & set(ALL_AZ_SHEAR_FIELD_NAMES))

    if len(az_shear_field_names):
        az_shear_tar_file_name = (
            '{0:s}/{1:s}/azimuthal_shear_only/{2:s}.tar').format(
                tarred_myrorss_dir_name, spc_date_string[:4], spc_date_string)

        myrorss_io.unzip_1day_tar_file(
            tar_file_name=az_shear_tar_file_name,
            field_names=az_shear_field_names,
            spc_date_string=spc_date_string,
            top_target_directory_name=untarred_myrorss_dir_name)
        print(SEPARATOR_STRING)

    # Untar files with other radar fields.
    non_shear_field_names = list(
        set(radar_field_names) - set(ALL_AZ_SHEAR_FIELD_NAMES))

    if len(non_shear_field_names):
        non_shear_tar_file_name = '{0:s}/{1:s}/{2:s}.tar'.format(
            tarred_myrorss_dir_name, spc_date_string[:4], spc_date_string)

        myrorss_io.unzip_1day_tar_file(
            tar_file_name=non_shear_tar_file_name,
            field_names=non_shear_field_names,
            spc_date_string=spc_date_string,
            top_target_directory_name=untarred_myrorss_dir_name,
            refl_heights_m_asl=refl_heights_m_asl)
        print(SEPARATOR_STRING)

    # Read storm tracks for the given SPC date.
    tracking_file_names = tracking_io.find_files_one_spc_date(
        spc_date_string=spc_date_string,
        source_name=tracking_utils.SEGMOTION_NAME,
        top_tracking_dir_name=top_tracking_dir_name,
        tracking_scale_metres2=tracking_scale_metres2)[0]

    storm_object_table = tracking_io.read_many_files(tracking_file_names)[
        storm_images.STORM_COLUMNS_NEEDED]
    print(SEPARATOR_STRING)

    if target_name is not None:
        print(('Removing storm objects without target values (variable = '
               '"{0:s}")...').format(target_name))

        these_indices = tracking_utils.find_storm_objects(
            all_id_strings=storm_object_table[
                tracking_utils.FULL_ID_COLUMN].values.tolist(),
            all_times_unix_sec=storm_object_table[
                tracking_utils.VALID_TIME_COLUMN].values.astype(int),
            id_strings_to_keep=target_dict[target_val_utils.FULL_IDS_KEY],
            times_to_keep_unix_sec=target_dict[
                target_val_utils.VALID_TIMES_KEY],
            allow_missing=False)

        num_storm_objects_orig = len(storm_object_table.index)
        storm_object_table = storm_object_table.iloc[these_indices]
        num_storm_objects = len(storm_object_table.index)

        print('Removed {0:d} of {1:d} storm objects!\n'.format(
            num_storm_objects_orig - num_storm_objects,
            num_storm_objects_orig))

    # Extract storm-centered radar images.
    storm_images.extract_storm_images_myrorss_or_mrms(
        storm_object_table=storm_object_table,
        radar_source=radar_utils.MYRORSS_SOURCE_ID,
        top_radar_dir_name=untarred_myrorss_dir_name,
        top_output_dir_name=top_output_dir_name,
        elevation_dir_name=elevation_dir_name,
        num_storm_image_rows=num_image_rows,
        num_storm_image_columns=num_image_columns,
        rotate_grids=rotate_grids,
        rotated_grid_spacing_metres=rotated_grid_spacing_metres,
        radar_field_names=radar_field_names,
        reflectivity_heights_m_agl=refl_heights_m_agl)
    print(SEPARATOR_STRING)

    # Remove untarred MYRORSS files.
    myrorss_io.remove_unzipped_data_1day(
        spc_date_string=spc_date_string,
        top_directory_name=untarred_myrorss_dir_name,
        field_names=radar_field_names,
        refl_heights_m_asl=refl_heights_m_asl)
    print(SEPARATOR_STRING)
Example #5
0
def _extract_storm_images(
        num_image_rows, num_image_columns, rotate_grids,
        rotated_grid_spacing_metres, radar_field_names, refl_heights_m_agl,
        spc_date_string, first_time_string, last_time_string,
        tarred_myrorss_dir_name, untarred_myrorss_dir_name,
        top_tracking_dir_name, elevation_dir_name, tracking_scale_metres2,
        target_name, top_target_dir_name, top_output_dir_name):
    """Extracts storm-centered img for each field/height pair and storm object.

    :param num_image_rows: See documentation at top of file.
    :param num_image_columns: Same.
    :param rotate_grids: Same.
    :param rotated_grid_spacing_metres: Same.
    :param radar_field_names: Same.
    :param refl_heights_m_agl: Same.
    :param spc_date_string: Same.
    :param first_time_string: Same.
    :param last_time_string: Same.
    :param tarred_myrorss_dir_name: Same.
    :param untarred_myrorss_dir_name: Same.
    :param top_tracking_dir_name: Same.
    :param elevation_dir_name: Same.
    :param tracking_scale_metres2: Same.
    :param target_name: Same.
    :param top_target_dir_name: Same.
    :param top_output_dir_name: Same.
    :raises: ValueError: if `first_time_string` and `last_time_string` have
        different SPC dates.
    """

    if elevation_dir_name in ['', 'None']:
        elevation_dir_name = None

    if elevation_dir_name is None:
        host_name = socket.gethostname()

        if 'casper' in host_name:
            elevation_dir_name = '/glade/work/ryanlage/elevation'
        else:
            elevation_dir_name = '/condo/swatwork/ralager/elevation'

    if spc_date_string in ['', 'None']:
        first_time_unix_sec = time_conversion.string_to_unix_sec(
            first_time_string, TIME_FORMAT)
        last_time_unix_sec = time_conversion.string_to_unix_sec(
            last_time_string, TIME_FORMAT)

        first_spc_date_string = time_conversion.time_to_spc_date_string(
            first_time_unix_sec)
        last_spc_date_string = time_conversion.time_to_spc_date_string(
            last_time_unix_sec)

        if first_spc_date_string != last_spc_date_string:
            error_string = (
                'First ({0:s}) and last ({1:s}) times have different SPC dates.'
                '  This script can handle only one SPC date.'
            ).format(first_time_string, last_time_string)

            raise ValueError(error_string)

        spc_date_string = first_spc_date_string
    else:
        first_time_unix_sec = 0
        last_time_unix_sec = int(1e12)

    if tarred_myrorss_dir_name in ['', 'None']:
        tarred_myrorss_dir_name = None
    if target_name in ['', 'None']:
        target_name = None

    if target_name is not None:
        target_param_dict = target_val_utils.target_name_to_params(target_name)

        target_file_name = target_val_utils.find_target_file(
            top_directory_name=top_target_dir_name,
            event_type_string=target_param_dict[
                target_val_utils.EVENT_TYPE_KEY],
            spc_date_string=spc_date_string)

        print('Reading data from: "{0:s}"...'.format(target_file_name))
        target_dict = target_val_utils.read_target_values(
            netcdf_file_name=target_file_name, target_names=[target_name]
        )
        print('\n')

    refl_heights_m_asl = radar_utils.get_valid_heights(
        data_source=radar_utils.MYRORSS_SOURCE_ID,
        field_name=radar_utils.REFL_NAME)

    # Untar files.
    if tarred_myrorss_dir_name is not None:
        az_shear_field_names = list(
            set(radar_field_names) & set(ALL_AZ_SHEAR_FIELD_NAMES)
        )

        if len(az_shear_field_names) > 0:
            az_shear_tar_file_name = (
                '{0:s}/{1:s}/azimuthal_shear_only/{2:s}.tar'
            ).format(
                tarred_myrorss_dir_name, spc_date_string[:4], spc_date_string
            )

            myrorss_io.unzip_1day_tar_file(
                tar_file_name=az_shear_tar_file_name,
                field_names=az_shear_field_names,
                spc_date_string=spc_date_string,
                top_target_directory_name=untarred_myrorss_dir_name)
            print(SEPARATOR_STRING)

        non_shear_field_names = list(
            set(radar_field_names) - set(ALL_AZ_SHEAR_FIELD_NAMES)
        )

        if len(non_shear_field_names) > 0:
            non_shear_tar_file_name = '{0:s}/{1:s}/{2:s}.tar'.format(
                tarred_myrorss_dir_name, spc_date_string[:4], spc_date_string
            )

            myrorss_io.unzip_1day_tar_file(
                tar_file_name=non_shear_tar_file_name,
                field_names=non_shear_field_names,
                spc_date_string=spc_date_string,
                top_target_directory_name=untarred_myrorss_dir_name,
                refl_heights_m_asl=refl_heights_m_asl)
            print(SEPARATOR_STRING)

    # Read storm tracks for the given SPC date.
    tracking_file_names = tracking_io.find_files_one_spc_date(
        spc_date_string=spc_date_string,
        source_name=tracking_utils.SEGMOTION_NAME,
        top_tracking_dir_name=top_tracking_dir_name,
        tracking_scale_metres2=tracking_scale_metres2
    )[0]

    file_times_unix_sec = numpy.array(
        [tracking_io.file_name_to_time(f) for f in tracking_file_names],
        dtype=int
    )

    good_indices = numpy.where(numpy.logical_and(
        file_times_unix_sec >= first_time_unix_sec,
        file_times_unix_sec <= last_time_unix_sec
    ))[0]

    tracking_file_names = [tracking_file_names[k] for k in good_indices]

    storm_object_table = tracking_io.read_many_files(
        tracking_file_names
    )[storm_images.STORM_COLUMNS_NEEDED]
    print(SEPARATOR_STRING)

    if target_name is not None:
        print((
            'Removing storm objects without target values (variable = '
            '"{0:s}")...'
        ).format(target_name))

        these_indices = tracking_utils.find_storm_objects(
            all_id_strings=storm_object_table[
                tracking_utils.FULL_ID_COLUMN].values.tolist(),
            all_times_unix_sec=storm_object_table[
                tracking_utils.VALID_TIME_COLUMN].values.astype(int),
            id_strings_to_keep=target_dict[target_val_utils.FULL_IDS_KEY],
            times_to_keep_unix_sec=target_dict[
                target_val_utils.VALID_TIMES_KEY],
            allow_missing=False)

        num_storm_objects_orig = len(storm_object_table.index)
        storm_object_table = storm_object_table.iloc[these_indices]
        num_storm_objects = len(storm_object_table.index)

        print('Removed {0:d} of {1:d} storm objects!\n'.format(
            num_storm_objects_orig - num_storm_objects, num_storm_objects_orig
        ))

    # Extract storm-centered radar images.
    storm_images.extract_storm_images_myrorss_or_mrms(
        storm_object_table=storm_object_table,
        radar_source=radar_utils.MYRORSS_SOURCE_ID,
        top_radar_dir_name=untarred_myrorss_dir_name,
        top_output_dir_name=top_output_dir_name,
        elevation_dir_name=elevation_dir_name,
        num_storm_image_rows=num_image_rows,
        num_storm_image_columns=num_image_columns, rotate_grids=rotate_grids,
        rotated_grid_spacing_metres=rotated_grid_spacing_metres,
        radar_field_names=radar_field_names,
        reflectivity_heights_m_agl=refl_heights_m_agl)
    print(SEPARATOR_STRING)

    # Remove untarred MYRORSS files.
    if tarred_myrorss_dir_name is not None:
        myrorss_io.remove_unzipped_data_1day(
            spc_date_string=spc_date_string,
            top_directory_name=untarred_myrorss_dir_name,
            field_names=radar_field_names,
            refl_heights_m_asl=refl_heights_m_asl)
Example #6
0
def _run(top_input_dir_name, target_name, first_spc_date_string,
         last_spc_date_string, class_fraction_keys, class_fraction_values,
         for_training, top_output_dir_name):
    """Downsamples storm objects, based on target values.

    This is effectively the main method.

    :param top_input_dir_name: See documentation at top of file.
    :param target_name: Same.
    :param first_spc_date_string: Same.
    :param last_spc_date_string: Same.
    :param class_fraction_keys: Same.
    :param class_fraction_values: Same.
    :param for_training: Same.
    :param top_output_dir_name: Same.
    """

    spc_date_strings = time_conversion.get_spc_dates_in_range(
        first_spc_date_string=first_spc_date_string,
        last_spc_date_string=last_spc_date_string)

    class_fraction_dict = dict(zip(class_fraction_keys, class_fraction_values))
    target_param_dict = target_val_utils.target_name_to_params(target_name)

    input_target_file_names = []
    spc_date_string_by_file = []

    for this_spc_date_string in spc_date_strings:
        this_file_name = target_val_utils.find_target_file(
            top_directory_name=top_input_dir_name,
            event_type_string=target_param_dict[
                target_val_utils.EVENT_TYPE_KEY],
            spc_date_string=this_spc_date_string,
            raise_error_if_missing=False)
        if not os.path.isfile(this_file_name):
            continue

        input_target_file_names.append(this_file_name)
        spc_date_string_by_file.append(this_spc_date_string)

    num_files = len(input_target_file_names)
    spc_date_strings = spc_date_string_by_file
    target_dict_by_file = [None] * num_files

    storm_ids = []
    storm_times_unix_sec = numpy.array([], dtype=int)
    target_values = numpy.array([], dtype=int)

    for i in range(num_files):
        print 'Reading "{0:s}" from: "{1:s}"...'.format(
            target_name, input_target_file_names[i])
        target_dict_by_file[i] = target_val_utils.read_target_values(
            netcdf_file_name=input_target_file_names[i],
            target_name=target_name)

        storm_ids += target_dict_by_file[i][target_val_utils.STORM_IDS_KEY]
        storm_times_unix_sec = numpy.concatenate(
            (storm_times_unix_sec,
             target_dict_by_file[i][target_val_utils.VALID_TIMES_KEY]))
        target_values = numpy.concatenate(
            (target_values,
             target_dict_by_file[i][target_val_utils.TARGET_VALUES_KEY]))

    print SEPARATOR_STRING

    good_indices = numpy.where(
        target_values != target_val_utils.INVALID_STORM_INTEGER)[0]

    storm_ids = [storm_ids[k] for k in good_indices]
    storm_times_unix_sec = storm_times_unix_sec[good_indices]
    target_values = target_values[good_indices]

    if for_training:
        storm_ids, storm_times_unix_sec, target_values = (
            fancy_downsampling.downsample_for_training(
                storm_ids=storm_ids,
                storm_times_unix_sec=storm_times_unix_sec,
                target_values=target_values,
                target_name=target_name,
                class_fraction_dict=class_fraction_dict))
    else:
        storm_ids, storm_times_unix_sec, target_values = (
            fancy_downsampling.downsample_for_non_training(
                storm_ids=storm_ids,
                storm_times_unix_sec=storm_times_unix_sec,
                target_values=target_values,
                target_name=target_name,
                class_fraction_dict=class_fraction_dict))

    print SEPARATOR_STRING

    for i in range(num_files):
        these_indices = tracking_utils.find_storm_objects(
            all_storm_ids=target_dict_by_file[i][
                target_val_utils.STORM_IDS_KEY],
            all_times_unix_sec=target_dict_by_file[i][
                target_val_utils.VALID_TIMES_KEY],
            storm_ids_to_keep=storm_ids,
            times_to_keep_unix_sec=storm_times_unix_sec,
            allow_missing=True)

        these_indices = these_indices[these_indices >= 0]
        if len(these_indices) == 0:
            continue

        this_output_dict = {
            tracking_utils.STORM_ID_COLUMN: [
                target_dict_by_file[i][target_val_utils.STORM_IDS_KEY][k]
                for k in these_indices
            ],
            tracking_utils.TIME_COLUMN:
            target_dict_by_file[i][
                target_val_utils.VALID_TIMES_KEY][these_indices],
            target_name:
            target_dict_by_file[i][target_val_utils.TARGET_VALUES_KEY]
            [these_indices]
        }
        this_output_table = pandas.DataFrame.from_dict(this_output_dict)

        this_new_file_name = target_val_utils.find_target_file(
            top_directory_name=top_output_dir_name,
            event_type_string=target_param_dict[
                target_val_utils.EVENT_TYPE_KEY],
            spc_date_string=spc_date_strings[i],
            raise_error_if_missing=False)

        print(
            'Writing {0:d} downsampled storm objects (out of {1:d} total) to: '
            '"{2:s}"...').format(
                len(this_output_table.index),
                len(target_dict_by_file[i][target_val_utils.STORM_IDS_KEY]),
                this_new_file_name)

        target_val_utils.write_target_values(
            storm_to_events_table=this_output_table,
            target_names=[target_name],
            netcdf_file_name=this_new_file_name)
Example #7
0
def _run(top_input_dir_name, target_name_for_downsampling,
         first_spc_date_string, last_spc_date_string, downsampling_classes,
         downsampling_fractions, for_training, top_output_dir_name):
    """Downsamples storm objects, based on target values.

    This is effectively the main method.

    :param top_input_dir_name: See documentation at top of file.
    :param target_name_for_downsampling: Same.
    :param first_spc_date_string: Same.
    :param last_spc_date_string: Same.
    :param downsampling_classes: Same.
    :param downsampling_fractions: Same.
    :param for_training: Same.
    :param top_output_dir_name: Same.
    """

    all_spc_date_strings = time_conversion.get_spc_dates_in_range(
        first_spc_date_string=first_spc_date_string,
        last_spc_date_string=last_spc_date_string)

    downsampling_dict = dict(
        list(zip(downsampling_classes, downsampling_fractions)))

    target_param_dict = target_val_utils.target_name_to_params(
        target_name_for_downsampling)
    event_type_string = target_param_dict[target_val_utils.EVENT_TYPE_KEY]

    input_target_file_names = []
    spc_date_string_by_file = []

    for this_spc_date_string in all_spc_date_strings:
        this_file_name = target_val_utils.find_target_file(
            top_directory_name=top_input_dir_name,
            event_type_string=event_type_string,
            spc_date_string=this_spc_date_string,
            raise_error_if_missing=False)

        if not os.path.isfile(this_file_name):
            continue

        input_target_file_names.append(this_file_name)
        spc_date_string_by_file.append(this_spc_date_string)

    num_files = len(input_target_file_names)
    target_dict_by_file = [None] * num_files

    full_id_strings = []
    storm_times_unix_sec = numpy.array([], dtype=int)
    storm_to_file_indices = numpy.array([], dtype=int)

    target_names = []
    target_matrix = None

    for i in range(num_files):
        print('Reading data from: "{0:s}"...'.format(
            input_target_file_names[i]))

        target_dict_by_file[i] = target_val_utils.read_target_values(
            netcdf_file_name=input_target_file_names[i])

        if i == 0:
            target_names = (
                target_dict_by_file[i][target_val_utils.TARGET_NAMES_KEY])

        these_full_id_strings = (
            target_dict_by_file[i][target_val_utils.FULL_IDS_KEY])

        full_id_strings += these_full_id_strings
        this_num_storm_objects = len(these_full_id_strings)

        storm_times_unix_sec = numpy.concatenate(
            (storm_times_unix_sec,
             target_dict_by_file[i][target_val_utils.VALID_TIMES_KEY]))

        storm_to_file_indices = numpy.concatenate(
            (storm_to_file_indices,
             numpy.full(this_num_storm_objects, i, dtype=int)))

        this_target_matrix = (
            target_dict_by_file[i][target_val_utils.TARGET_MATRIX_KEY])

        if target_matrix is None:
            target_matrix = this_target_matrix + 0
        else:
            target_matrix = numpy.concatenate(
                (target_matrix, this_target_matrix), axis=0)

    print(SEPARATOR_STRING)

    downsampling_index = target_names.index(target_name_for_downsampling)
    good_indices = numpy.where(target_matrix[:, downsampling_index] !=
                               target_val_utils.INVALID_STORM_INTEGER)[0]

    full_id_strings = [full_id_strings[k] for k in good_indices]
    storm_times_unix_sec = storm_times_unix_sec[good_indices]
    target_matrix = target_matrix[good_indices, :]
    storm_to_file_indices = storm_to_file_indices[good_indices]

    primary_id_strings = temporal_tracking.full_to_partial_ids(
        full_id_strings)[0]

    if for_training:
        indices_to_keep = fancy_downsampling.downsample_for_training(
            primary_id_strings=primary_id_strings,
            storm_times_unix_sec=storm_times_unix_sec,
            target_values=target_matrix[:, downsampling_index],
            target_name=target_name_for_downsampling,
            class_fraction_dict=downsampling_dict)
    else:
        indices_to_keep = fancy_downsampling.downsample_for_non_training(
            primary_id_strings=primary_id_strings,
            storm_times_unix_sec=storm_times_unix_sec,
            target_values=target_matrix[:, downsampling_index],
            target_name=target_name_for_downsampling,
            class_fraction_dict=downsampling_dict)

    print(SEPARATOR_STRING)

    for i in range(num_files):
        these_object_subindices = numpy.where(
            storm_to_file_indices[indices_to_keep] == i)[0]

        these_object_indices = indices_to_keep[these_object_subindices]
        if len(these_object_indices) == 0:
            continue

        these_indices_in_file = tracking_utils.find_storm_objects(
            all_id_strings=target_dict_by_file[i][
                target_val_utils.FULL_IDS_KEY],
            all_times_unix_sec=target_dict_by_file[i][
                target_val_utils.VALID_TIMES_KEY],
            id_strings_to_keep=[
                full_id_strings[k] for k in these_object_indices
            ],
            times_to_keep_unix_sec=storm_times_unix_sec[these_object_indices],
            allow_missing=False)

        this_output_dict = {
            tracking_utils.FULL_ID_COLUMN: [
                target_dict_by_file[i][target_val_utils.FULL_IDS_KEY][k]
                for k in these_indices_in_file
            ],
            tracking_utils.VALID_TIME_COLUMN:
            target_dict_by_file[i][target_val_utils.VALID_TIMES_KEY]
            [these_indices_in_file]
        }

        for j in range(len(target_names)):
            this_output_dict[target_names[j]] = (target_dict_by_file[i][
                target_val_utils.TARGET_MATRIX_KEY][these_indices_in_file, j])

        this_output_table = pandas.DataFrame.from_dict(this_output_dict)

        this_new_file_name = target_val_utils.find_target_file(
            top_directory_name=top_output_dir_name,
            event_type_string=event_type_string,
            spc_date_string=spc_date_string_by_file[i],
            raise_error_if_missing=False)

        print((
            'Writing {0:d} downsampled storm objects (out of {1:d} total) to: '
            '"{2:s}"...').format(
                len(this_output_table.index),
                len(target_dict_by_file[i][target_val_utils.FULL_IDS_KEY]),
                this_new_file_name))

        target_val_utils.write_target_values(
            storm_to_events_table=this_output_table,
            target_names=target_names,
            netcdf_file_name=this_new_file_name)