def test_first_and_last_times_in_year(self):
        """Ensures correct output from first_and_last_times_in_year."""

        this_start_time_unix_sec, this_end_time_unix_sec = (
            time_conversion.first_and_last_times_in_year(2017))
        self.assertTrue(this_start_time_unix_sec == START_TIME_2017_UNIX_SEC)
        self.assertTrue(this_end_time_unix_sec == END_TIME_2017_UNIX_SEC)
def _remove_examples_in_wrong_year(example_dict, desired_year):
    """Removes examples in the wrong year.

    :param example_dict: Dictionary in format created by `example_io.read_file`.
    :param desired_year: Year that belongs in this dictionary (integer).
    :return: example_dict: Same as input but maybe with fewer examples.
    """

    first_time_unix_sec, last_time_unix_sec = (
        time_conversion.first_and_last_times_in_year(desired_year))

    num_examples_orig = len(example_dict[example_utils.VALID_TIMES_KEY])
    example_dict = example_utils.subset_by_time(
        example_dict=example_dict,
        first_time_unix_sec=first_time_unix_sec,
        last_time_unix_sec=last_time_unix_sec)[0]
    num_examples = len(example_dict[example_utils.VALID_TIMES_KEY])

    print(
        'Removed {0:d} of {1:d} examples for being in the wrong year.'.format(
            num_examples_orig - num_examples, num_examples_orig))
    return example_dict
def _find_rrtm_files(rrtm_directory_name, year):
    """Finds RRTM files for the given year.

    :param rrtm_directory_name: See documentation at top of file.
    :param year: Year (integer).
    :return: example_file_names: 1-D list of file paths.
    :raises: ValueError: if no daily files can be found for the given year.
    """

    first_time_unix_sec, last_time_unix_sec = (
        time_conversion.first_and_last_times_in_year(year))
    num_seconds_in_year = last_time_unix_sec - first_time_unix_sec + 1
    num_days_in_year = int(
        numpy.round(float(num_seconds_in_year) / DAYS_TO_SECONDS))

    day_indices = numpy.linspace(1,
                                 num_days_in_year,
                                 num=num_days_in_year,
                                 dtype=int)
    example_file_names = []

    for i in day_indices:
        this_file_name = '{0:s}/{1:04d}{2:03d}/output_file.{1:04d}.cdf'.format(
            rrtm_directory_name, year, i)

        if not os.path.isfile(this_file_name):
            continue

        example_file_names.append(this_file_name)

    if len(example_file_names) > 0:
        return example_file_names

    error_string = (
        'Cannot find daily examples for year {0:d} in directory "{1:s}".'
    ).format(year, rrtm_directory_name)

    raise ValueError(error_string)
def _run(tropical_example_dir_name, non_tropical_example_dir_name,
         num_histogram_bins, output_dir_name):
    """Plots distribution of each target variable.

    This is effectively the main method.

    :param tropical_example_dir_name: See documentation at top of file.
    :param non_tropical_example_dir_name: Same.
    :param num_histogram_bins: Same.
    :param output_dir_name: Same.
    """

    file_system_utils.mkdir_recursive_if_necessary(
        directory_name=output_dir_name)

    first_time_unix_sec = (
        time_conversion.first_and_last_times_in_year(FIRST_YEAR)[0])
    last_time_unix_sec = (
        time_conversion.first_and_last_times_in_year(LAST_YEAR)[-1])

    example_file_names = example_io.find_many_files(
        directory_name=tropical_example_dir_name,
        first_time_unix_sec=first_time_unix_sec,
        last_time_unix_sec=last_time_unix_sec,
        raise_error_if_all_missing=True,
        raise_error_if_any_missing=True)

    example_file_names += example_io.find_many_files(
        directory_name=non_tropical_example_dir_name,
        first_time_unix_sec=first_time_unix_sec,
        last_time_unix_sec=last_time_unix_sec,
        raise_error_if_all_missing=True,
        raise_error_if_any_missing=True)

    example_dicts = []

    for this_file_name in example_file_names:
        print('Reading data from: "{0:s}"...'.format(this_file_name))
        this_example_dict = example_io.read_file(this_file_name)
        this_example_dict = example_utils.subset_by_field(
            example_dict=this_example_dict, field_names=TARGET_NAMES_IN_FILE)

        example_dicts.append(this_example_dict)

    example_dict = example_utils.concat_examples(example_dicts)
    del example_dicts

    letter_label = None
    panel_file_names = []

    for this_target_name in TARGET_NAMES:
        if this_target_name in TARGET_NAMES_IN_FILE:
            these_target_values = example_utils.get_field_from_dict(
                example_dict=example_dict, field_name=this_target_name)
        else:
            down_fluxes_w_m02 = example_utils.get_field_from_dict(
                example_dict=example_dict,
                field_name=example_utils.SHORTWAVE_SURFACE_DOWN_FLUX_NAME)
            up_fluxes_w_m02 = example_utils.get_field_from_dict(
                example_dict=example_dict,
                field_name=example_utils.SHORTWAVE_TOA_UP_FLUX_NAME)
            these_target_values = down_fluxes_w_m02 - up_fluxes_w_m02

        these_target_values = numpy.ravel(these_target_values)

        if letter_label is None:
            letter_label = 'a'
        else:
            letter_label = chr(ord(letter_label) + 1)

        this_file_name = _plot_histogram_one_target(
            target_values=these_target_values,
            target_name=this_target_name,
            num_bins=num_histogram_bins,
            letter_label=letter_label,
            output_dir_name=output_dir_name)
        panel_file_names.append(this_file_name)

    concat_file_name = '{0:s}/target_distributions.jpg'.format(output_dir_name)
    print('Concatenating panels to: "{0:s}"...'.format(concat_file_name))

    imagemagick_utils.concatenate_images(input_file_names=panel_file_names,
                                         output_file_name=concat_file_name,
                                         num_panel_rows=2,
                                         num_panel_columns=2,
                                         border_width_pixels=25)
    imagemagick_utils.trim_whitespace(input_file_name=concat_file_name,
                                      output_file_name=concat_file_name)
Beispiel #5
0
def _run(tropical_example_dir_name, non_tropical_example_dir_name,
         output_file_name):
    """Plots all sites wtih data.

    This is effectively the main method.

    :param tropical_example_dir_name: See documentation at top of file.
    :param non_tropical_example_dir_name: Same.
    :param output_file_name: Same.
    """

    first_time_unix_sec = (
        time_conversion.first_and_last_times_in_year(FIRST_YEAR)[0])
    last_time_unix_sec = (
        time_conversion.first_and_last_times_in_year(LAST_YEAR)[-1])

    tropical_file_names = example_io.find_many_files(
        directory_name=tropical_example_dir_name,
        first_time_unix_sec=first_time_unix_sec,
        last_time_unix_sec=last_time_unix_sec,
        raise_error_if_all_missing=True,
        raise_error_if_any_missing=False)

    non_tropical_file_names = example_io.find_many_files(
        directory_name=non_tropical_example_dir_name,
        first_time_unix_sec=first_time_unix_sec,
        last_time_unix_sec=last_time_unix_sec,
        raise_error_if_all_missing=True,
        raise_error_if_any_missing=False)

    latitudes_deg_n = numpy.array([])
    longitudes_deg_e = numpy.array([])

    for this_file_name in tropical_file_names:
        print('Reading data from: "{0:s}"...'.format(this_file_name))
        this_example_dict = example_io.read_file(this_file_name)

        these_latitudes_deg_n = example_utils.get_field_from_dict(
            example_dict=this_example_dict,
            field_name=example_utils.LATITUDE_NAME)
        these_longitudes_deg_e = example_utils.get_field_from_dict(
            example_dict=this_example_dict,
            field_name=example_utils.LONGITUDE_NAME)

        latitudes_deg_n = numpy.concatenate(
            (latitudes_deg_n, these_latitudes_deg_n))
        longitudes_deg_e = numpy.concatenate(
            (longitudes_deg_e, these_longitudes_deg_e))

    for this_file_name in non_tropical_file_names:
        print('Reading data from: "{0:s}"...'.format(this_file_name))
        this_example_dict = example_io.read_file(this_file_name)

        these_latitudes_deg_n = example_utils.get_field_from_dict(
            example_dict=this_example_dict,
            field_name=example_utils.LATITUDE_NAME)
        these_longitudes_deg_e = example_utils.get_field_from_dict(
            example_dict=this_example_dict,
            field_name=example_utils.LONGITUDE_NAME)

        latitudes_deg_n = numpy.concatenate(
            (latitudes_deg_n, these_latitudes_deg_n))
        longitudes_deg_e = numpy.concatenate(
            (longitudes_deg_e, these_longitudes_deg_e))

    coord_matrix = numpy.transpose(
        numpy.vstack((latitudes_deg_n, longitudes_deg_e)))
    coord_matrix = number_rounding.round_to_nearest(coord_matrix,
                                                    LATLNG_TOLERANCE_DEG)
    coord_matrix = numpy.unique(coord_matrix, axis=0)

    latitudes_deg_n = coord_matrix[:, 0]
    longitudes_deg_e = coord_matrix[:, 1]

    figure_object, axes_object, basemap_object = (
        plotting_utils.create_equidist_cylindrical_map(
            min_latitude_deg=MIN_PLOT_LATITUDE_DEG_N,
            max_latitude_deg=MAX_PLOT_LATITUDE_DEG_N,
            min_longitude_deg=MIN_PLOT_LONGITUDE_DEG_E,
            max_longitude_deg=MAX_PLOT_LONGITUDE_DEG_E,
            resolution_string='l'))

    plotting_utils.plot_coastlines(basemap_object=basemap_object,
                                   axes_object=axes_object,
                                   line_colour=BORDER_COLOUR,
                                   line_width=BORDER_WIDTH)
    plotting_utils.plot_countries(basemap_object=basemap_object,
                                  axes_object=axes_object,
                                  line_colour=BORDER_COLOUR,
                                  line_width=BORDER_WIDTH)
    plotting_utils.plot_parallels(basemap_object=basemap_object,
                                  axes_object=axes_object,
                                  num_parallels=NUM_PARALLELS,
                                  line_colour=GRID_LINE_COLOUR,
                                  line_width=GRID_LINE_WIDTH,
                                  font_size=FONT_SIZE)
    plotting_utils.plot_meridians(basemap_object=basemap_object,
                                  axes_object=axes_object,
                                  num_meridians=NUM_MERIDIANS,
                                  line_colour=GRID_LINE_COLOUR,
                                  line_width=GRID_LINE_WIDTH,
                                  font_size=FONT_SIZE)

    arctic_indices = numpy.where(latitudes_deg_n >= 66.5)[0]
    print(len(arctic_indices))

    arctic_x_coords, arctic_y_coords = basemap_object(
        longitudes_deg_e[arctic_indices], latitudes_deg_n[arctic_indices])
    axes_object.plot(arctic_x_coords,
                     arctic_y_coords,
                     linestyle='None',
                     marker=MARKER_TYPE,
                     markersize=MARKER_SIZE,
                     markeredgewidth=0,
                     markerfacecolor=ARCTIC_COLOUR,
                     markeredgecolor=ARCTIC_COLOUR)

    mid_latitude_indices = numpy.where(
        numpy.logical_and(latitudes_deg_n >= 30., latitudes_deg_n < 66.5))[0]
    print(len(mid_latitude_indices))

    mid_latitude_x_coords, mid_latitude_y_coords = basemap_object(
        longitudes_deg_e[mid_latitude_indices],
        latitudes_deg_n[mid_latitude_indices])
    axes_object.plot(mid_latitude_x_coords,
                     mid_latitude_y_coords,
                     linestyle='None',
                     marker=MARKER_TYPE,
                     markersize=MARKER_SIZE,
                     markeredgewidth=0,
                     markerfacecolor=MID_LATITUDE_COLOUR,
                     markeredgecolor=MID_LATITUDE_COLOUR)

    tropical_indices = numpy.where(latitudes_deg_n < 30.)[0]
    print(len(tropical_indices))

    tropical_x_coords, tropical_y_coords = basemap_object(
        longitudes_deg_e[tropical_indices], latitudes_deg_n[tropical_indices])
    axes_object.plot(tropical_x_coords,
                     tropical_y_coords,
                     linestyle='None',
                     marker=MARKER_TYPE,
                     markersize=MARKER_SIZE,
                     markeredgewidth=0,
                     markerfacecolor=TROPICAL_COLOUR,
                     markeredgecolor=TROPICAL_COLOUR)

    file_system_utils.mkdir_recursive_if_necessary(file_name=output_file_name)

    print('Saving figure to: "{0:s}"...'.format(output_file_name))
    figure_object.savefig(output_file_name,
                          dpi=FIGURE_RESOLUTION_DPI,
                          pad_inches=0,
                          bbox_inches='tight')
    pyplot.close(figure_object)
Beispiel #6
0
def get_examples_for_inference(model_metadata_dict, example_file_name,
                               num_examples, example_dir_name,
                               example_id_file_name):
    """Returns examples to be used by a model at inference stage.

    :param model_metadata_dict: Dictionary read by `neural_net.read_metafile`.
    :param example_file_name: [use only if you want random examples]
        Path to file with data examples (to be read by `example_io.read_file`).
    :param num_examples: [use only if you want random examples]
        Number of examples to use.  If you want to use all examples in
        `example_file_name`, leave this alone.
    :param example_dir_name: [use only if you want specific examples]
        Name of directory with data examples.  Files therein will be found by
        `example_io.find_file` and read by `example_io.read_file`.
    :param example_id_file_name: [use only if you want specific examples]
        Path to file with desired IDs.  Will be read by
        `read_example_ids_from_netcdf`.
    :return: Same output variables as `neural_net.data_generator`.
    """

    error_checking.assert_is_string(example_file_name)
    use_specific_ids = example_file_name == ''

    generator_option_dict = copy.deepcopy(
        model_metadata_dict[neural_net.TRAINING_OPTIONS_KEY])

    if use_specific_ids:
        error_checking.assert_is_string(example_id_file_name)

        print('Reading desired example IDs from: "{0:s}"...'.format(
            example_id_file_name))
        example_id_strings = read_example_ids_from_netcdf(example_id_file_name)

        generator_option_dict[neural_net.EXAMPLE_DIRECTORY_KEY] = (
            example_dir_name)

        predictor_matrix, target_array = (
            neural_net.create_data_specific_examples(
                option_dict=generator_option_dict,
                net_type_string=model_metadata_dict[neural_net.NET_TYPE_KEY],
                example_id_strings=example_id_strings))

        return predictor_matrix, target_array, example_id_strings

    error_checking.assert_is_string(example_dir_name)
    error_checking.assert_is_integer(num_examples)
    error_checking.assert_is_greater(num_examples, 0)

    example_dir_name = os.path.split(example_file_name)[0]
    year = example_io.file_name_to_year(example_file_name)
    first_time_unix_sec, last_time_unix_sec = (
        time_conversion.first_and_last_times_in_year(year))

    generator_option_dict[neural_net.EXAMPLE_DIRECTORY_KEY] = (
        example_dir_name)
    generator_option_dict[neural_net.FIRST_TIME_KEY] = first_time_unix_sec
    generator_option_dict[neural_net.LAST_TIME_KEY] = last_time_unix_sec

    predictor_matrix, target_array, example_id_strings = neural_net.create_data(
        option_dict=generator_option_dict,
        for_inference=True,
        net_type_string=model_metadata_dict[neural_net.NET_TYPE_KEY],
        is_loss_constrained_mse=False)

    num_examples_total = len(example_id_strings)
    if num_examples >= num_examples_total:
        return predictor_matrix, target_array, example_id_strings

    good_indices = numpy.linspace(0,
                                  num_examples_total - 1,
                                  num=num_examples_total,
                                  dtype=int)
    good_indices = numpy.random.choice(good_indices,
                                       size=num_examples,
                                       replace=False)

    predictor_matrix = predictor_matrix[good_indices, ...]
    example_id_strings = [example_id_strings[i] for i in good_indices]

    if isinstance(target_array, list):
        target_array = [t[good_indices, ...] for t in target_array]
    else:
        target_array = target_array[good_indices, ...]

    return predictor_matrix, target_array, example_id_strings