def _run():
    """Analyzes backwards-optimization experiment.

    This is effectively the main method.
    """

    num_l2_weights = len(L2_WEIGHTS)
    num_minmax_weights = len(MINMAX_WEIGHTS)

    mean_final_activation_matrix = numpy.full(
        (num_l2_weights, num_minmax_weights), numpy.nan)

    for i in range(num_l2_weights):
        for j in range(num_minmax_weights):
            this_file_name = (
                '{0:s}/bwo_pmm_l2-weight={1:.10f}_minmax-weight={2:.10f}.p'
            ).format(TOP_EXPERIMENT_DIR_NAME, L2_WEIGHTS[i], MINMAX_WEIGHTS[j])

            print('Reading data from: "{0:s}"...'.format(this_file_name))
            this_bwo_dict = backwards_opt.read_file(this_file_name)[0]
            mean_final_activation_matrix[i, j] = this_bwo_dict[
                backwards_opt.MEAN_FINAL_ACTIVATION_KEY]

    x_tick_labels = ['{0:.1f}'.format(r) for r in numpy.log10(MINMAX_WEIGHTS)]
    y_tick_labels = ['{0:.1f}'.format(w) for w in numpy.log10(L2_WEIGHTS)]

    axes_object = model_evaluation.plot_hyperparam_grid(
        score_matrix=mean_final_activation_matrix,
        colour_map_object=COLOUR_MAP_OBJECT,
        min_colour_value=0.,
        max_colour_value=1.)

    axes_object.set_xticklabels(x_tick_labels, rotation=90.)
    axes_object.set_yticklabels(y_tick_labels)

    axes_object.set_xlabel(r'Min-max weight (log$_{10}$)')
    axes_object.set_ylabel(r'L$_2$ weight (log$_{10}$)')

    plotting_utils.plot_linear_colour_bar(
        axes_object_or_matrix=axes_object,
        data_matrix=mean_final_activation_matrix,
        colour_map_object=COLOUR_MAP_OBJECT,
        min_value=0.,
        max_value=1.,
        orientation_string='vertical',
        extend_min=False,
        extend_max=False,
        font_size=FONT_SIZE)

    output_file_name = '{0:s}/mean_final_activations.jpg'.format(
        TOP_EXPERIMENT_DIR_NAME)

    print('Saving figure to: "{0:s}"...'.format(output_file_name))
    pyplot.savefig(output_file_name,
                   dpi=FIGURE_RESOLUTION_DPI,
                   pad_inches=0,
                   bbox_inches='tight')
    pyplot.close()
def _plot_feature_maps_one_layer(
        feature_matrix, example_id_strings, layer_name, output_dir_name):
    """Plots feature maps for one layer.

    E = number of examples
    H = number of heights
    C = number of channels

    :param feature_matrix: E-by-H-by-C numpy array of feature maps.
    :param example_id_strings: length-E list of example IDs.
    :param layer_name: Name of layer that generated feature maps.
    :param output_dir_name: Name of output directory.  Figures will be saved
        here.
    """

    error_checking.assert_is_numpy_array(feature_matrix, num_dimensions=3)
    num_examples = feature_matrix.shape[0]

    # TODO(thunderhoser): Maybe define colour limits differently?
    max_colour_value = numpy.percentile(numpy.absolute(feature_matrix), 99)
    min_colour_value = -1 * max_colour_value

    for i in range(num_examples):
        this_figure_object, this_axes_object_matrix = (
            feature_map_plotting.plot_many_1d_feature_maps(
                feature_matrix=feature_matrix[i, ...],
                colour_map_object=COLOUR_MAP_OBJECT,
                min_colour_value=min_colour_value,
                max_colour_value=max_colour_value)
        )

        plotting_utils.plot_linear_colour_bar(
            axes_object_or_matrix=this_axes_object_matrix,
            data_matrix=feature_matrix[i, ...],
            colour_map_object=COLOUR_MAP_OBJECT,
            min_value=min_colour_value, max_value=max_colour_value,
            orientation_string='horizontal', padding=0.01,
            extend_min=True, extend_max=True
        )

        this_title_string = 'Layer "{0:s}", example "{1:s}"'.format(
            layer_name, example_id_strings[i]
        )
        this_figure_object.suptitle(this_title_string, fontsize=25)

        this_file_name = '{0:s}/{1:s}.jpg'.format(
            output_dir_name, example_id_strings[i]
        )

        print('Saving figure to: "{0:s}"...'.format(this_file_name))
        this_figure_object.savefig(
            this_file_name, dpi=FIGURE_RESOLUTION_DPI,
            pad_inches=0, bbox_inches='tight'
        )
        pyplot.close(this_figure_object)
def _add_colour_bar(axes_object, basemap_object, colour_map_object,
                    colour_norm_object):
    """Adds colour bar to figure.

    :param axes_object: See input doc for `plot_storm_tracks`.
    :param basemap_object: Same.
    :param colour_map_object: See output doc for `_process_colour_args`.
    :param colour_norm_object: Same.
    :return: colour_bar_object: Handle for colour bar.
    """

    latitude_range_deg = basemap_object.urcrnrlat - basemap_object.llcrnrlat
    longitude_range_deg = basemap_object.urcrnrlon - basemap_object.llcrnrlon

    if latitude_range_deg > longitude_range_deg:
        orientation_string = 'vertical'
        padding = None
    else:
        orientation_string = 'horizontal'
        padding = 0.05

    dummy_values = numpy.array([0, 1e12], dtype=int)

    colour_bar_object = plotting_utils.plot_linear_colour_bar(
        axes_object_or_matrix=axes_object,
        data_matrix=dummy_values,
        colour_map_object=colour_map_object,
        min_value=colour_norm_object.vmin,
        max_value=colour_norm_object.vmax,
        orientation_string=orientation_string,
        padding=padding,
        extend_min=False,
        extend_max=False,
        fraction_of_axis_length=0.9,
        font_size=COLOUR_BAR_FONT_SIZE)

    tick_times_unix_sec = numpy.round(
        colour_bar_object.get_ticks()).astype(int)

    tick_time_strings = [
        time_conversion.unix_sec_to_string(t, COLOUR_BAR_TIME_FORMAT)
        for t in tick_times_unix_sec
    ]

    colour_bar_object.set_ticks(tick_times_unix_sec)
    colour_bar_object.set_ticklabels(tick_time_strings)
    return colour_bar_object
Exemple #4
0
def _plot_one_feature_map(feature_matrix_2d, max_colour_value, plot_colour_bar,
                          axes_object):
    """Plots one feature map.

    M = number of rows in grid
    N = number of columns in grid

    :param feature_matrix_2d: M-by-N numpy array of feature values.
    :param max_colour_value: Max value in colour scheme.
    :param plot_colour_bar: Boolean flag.
    :param axes_object: Will plot on these axes (instance of
        `matplotlib.axes._subplots.AxesSubplot`).
    """

    min_colour_value = -1 * max_colour_value

    axes_object.pcolormesh(
        feature_matrix_2d, cmap=COLOUR_MAP_OBJECT, vmin=min_colour_value,
        vmax=max_colour_value, shading='flat', edgecolors='None')

    axes_object.set_xlim(0., feature_matrix_2d.shape[1])
    axes_object.set_ylim(0., feature_matrix_2d.shape[0])
    axes_object.set_xticks([])
    axes_object.set_yticks([])

    if not plot_colour_bar:
        return

    colour_bar_object = plotting_utils.plot_linear_colour_bar(
        axes_object_or_matrix=axes_object, data_matrix=feature_matrix_2d,
        colour_map_object=COLOUR_MAP_OBJECT, min_value=min_colour_value,
        max_value=max_colour_value, orientation_string='horizontal',
        padding=0.015, fraction_of_axis_length=0.9,
        extend_min=True, extend_max=True, font_size=DEFAULT_FONT_SIZE)

    tick_values = colour_bar_object.ax.get_xticks()
    tick_label_strings = ['{0:.1f}'.format(x) for x in tick_values]
    colour_bar_object.set_ticks(tick_values)
    colour_bar_object.set_ticklabels(tick_label_strings)
Exemple #5
0
def _plot_saliency_vector_p_vector_t(saliency_matrix, predictor_names,
                                     target_names, height_labels,
                                     example_id_string, colour_map_object,
                                     max_colour_percentile, output_dir_name):
    """Plots saliency for one example: vector predictors, vector targets.

    P = number of predictor variables
    T = number of target variables
    H = number of heights

    :param saliency_matrix: H-by-P-by-H-by-T numpy array of saliency values.
    :param predictor_names: length-P list of predictor names.
    :param target_names: length-T list of target names.
    :param height_labels: length-H list of height labels (strings).
    :param example_id_string: Example ID.
    :param colour_map_object: See documentation at top of file.
    :param max_colour_percentile: Same.
    :param output_dir_name: Same.
    """

    predictor_names_verbose = [
        PREDICTOR_NAME_TO_VERBOSE[n] for n in predictor_names
    ]
    target_names_verbose = [TARGET_NAME_TO_VERBOSE[n] for n in target_names]

    num_targets = len(target_names)
    num_predictors = len(predictor_names)
    num_heights = len(height_labels)

    for j in range(num_predictors):
        for k in range(num_targets):
            max_colour_value = numpy.percentile(
                numpy.abs(saliency_matrix[:, j, :, k]), max_colour_percentile)
            max_colour_value = numpy.maximum(max_colour_value, 0.001)
            min_colour_value = -1 * max_colour_value

            figure_object, axes_object = pyplot.subplots(
                1, 1, figsize=(FIGURE_WIDTH_INCHES, FIGURE_HEIGHT_INCHES))

            axes_object.imshow(numpy.transpose(saliency_matrix[:, j, :, k]),
                               cmap=colour_map_object,
                               vmin=min_colour_value,
                               vmax=max_colour_value,
                               origin='lower')

            tick_values = numpy.linspace(0,
                                         num_heights - 1,
                                         num=num_heights,
                                         dtype=float)
            axes_object.set_xticks(tick_values)
            axes_object.set_yticks(tick_values)

            axes_object.set_xticklabels(height_labels,
                                        fontsize=TICK_LABEL_FONT_SIZE,
                                        rotation=90.)
            axes_object.set_yticklabels(height_labels,
                                        fontsize=TICK_LABEL_FONT_SIZE)

            axes_object.set_xlabel('Predictor height (km AGL)')
            axes_object.set_ylabel('Target height (km AGL)')

            axes_object.plot(axes_object.get_xlim(),
                             axes_object.get_ylim(),
                             color=REFERENCE_LINE_COLOUR,
                             linestyle='dashed',
                             linewidth=REFERENCE_LINE_WIDTH)

            colour_bar_object = plotting_utils.plot_linear_colour_bar(
                axes_object_or_matrix=axes_object,
                data_matrix=saliency_matrix[:, j, :, k],
                colour_map_object=colour_map_object,
                min_value=min_colour_value,
                max_value=max_colour_value,
                orientation_string='horizontal',
                padding=0.1,
                extend_min=True,
                extend_max=True,
                fraction_of_axis_length=0.8,
                font_size=DEFAULT_FONT_SIZE)

            tick_values = colour_bar_object.get_ticks()
            tick_strings = ['{0:.1f}'.format(v) for v in tick_values]
            colour_bar_object.set_ticks(tick_values)
            colour_bar_object.set_ticklabels(tick_strings)

            title_string = 'Saliency for {0:s} with respect to {1:s}'.format(
                target_names_verbose[k], predictor_names_verbose[j])
            axes_object.set_title(title_string, fontsize=DEFAULT_FONT_SIZE)

            output_file_name = '{0:s}/{1:s}_{2:s}_{3:s}.jpg'.format(
                output_dir_name, example_id_string.replace('_', '-'),
                predictor_names[j].replace('_', '-'),
                target_names[k].replace('_', '-'))
            print('Saving figure to: "{0:s}"...'.format(output_file_name))

            figure_object.savefig(output_file_name,
                                  dpi=FIGURE_RESOLUTION_DPI,
                                  pad_inches=0,
                                  bbox_inches='tight')
            pyplot.close(figure_object)
Exemple #6
0
def _plot_saliency_scalar_p_scalar_t(saliency_matrix, predictor_names,
                                     target_names, example_id_string,
                                     colour_map_object, max_colour_percentile,
                                     output_dir_name):
    """Plots saliency for one example: scalar predictors, scalar targets.

    P = number of predictor variables
    T = number of target variables

    :param saliency_matrix: P-by-T numpy array of saliency values.
    :param predictor_names: length-P list of predictor names.
    :param target_names: length-T list of target names.
    :param example_id_string: Example ID.
    :param colour_map_object: See documentation at top of file.
    :param max_colour_percentile: Same.
    :param output_dir_name: Same.
    """

    predictor_names_verbose = [
        PREDICTOR_NAME_TO_VERBOSE[n] for n in predictor_names
    ]
    target_names_verbose = [TARGET_NAME_TO_VERBOSE[n] for n in target_names]

    max_colour_value = numpy.percentile(numpy.absolute(saliency_matrix),
                                        max_colour_percentile)
    max_colour_value = numpy.maximum(max_colour_value, 0.001)
    min_colour_value = -1 * max_colour_value

    figure_object, axes_object = pyplot.subplots(
        1, 1, figsize=(FIGURE_WIDTH_INCHES, FIGURE_HEIGHT_INCHES))
    axes_object.imshow(numpy.transpose(saliency_matrix),
                       cmap=colour_map_object,
                       vmin=min_colour_value,
                       vmax=max_colour_value,
                       origin='lower')

    num_predictors = len(predictor_names)
    num_targets = len(target_names)
    x_tick_values = numpy.linspace(0,
                                   num_predictors - 1,
                                   num=num_predictors,
                                   dtype=float)
    y_tick_values = numpy.linspace(0,
                                   num_targets - 1,
                                   num=num_targets,
                                   dtype=float)
    axes_object.set_xticks(x_tick_values)
    axes_object.set_yticks(y_tick_values)

    x_tick_labels = [
        '{0:s}{1:s}'.format(n[0].upper(), n[1:])
        for n in predictor_names_verbose
    ]
    y_tick_labels = [
        '{0:s}{1:s}'.format(n[0].upper(), n[1:]) for n in target_names_verbose
    ]
    axes_object.set_xticklabels(x_tick_labels,
                                fontsize=TICK_LABEL_FONT_SIZE,
                                rotation=90.)
    axes_object.set_yticklabels(y_tick_labels, fontsize=TICK_LABEL_FONT_SIZE)

    axes_object.set_xlabel('Predictor')
    axes_object.set_ylabel('Target')

    orientation_string = ('horizontal'
                          if len(x_tick_values) >= len(y_tick_values) else
                          'vertical')

    colour_bar_object = plotting_utils.plot_linear_colour_bar(
        axes_object_or_matrix=axes_object,
        data_matrix=saliency_matrix,
        colour_map_object=colour_map_object,
        min_value=min_colour_value,
        max_value=max_colour_value,
        orientation_string=orientation_string,
        padding=0.1 if orientation_string == 'horizontal' else 0.01,
        extend_min=True,
        extend_max=True,
        fraction_of_axis_length=0.8,
        font_size=DEFAULT_FONT_SIZE)

    tick_values = colour_bar_object.get_ticks()
    tick_strings = ['{0:.1f}'.format(v) for v in tick_values]
    colour_bar_object.set_ticks(tick_values)
    colour_bar_object.set_ticklabels(tick_strings)

    axes_object.set_title(
        'Saliency for scalar targets with respect to scalar predictors',
        fontsize=DEFAULT_FONT_SIZE)

    output_file_name = '{0:s}/{1:s}_scalars.jpg'.format(
        output_dir_name, example_id_string.replace('_', '-'))
    print('Saving figure to: "{0:s}"...'.format(output_file_name))

    figure_object.savefig(output_file_name,
                          dpi=FIGURE_RESOLUTION_DPI,
                          pad_inches=0,
                          bbox_inches='tight')
    pyplot.close(figure_object)
Exemple #7
0
def _plot_one_value(data_matrix,
                    grid_metadata_dict,
                    colour_map_object,
                    min_colour_value,
                    max_colour_value,
                    plot_cbar_min_arrow,
                    plot_cbar_max_arrow,
                    log_scale=False):
    """Plots one value (score, num examples, or num positive examples).

    M = number of rows in grid
    N = number of columns in grid

    :param data_matrix: M-by-N numpy array of values to plot.
    :param grid_metadata_dict: Dictionary returned by
        `grids.read_equidistant_metafile`.
    :param colour_map_object: See documentation at top of file.
    :param min_colour_value: Minimum value in colour scheme.
    :param max_colour_value: Max value in colour scheme.
    :param plot_cbar_min_arrow: Boolean flag.  If True, will plot arrow at
        bottom of colour bar (to signify that lower values are possible).
    :param plot_cbar_max_arrow: Boolean flag.  If True, will plot arrow at top
        of colour bar (to signify that higher values are possible).
    :param log_scale: Boolean flag (True if `data_matrix` contains data in log
        scale).
    :return: figure_object: Figure handle (instance of
        `matplotlib.figure.Figure`).
    :return: axes_object: Axes handle (instance of
        `matplotlib.axes._subplots.AxesSubplot`).
    """

    figure_object, axes_object = pyplot.subplots(
        1, 1, figsize=(FIGURE_WIDTH_INCHES, FIGURE_HEIGHT_INCHES))

    basemap_object, basemap_x_matrix_metres, basemap_y_matrix_metres = (
        _get_basemap(grid_metadata_dict))

    num_grid_rows = data_matrix.shape[0]
    num_grid_columns = data_matrix.shape[1]
    x_spacing_metres = (
        (basemap_x_matrix_metres[0, -1] - basemap_x_matrix_metres[0, 0]) /
        (num_grid_columns - 1))
    y_spacing_metres = (
        (basemap_y_matrix_metres[-1, 0] - basemap_y_matrix_metres[0, 0]) /
        (num_grid_rows - 1))

    data_matrix_at_edges, edge_x_coords_metres, edge_y_coords_metres = (
        grids.xy_field_grid_points_to_edges(
            field_matrix=data_matrix,
            x_min_metres=basemap_x_matrix_metres[0, 0],
            y_min_metres=basemap_y_matrix_metres[0, 0],
            x_spacing_metres=x_spacing_metres,
            y_spacing_metres=y_spacing_metres))

    data_matrix_at_edges = numpy.ma.masked_where(
        numpy.isnan(data_matrix_at_edges), data_matrix_at_edges)

    # data_matrix_at_edges[numpy.isnan(data_matrix_at_edges)] = -1

    plotting_utils.plot_coastlines(basemap_object=basemap_object,
                                   axes_object=axes_object,
                                   line_colour=BORDER_COLOUR)

    plotting_utils.plot_countries(basemap_object=basemap_object,
                                  axes_object=axes_object,
                                  line_colour=BORDER_COLOUR)

    plotting_utils.plot_states_and_provinces(basemap_object=basemap_object,
                                             axes_object=axes_object,
                                             line_colour=BORDER_COLOUR)

    plotting_utils.plot_parallels(basemap_object=basemap_object,
                                  axes_object=axes_object,
                                  num_parallels=NUM_PARALLELS)

    plotting_utils.plot_meridians(basemap_object=basemap_object,
                                  axes_object=axes_object,
                                  num_meridians=NUM_MERIDIANS)

    basemap_object.pcolormesh(edge_x_coords_metres,
                              edge_y_coords_metres,
                              data_matrix_at_edges,
                              cmap=colour_map_object,
                              vmin=min_colour_value,
                              vmax=max_colour_value,
                              shading='flat',
                              edgecolors='None',
                              axes=axes_object,
                              zorder=-1e12)

    colour_bar_object = plotting_utils.plot_linear_colour_bar(
        axes_object_or_matrix=axes_object,
        data_matrix=data_matrix,
        colour_map_object=colour_map_object,
        min_value=min_colour_value,
        max_value=max_colour_value,
        orientation_string='horizontal',
        extend_min=plot_cbar_min_arrow,
        extend_max=plot_cbar_max_arrow,
        padding=0.05)

    tick_values = colour_bar_object.get_ticks()

    if log_scale:
        tick_strings = [
            '{0:d}'.format(int(numpy.round(10**v))) for v in tick_values
        ]
    elif numpy.nanmax(data_matrix) >= 6:
        tick_strings = [
            '{0:d}'.format(int(numpy.round(v))) for v in tick_values
        ]
    else:
        tick_strings = ['{0:.2f}'.format(v) for v in tick_values]

    colour_bar_object.set_ticks(tick_values)
    colour_bar_object.set_ticklabels(tick_strings)

    return figure_object, axes_object
def _plot_2d_radar_saliency(
        saliency_matrix, colour_map_object, max_colour_value, half_num_contours,
        label_colour_bars, colour_bar_length, figure_objects,
        axes_object_matrices, model_metadata_dict, output_dir_name,
        significance_matrix=None, full_storm_id_string=None,
        storm_time_unix_sec=None):
    """Plots saliency map for 2-D radar data.

    M = number of rows in spatial grid
    N = number of columns in spatial grid
    C = number of radar channels

    If this method is plotting a composite rather than single example (storm
    object), `full_storm_id_string` and `storm_time_unix_sec` can be None.

    :param saliency_matrix: M-by-N-by-C numpy array of saliency values.
    :param colour_map_object: See documentation at top of file.
    :param max_colour_value: Same.
    :param half_num_contours: Same.
    :param label_colour_bars: Same.
    :param colour_bar_length: Same.
    :param figure_objects: See doc for
        `plot_input_examples._plot_2d_radar_scan`.
    :param axes_object_matrices: Same.
    :param model_metadata_dict: Dictionary returned by
        `cnn.read_model_metadata`.
    :param output_dir_name: Path to output directory.  Figure(s) will be saved
        here.
    :param significance_matrix: M-by-N-by-H numpy array of Boolean flags,
        indicating where differences with some other saliency map are
        significant.
    :param full_storm_id_string: Full storm ID.
    :param storm_time_unix_sec: Storm time.
    """

    if max_colour_value is None:
        max_colour_value = numpy.percentile(
            numpy.absolute(saliency_matrix), MAX_COLOUR_PERCENTILE
        )

    pmm_flag = full_storm_id_string is None and storm_time_unix_sec is None
    conv_2d3d = model_metadata_dict[cnn.CONV_2D3D_KEY]

    if conv_2d3d:
        figure_index = 1
        radar_field_name = 'shear'
    else:
        figure_index = 0
        radar_field_name = None

    saliency_plotting.plot_many_2d_grids_with_contours(
        saliency_matrix_3d=numpy.flip(saliency_matrix, axis=0),
        axes_object_matrix=axes_object_matrices[figure_index],
        colour_map_object=colour_map_object,
        max_absolute_contour_level=max_colour_value,
        contour_interval=max_colour_value / half_num_contours,
        row_major=False)

    if significance_matrix is not None:
        significance_plotting.plot_many_2d_grids_without_coords(
            significance_matrix=numpy.flip(significance_matrix, axis=0),
            axes_object_matrix=axes_object_matrices[figure_index],
            row_major=False)

    colour_bar_object = plotting_utils.plot_linear_colour_bar(
        axes_object_or_matrix=axes_object_matrices[figure_index],
        data_matrix=saliency_matrix,
        colour_map_object=colour_map_object, min_value=0.,
        max_value=max_colour_value, orientation_string='horizontal',
        fraction_of_axis_length=colour_bar_length / (1 + int(conv_2d3d)),
        extend_min=False, extend_max=True, font_size=COLOUR_BAR_FONT_SIZE)

    if label_colour_bars:
        colour_bar_object.set_label(
            'Absolute saliency', fontsize=COLOUR_BAR_FONT_SIZE)

    output_file_name = plot_examples.metadata_to_file_name(
        output_dir_name=output_dir_name, is_sounding=False, pmm_flag=pmm_flag,
        full_storm_id_string=full_storm_id_string,
        storm_time_unix_sec=storm_time_unix_sec,
        radar_field_name=radar_field_name)

    print('Saving figure to: "{0:s}"...'.format(output_file_name))
    figure_objects[figure_index].savefig(
        output_file_name, dpi=FIGURE_RESOLUTION_DPI, pad_inches=0,
        bbox_inches='tight'
    )
    pyplot.close(figure_objects[figure_index])
Exemple #9
0
def _plot_one_score(score_matrix, colour_map_object, min_colour_value,
                    max_colour_value, colour_bar_label, is_score_bias,
                    best_model_index, output_file_name):
    """Plots one score.

    :param score_matrix: 4-D numpy array of scores, where the first axis
        represents dropout rate; second represents L2 weight; third represents
        num dense layers; and fourth is data augmentation (yes or no).
    :param colour_map_object: See documentation at top of file.
    :param min_colour_value: Minimum value in colour scheme.
    :param max_colour_value: Max value in colour scheme.
    :param colour_bar_label: Label string for colour bar.
    :param is_score_bias: Boolean flag.  If True, score to be plotted is
        frequency bias, which changes settings for colour scheme.
    :param best_model_index: Linear index of best model.
    :param output_file_name: Path to output file (figure will be saved here).
    """

    if is_score_bias:
        colour_map_object, colour_norm_object = _get_bias_colour_scheme(
            max_value=max_colour_value)
    else:
        colour_norm_object = None

    num_dense_layer_counts = len(DENSE_LAYER_COUNTS)
    num_data_aug_flags = len(DATA_AUGMENTATION_FLAGS)

    figure_object, axes_object_matrix = plotting_utils.create_paneled_figure(
        num_rows=num_dense_layer_counts * num_data_aug_flags,
        num_columns=1,
        horizontal_spacing=0.15,
        vertical_spacing=0.15,
        shared_x_axis=False,
        shared_y_axis=False,
        keep_aspect_ratio=True)

    axes_object_matrix = numpy.reshape(
        axes_object_matrix, (num_dense_layer_counts, num_data_aug_flags))

    x_axis_label = r'L$_2$ weight (log$_{10}$)'
    y_axis_label = 'Dropout rate'
    x_tick_labels = ['{0:.1f}'.format(w) for w in numpy.log10(L2_WEIGHTS)]
    y_tick_labels = ['{0:.3f}'.format(d) for d in DROPOUT_RATES]

    best_model_index_tuple = numpy.unravel_index(best_model_index,
                                                 score_matrix.shape)

    for k in range(num_dense_layer_counts):
        for m in range(num_data_aug_flags):
            model_eval.plot_hyperparam_grid(
                score_matrix=score_matrix[..., k, m],
                min_colour_value=min_colour_value,
                max_colour_value=max_colour_value,
                colour_map_object=colour_map_object,
                colour_norm_object=colour_norm_object,
                axes_object=axes_object_matrix[k, m])

            axes_object_matrix[k, m].set_xticklabels(
                x_tick_labels, fontsize=TICK_LABEL_FONT_SIZE, rotation=90.)
            axes_object_matrix[k, m].set_yticklabels(
                y_tick_labels, fontsize=TICK_LABEL_FONT_SIZE)
            axes_object_matrix[k, m].set_ylabel(y_axis_label,
                                                fontsize=TICK_LABEL_FONT_SIZE)

            if k == num_dense_layer_counts - 1 and m == num_data_aug_flags - 1:
                axes_object_matrix[k, m].set_xlabel(x_axis_label)
            else:
                axes_object_matrix[k, m].set_xticks([], [])

            this_title_string = '{0:d} dense layer{1:s}, DA {2:s}'.format(
                DENSE_LAYER_COUNTS[k],
                's' if DENSE_LAYER_COUNTS[k] > 1 else '',
                'on' if DATA_AUGMENTATION_FLAGS[m] else 'off')

            axes_object_matrix[k, m].set_title(this_title_string)

    i = best_model_index_tuple[0]
    j = best_model_index_tuple[1]
    k = best_model_index_tuple[2]
    m = best_model_index_tuple[3]

    axes_object_matrix[k, m].plot(j,
                                  i,
                                  linestyle='None',
                                  marker=BEST_MODEL_MARKER_TYPE,
                                  markersize=BEST_MODEL_MARKER_SIZE,
                                  markerfacecolor=MARKER_COLOUR,
                                  markeredgecolor=MARKER_COLOUR,
                                  markeredgewidth=BEST_MODEL_MARKER_WIDTH)

    corrupt_model_indices = numpy.where(numpy.isnan(
        numpy.ravel(score_matrix)))[0]

    for this_linear_index in corrupt_model_indices:
        i, j, k, m = numpy.unravel_index(this_linear_index, score_matrix.shape)

        axes_object_matrix[k,
                           m].plot(j,
                                   i,
                                   linestyle='None',
                                   marker=CORRUPT_MODEL_MARKER_TYPE,
                                   markersize=CORRUPT_MODEL_MARKER_SIZE,
                                   markerfacecolor=MARKER_COLOUR,
                                   markeredgecolor=MARKER_COLOUR,
                                   markeredgewidth=CORRUPT_MODEL_MARKER_WIDTH)

    if is_score_bias:
        colour_bar_object = plotting_utils.plot_colour_bar(
            axes_object_or_matrix=axes_object_matrix,
            data_matrix=score_matrix,
            colour_map_object=colour_map_object,
            colour_norm_object=colour_norm_object,
            orientation_string='vertical',
            extend_min=False,
            extend_max=True,
            font_size=DEFAULT_FONT_SIZE)

        tick_values = colour_bar_object.get_ticks()
        tick_strings = ['{0:.1f}'.format(v) for v in tick_values]

        colour_bar_object.set_ticks(tick_values)
        colour_bar_object.set_ticklabels(tick_strings)
    else:
        colour_bar_object = plotting_utils.plot_linear_colour_bar(
            axes_object_or_matrix=axes_object_matrix,
            data_matrix=score_matrix,
            colour_map_object=colour_map_object,
            min_value=min_colour_value,
            max_value=max_colour_value,
            orientation_string='vertical',
            extend_min=True,
            extend_max=True,
            font_size=DEFAULT_FONT_SIZE)

    colour_bar_object.set_label(colour_bar_label)
    print('Saving figure to: "{0:s}"...'.format(output_file_name))

    figure_object.savefig(output_file_name,
                          dpi=FIGURE_RESOLUTION_DPI,
                          pad_inches=0,
                          bbox_inches='tight')
    pyplot.close(figure_object)
Exemple #10
0
def _add_colour_bar(figure_file_name, colour_map_object, max_colour_value,
                    temporary_dir_name):
    """Adds colour bar to saved image file.

    :param figure_file_name: Path to saved image file.  Colour bar will be added
        to this image.
    :param colour_map_object: Colour scheme (instance of `matplotlib.pyplot.cm`
        or similar).
    :param max_colour_value: Max value in colour scheme.
    :param temporary_dir_name: Name of temporary output directory.
    """

    this_image_matrix = Image.open(figure_file_name)
    figure_width_px, figure_height_px = this_image_matrix.size
    figure_width_inches = float(figure_width_px) / FIGURE_RESOLUTION_DPI
    figure_height_inches = float(figure_height_px) / FIGURE_RESOLUTION_DPI

    extra_figure_object, extra_axes_object = pyplot.subplots(
        1, 1, figsize=(figure_width_inches, figure_height_inches))
    extra_axes_object.axis('off')

    dummy_values = numpy.array([0., max_colour_value])

    colour_bar_object = plotting_utils.plot_linear_colour_bar(
        axes_object_or_matrix=extra_axes_object,
        data_matrix=dummy_values,
        colour_map_object=colour_map_object,
        min_value=0.,
        max_value=max_colour_value,
        orientation_string='vertical',
        fraction_of_axis_length=1.25,
        extend_min=False,
        extend_max=True,
        font_size=COLOUR_BAR_FONT_SIZE,
        aspect_ratio=50.)

    tick_values = colour_bar_object.get_ticks()

    if max_colour_value <= 0.005:
        tick_strings = ['{0:.4f}'.format(v) for v in tick_values]
    elif max_colour_value <= 0.05:
        tick_strings = ['{0:.3f}'.format(v) for v in tick_values]
    else:
        tick_strings = ['{0:.2f}'.format(v) for v in tick_values]

    colour_bar_object.set_ticks(tick_values)
    colour_bar_object.set_ticklabels(tick_strings)

    extra_file_name = '{0:s}/saliency_colour-bar.jpg'.format(
        temporary_dir_name)
    print('Saving colour bar to: "{0:s}"...'.format(extra_file_name))

    extra_figure_object.savefig(extra_file_name,
                                dpi=FIGURE_RESOLUTION_DPI,
                                pad_inches=0,
                                bbox_inches='tight')
    pyplot.close(extra_figure_object)

    print('Concatenating colour bar to: "{0:s}"...'.format(figure_file_name))

    imagemagick_utils.concatenate_images(
        input_file_names=[figure_file_name, extra_file_name],
        output_file_name=figure_file_name,
        num_panel_rows=1,
        num_panel_columns=2,
        extra_args_string='-gravity Center')

    os.remove(extra_file_name)
    imagemagick_utils.trim_whitespace(input_file_name=figure_file_name,
                                      output_file_name=figure_file_name)
Exemple #11
0
def plot_storm_tracks(storm_object_table,
                      axes_object,
                      basemap_object,
                      colour_map_object='random',
                      line_colour=DEFAULT_TRACK_COLOUR,
                      line_width=DEFAULT_TRACK_WIDTH,
                      start_marker_type=DEFAULT_START_MARKER_TYPE,
                      end_marker_type=DEFAULT_END_MARKER_TYPE,
                      start_marker_size=DEFAULT_START_MARKER_SIZE,
                      end_marker_size=DEFAULT_END_MARKER_SIZE):
    """Plots one or more storm tracks on the same map.

    :param storm_object_table: See doc for `plot_storm_outlines`.
    :param axes_object: Same.
    :param basemap_object: Same.
    :param colour_map_object: There are 3 cases.

    If "random", each track will be plotted in a random colour from
    `get_storm_track_colours`.

    If None, each track will be plotted in `line_colour` (the next input arg).

    If real colour map (instance of `matplotlib.pyplot.cm`), track segments will
    be coloured by time, according to this colour map.

    :param line_colour: [used only if `colour_map_object is None`]
        length-3 numpy array with (R, G, B).  Will be used for all tracks.
    :param line_width: Width of each storm track.
    :param start_marker_type: Marker type for beginning of track (in any format
        accepted by `matplotlib.lines`).  If `start_marker_type is None`,
        markers will not be used to show beginning of each track.
    :param end_marker_type: Same but for end of track.
    :param start_marker_size: Size of each start-point marker.
    :param end_marker_size: Size of each end-point marker.
    """

    plot_start_markers = start_marker_type is not None
    plot_end_markers = end_marker_type is not None

    if start_marker_type is None:
        start_marker_type = DEFAULT_START_MARKER_TYPE
        start_marker_size = DEFAULT_START_MARKER_SIZE

    if end_marker_type is None:
        end_marker_type = DEFAULT_END_MARKER_TYPE
        end_marker_size = DEFAULT_END_MARKER_SIZE

    x_coords_metres, y_coords_metres = basemap_object(
        storm_object_table[tracking_utils.CENTROID_LONGITUDE_COLUMN].values,
        storm_object_table[tracking_utils.CENTROID_LATITUDE_COLUMN].values)

    storm_object_table = storm_object_table.assign(
        **{
            tracking_utils.CENTROID_X_COLUMN: x_coords_metres,
            tracking_utils.CENTROID_Y_COLUMN: y_coords_metres
        })

    rgb_matrix = None
    num_colours = None
    colour_norm_object = None

    if colour_map_object is None:
        error_checking.assert_is_numpy_array(line_colour,
                                             exact_dimensions=numpy.array(
                                                 [3], dtype=int))

        rgb_matrix = numpy.reshape(line_colour, (1, 3))
        num_colours = rgb_matrix.shape[0]

    elif colour_map_object == 'random':
        rgb_matrix = get_storm_track_colours()
        num_colours = rgb_matrix.shape[0]

        colour_map_object = None

    else:
        first_time_unix_sec = numpy.min(
            storm_object_table[tracking_utils.VALID_TIME_COLUMN].values)
        last_time_unix_sec = numpy.max(
            storm_object_table[tracking_utils.VALID_TIME_COLUMN].values)

        colour_norm_object = pyplot.Normalize(first_time_unix_sec,
                                              last_time_unix_sec)

    track_primary_id_strings, object_to_track_indices = numpy.unique(
        storm_object_table[tracking_utils.PRIMARY_ID_COLUMN].values,
        return_inverse=True)

    num_tracks = len(track_primary_id_strings)

    for k in range(num_tracks):
        if colour_map_object is None:
            this_colour = rgb_matrix[numpy.mod(k, num_colours), :]
            this_colour = plotting_utils.colour_from_numpy_to_tuple(
                this_colour)
        else:
            this_colour = None

        these_object_indices = numpy.where(object_to_track_indices == k)[0]

        for i in these_object_indices:
            these_next_indices = temporal_tracking.find_immediate_successors(
                storm_object_table=storm_object_table, target_row=i)

            # if len(these_next_indices) > 1:
            #     axes_object.text(
            #         storm_object_table[
            #             tracking_utils.CENTROID_X_COLUMN].values[i],
            #         storm_object_table[
            #             tracking_utils.CENTROID_Y_COLUMN].values[i],
            #         '{0:d}-WAY SPLIT'.format(len(these_next_indices)),
            #         fontsize=12, color='k',
            #         horizontalalignment='left', verticalalignment='top')

            for j in these_next_indices:
                these_x_coords_metres = storm_object_table[
                    tracking_utils.CENTROID_X_COLUMN].values[[i, j]]

                these_y_coords_metres = storm_object_table[
                    tracking_utils.CENTROID_Y_COLUMN].values[[i, j]]

                if colour_map_object is None:
                    axes_object.plot(these_x_coords_metres,
                                     these_y_coords_metres,
                                     color=this_colour,
                                     linestyle='solid',
                                     linewidth=line_width)
                else:
                    this_point_matrix = numpy.array(
                        [these_x_coords_metres,
                         these_y_coords_metres]).T.reshape(-1, 1, 2)

                    this_segment_matrix = numpy.concatenate(
                        [this_point_matrix[:-1], this_point_matrix[1:]],
                        axis=1)

                    this_time_unix_sec = numpy.mean(storm_object_table[
                        tracking_utils.VALID_TIME_COLUMN].values[[i, j]])

                    this_line_collection_object = LineCollection(
                        this_segment_matrix,
                        cmap=colour_map_object,
                        norm=colour_norm_object)

                    this_line_collection_object.set_array(
                        numpy.array([this_time_unix_sec]))
                    this_line_collection_object.set_linewidth(line_width)
                    axes_object.add_collection(this_line_collection_object)

            these_prev_indices = temporal_tracking.find_immediate_predecessors(
                storm_object_table=storm_object_table, target_row=i)

            # if len(these_prev_indices) > 1:
            #     axes_object.text(
            #         storm_object_table[
            #             tracking_utils.CENTROID_X_COLUMN].values[i],
            #         storm_object_table[
            #             tracking_utils.CENTROID_Y_COLUMN].values[i],
            #         '{0:d}-WAY MERGER'.format(len(these_prev_indices)),
            #         fontsize=12, color='k',
            #         horizontalalignment='left', verticalalignment='top')

            plot_this_start_marker = ((plot_start_markers
                                       and len(these_prev_indices) == 0)
                                      or len(these_object_indices) == 1)

            if plot_this_start_marker:
                if colour_map_object is not None:
                    this_colour = colour_map_object(
                        colour_norm_object(storm_object_table[
                            tracking_utils.VALID_TIME_COLUMN].values[i]))

                if start_marker_type == 'x':
                    this_edge_width = 2
                else:
                    this_edge_width = 1

                axes_object.plot(
                    storm_object_table[
                        tracking_utils.CENTROID_X_COLUMN].values[i],
                    storm_object_table[
                        tracking_utils.CENTROID_Y_COLUMN].values[i],
                    linestyle='None',
                    marker=start_marker_type,
                    markerfacecolor=this_colour,
                    markeredgecolor=this_colour,
                    markersize=start_marker_size,
                    markeredgewidth=this_edge_width)

            plot_this_end_marker = ((plot_end_markers
                                     and len(these_next_indices) == 0)
                                    or len(these_object_indices) == 1)

            if plot_this_end_marker:
                if colour_map_object is not None:
                    this_colour = colour_map_object(
                        colour_norm_object(storm_object_table[
                            tracking_utils.VALID_TIME_COLUMN].values[i]))

                if end_marker_type == 'x':
                    this_edge_width = 2
                else:
                    this_edge_width = 1

                axes_object.plot(
                    storm_object_table[
                        tracking_utils.CENTROID_X_COLUMN].values[i],
                    storm_object_table[
                        tracking_utils.CENTROID_Y_COLUMN].values[i],
                    linestyle='None',
                    marker=end_marker_type,
                    markerfacecolor=this_colour,
                    markeredgecolor=this_colour,
                    markersize=end_marker_size,
                    markeredgewidth=this_edge_width)

    if colour_map_object is None:
        return

    min_plot_latitude_deg = basemap_object.llcrnrlat
    max_plot_latitude_deg = basemap_object.urcrnrlat
    min_plot_longitude_deg = basemap_object.llcrnrlon
    max_plot_longitude_deg = basemap_object.urcrnrlon

    latitude_range_deg = max_plot_latitude_deg - min_plot_latitude_deg
    longitude_range_deg = max_plot_longitude_deg - min_plot_longitude_deg

    if latitude_range_deg > longitude_range_deg:
        orientation_string = 'vertical'
    else:
        orientation_string = 'horizontal'

    colour_bar_object = plotting_utils.plot_linear_colour_bar(
        axes_object_or_matrix=axes_object,
        data_matrix=storm_object_table[
            tracking_utils.VALID_TIME_COLUMN].values,
        colour_map_object=colour_map_object,
        min_value=colour_norm_object.vmin,
        max_value=colour_norm_object.vmax,
        orientation_string=orientation_string,
        extend_min=False,
        extend_max=False,
        fraction_of_axis_length=0.9,
        font_size=COLOUR_BAR_FONT_SIZE)

    if orientation_string == 'horizontal':
        tick_values = colour_bar_object.ax.get_xticks()
    else:
        tick_values = colour_bar_object.ax.get_yticks()

    tick_times_unix_sec = numpy.round(
        colour_norm_object.inverse(tick_values)).astype(int)

    slope_sec_per_sec = (float(last_time_unix_sec - first_time_unix_sec) /
                         (tick_times_unix_sec[-1] - tick_times_unix_sec[0]))

    tick_times_unix_sec = numpy.round(
        first_time_unix_sec + slope_sec_per_sec *
        (tick_times_unix_sec - tick_times_unix_sec[0])).astype(int)

    tick_time_strings = [
        time_conversion.unix_sec_to_string(t, '%Y-%m-%d-%H%M%S')
        for t in tick_times_unix_sec
    ]

    print(tick_time_strings)

    tick_time_strings = [
        time_conversion.unix_sec_to_string(t, COLOUR_BAR_TIME_FORMAT)
        for t in tick_times_unix_sec
    ]

    print(tick_time_strings)

    colour_bar_object.set_ticks(tick_values)
    colour_bar_object.set_ticklabels(tick_time_strings)
Exemple #12
0
def _plot_rapruc_one_example(
        full_storm_id_string, storm_time_unix_sec, top_tracking_dir_name,
        latitude_buffer_deg, longitude_buffer_deg, lead_time_seconds,
        field_name_grib1, output_dir_name, rap_file_name=None,
        ruc_file_name=None):
    """Plots RAP or RUC field for one example.

    :param full_storm_id_string: Full storm ID.
    :param storm_time_unix_sec: Valid time.
    :param top_tracking_dir_name: See documentation at top of file.
    :param latitude_buffer_deg: Same.
    :param longitude_buffer_deg: Same.
    :param lead_time_seconds: Same.
    :param field_name_grib1: Same.
    :param output_dir_name: Same.
    :param rap_file_name: Path to file with RAP analysis.
    :param ruc_file_name: [used only if `rap_file_name is None`]
        Path to file with RUC analysis.
    """

    tracking_file_name = tracking_io.find_file(
        top_tracking_dir_name=top_tracking_dir_name,
        tracking_scale_metres2=DUMMY_TRACKING_SCALE_METRES2,
        source_name=tracking_utils.SEGMOTION_NAME,
        valid_time_unix_sec=storm_time_unix_sec,
        spc_date_string=
        time_conversion.time_to_spc_date_string(storm_time_unix_sec),
        raise_error_if_missing=True
    )

    print('Reading data from: "{0:s}"...'.format(tracking_file_name))
    storm_object_table = tracking_io.read_file(tracking_file_name)
    storm_object_table = storm_object_table.loc[
        storm_object_table[tracking_utils.FULL_ID_COLUMN] ==
        full_storm_id_string
    ]

    extrap_times_sec = numpy.array([0, lead_time_seconds], dtype=int)
    storm_object_table = soundings._create_target_points_for_interp(
        storm_object_table=storm_object_table,
        lead_times_seconds=extrap_times_sec
    )

    orig_latitude_deg = (
        storm_object_table[tracking_utils.CENTROID_LATITUDE_COLUMN].values[0]
    )
    orig_longitude_deg = (
        storm_object_table[tracking_utils.CENTROID_LONGITUDE_COLUMN].values[0]
    )
    extrap_latitude_deg = (
        storm_object_table[tracking_utils.CENTROID_LATITUDE_COLUMN].values[1]
    )
    extrap_longitude_deg = (
        storm_object_table[tracking_utils.CENTROID_LONGITUDE_COLUMN].values[1]
    )

    if rap_file_name is None:
        grib_file_name = ruc_file_name
        model_name = nwp_model_utils.RUC_MODEL_NAME
    else:
        grib_file_name = rap_file_name
        model_name = nwp_model_utils.RAP_MODEL_NAME

    pathless_grib_file_name = os.path.split(grib_file_name)[-1]
    grid_name = pathless_grib_file_name.split('_')[1]

    host_name = socket.gethostname()

    if 'casper' in host_name:
        wgrib_exe_name = '/glade/work/ryanlage/wgrib/wgrib'
        wgrib2_exe_name = '/glade/work/ryanlage/wgrib2/wgrib2/wgrib2'
    else:
        wgrib_exe_name = '/condo/swatwork/ralager/wgrib/wgrib'
        wgrib2_exe_name = '/condo/swatwork/ralager/grib2/wgrib2/wgrib2'

    print('Reading field "{0:s}" from: "{1:s}"...'.format(
        field_name_grib1, grib_file_name
    ))
    main_field_matrix = nwp_model_io.read_field_from_grib_file(
        grib_file_name=grib_file_name, field_name_grib1=field_name_grib1,
        model_name=model_name, grid_id=grid_name,
        wgrib_exe_name=wgrib_exe_name, wgrib2_exe_name=wgrib2_exe_name
    )

    u_wind_name_grib1 = 'UGRD:{0:s}'.format(
        field_name_grib1.split(':')[-1]
    )
    u_wind_name_grib1 = u_wind_name_grib1.replace('2 m', '10 m')
    print('Reading field "{0:s}" from: "{1:s}"...'.format(
        u_wind_name_grib1, grib_file_name
    ))
    u_wind_matrix_m_s01 = nwp_model_io.read_field_from_grib_file(
        grib_file_name=grib_file_name, field_name_grib1=u_wind_name_grib1,
        model_name=model_name, grid_id=grid_name,
        wgrib_exe_name=wgrib_exe_name, wgrib2_exe_name=wgrib2_exe_name
    )

    v_wind_name_grib1 = 'VGRD:{0:s}'.format(
        u_wind_name_grib1.split(':')[-1]
    )
    print('Reading field "{0:s}" from: "{1:s}"...'.format(
        v_wind_name_grib1, grib_file_name
    ))
    v_wind_matrix_m_s01 = nwp_model_io.read_field_from_grib_file(
        grib_file_name=grib_file_name, field_name_grib1=v_wind_name_grib1,
        model_name=model_name, grid_id=grid_name,
        wgrib_exe_name=wgrib_exe_name, wgrib2_exe_name=wgrib2_exe_name
    )

    latitude_matrix_deg, longitude_matrix_deg = (
        nwp_model_utils.get_latlng_grid_point_matrices(
            model_name=model_name, grid_name=grid_name)
    )
    cosine_matrix, sine_matrix = nwp_model_utils.get_wind_rotation_angles(
        latitudes_deg=latitude_matrix_deg, longitudes_deg=longitude_matrix_deg,
        model_name=model_name
    )
    u_wind_matrix_m_s01, v_wind_matrix_m_s01 = (
        nwp_model_utils.rotate_winds_to_earth_relative(
            u_winds_grid_relative_m_s01=u_wind_matrix_m_s01,
            v_winds_grid_relative_m_s01=v_wind_matrix_m_s01,
            rotation_angle_cosines=cosine_matrix,
            rotation_angle_sines=sine_matrix)
    )

    min_plot_latitude_deg = (
        min([orig_latitude_deg, extrap_latitude_deg]) - latitude_buffer_deg
    )
    max_plot_latitude_deg = (
        max([orig_latitude_deg, extrap_latitude_deg]) + latitude_buffer_deg
    )
    min_plot_longitude_deg = (
        min([orig_longitude_deg, extrap_longitude_deg]) - longitude_buffer_deg
    )
    max_plot_longitude_deg = (
        max([orig_longitude_deg, extrap_longitude_deg]) + longitude_buffer_deg
    )

    row_limits, column_limits = nwp_plotting.latlng_limits_to_rowcol_limits(
        min_latitude_deg=min_plot_latitude_deg,
        max_latitude_deg=max_plot_latitude_deg,
        min_longitude_deg=min_plot_longitude_deg,
        max_longitude_deg=max_plot_longitude_deg,
        model_name=model_name, grid_id=grid_name
    )

    main_field_matrix = main_field_matrix[
        row_limits[0]:(row_limits[1] + 1),
        column_limits[0]:(column_limits[1] + 1)
    ]
    u_wind_matrix_m_s01 = u_wind_matrix_m_s01[
        row_limits[0]:(row_limits[1] + 1),
        column_limits[0]:(column_limits[1] + 1)
    ]
    v_wind_matrix_m_s01 = v_wind_matrix_m_s01[
        row_limits[0]:(row_limits[1] + 1),
        column_limits[0]:(column_limits[1] + 1)
    ]

    _, axes_object, basemap_object = nwp_plotting.init_basemap(
        model_name=model_name, grid_id=grid_name,
        first_row_in_full_grid=row_limits[0],
        last_row_in_full_grid=row_limits[1],
        first_column_in_full_grid=column_limits[0],
        last_column_in_full_grid=column_limits[1]
    )

    plotting_utils.plot_coastlines(
        basemap_object=basemap_object, axes_object=axes_object,
        line_colour=BORDER_COLOUR
    )
    plotting_utils.plot_countries(
        basemap_object=basemap_object, axes_object=axes_object,
        line_colour=BORDER_COLOUR
    )
    plotting_utils.plot_states_and_provinces(
        basemap_object=basemap_object, axes_object=axes_object,
        line_colour=BORDER_COLOUR
    )
    plotting_utils.plot_parallels(
        basemap_object=basemap_object, axes_object=axes_object,
        num_parallels=NUM_PARALLELS
    )
    plotting_utils.plot_meridians(
        basemap_object=basemap_object, axes_object=axes_object,
        num_meridians=NUM_MERIDIANS
    )

    min_colour_value = numpy.nanpercentile(
        main_field_matrix, 100. - MAX_COLOUR_PERCENTILE
    )
    max_colour_value = numpy.nanpercentile(
        main_field_matrix, MAX_COLOUR_PERCENTILE
    )

    nwp_plotting.plot_subgrid(
        field_matrix=main_field_matrix,
        model_name=model_name, grid_id=grid_name,
        axes_object=axes_object, basemap_object=basemap_object,
        colour_map_object=COLOUR_MAP_OBJECT, min_colour_value=min_colour_value,
        max_colour_value=max_colour_value,
        first_row_in_full_grid=row_limits[0],
        first_column_in_full_grid=column_limits[0]
    )

    nwp_plotting.plot_wind_barbs_on_subgrid(
        u_wind_matrix_m_s01=u_wind_matrix_m_s01,
        v_wind_matrix_m_s01=v_wind_matrix_m_s01,
        model_name=model_name, grid_id=grid_name,
        axes_object=axes_object, basemap_object=basemap_object,
        first_row_in_full_grid=row_limits[0],
        first_column_in_full_grid=column_limits[0],
        plot_every_k_rows=PLOT_EVERY_KTH_WIND_BARB,
        plot_every_k_columns=PLOT_EVERY_KTH_WIND_BARB,
        barb_length=WIND_BARB_LENGTH, empty_barb_radius=EMPTY_WIND_BARB_RADIUS,
        fill_empty_barb=True, colour_map=WIND_COLOUR_MAP_OBJECT,
        colour_minimum_kt=MIN_WIND_SPEED_KT, colour_maximum_kt=MAX_WIND_SPEED_KT
    )

    orig_x_metres, orig_y_metres = basemap_object(
        orig_longitude_deg, orig_latitude_deg
    )
    axes_object.plot(
        orig_x_metres, orig_y_metres, linestyle='None',
        marker=ORIGIN_MARKER_TYPE, markersize=ORIGIN_MARKER_SIZE,
        markeredgewidth=ORIGIN_MARKER_EDGE_WIDTH,
        markerfacecolor=MARKER_COLOUR, markeredgecolor=MARKER_COLOUR
    )

    extrap_x_metres, extrap_y_metres = basemap_object(
        extrap_longitude_deg, extrap_latitude_deg
    )
    axes_object.plot(
        extrap_x_metres, extrap_y_metres, linestyle='None',
        marker=EXTRAP_MARKER_TYPE, markersize=EXTRAP_MARKER_SIZE,
        markeredgewidth=EXTRAP_MARKER_EDGE_WIDTH,
        markerfacecolor=MARKER_COLOUR, markeredgecolor=MARKER_COLOUR
    )

    plotting_utils.plot_linear_colour_bar(
        axes_object_or_matrix=axes_object, data_matrix=main_field_matrix,
        colour_map_object=COLOUR_MAP_OBJECT,
        min_value=min_colour_value, max_value=max_colour_value,
        orientation_string='vertical'
    )

    output_file_name = '{0:s}/{1:s}_{2:s}.jpg'.format(
        output_dir_name, full_storm_id_string.replace('_', '-'),
        time_conversion.unix_sec_to_string(
            storm_time_unix_sec, FILE_NAME_TIME_FORMAT
        )
    )

    print('Saving figure to: "{0:s}"...'.format(output_file_name))
    pyplot.savefig(
        output_file_name, dpi=FIGURE_RESOLUTION_DPI,
        pad_inches=0, bbox_inches='tight'
    )
    pyplot.close()
Exemple #13
0
def _plot_one_example(
        input_feature_matrix, feature_matrix_after_conv,
        feature_matrix_after_activn, feature_matrix_after_bn,
        feature_matrix_after_pooling, output_file_name):
    """Plots entire figure for one example (storm object).

    :param input_feature_matrix: 2-D numpy array with input features.
    :param feature_matrix_after_conv: 2-D numpy array with features after
        convolution.
    :param feature_matrix_after_activn: 2-D numpy array with features after
        activation.
    :param feature_matrix_after_bn: 2-D numpy array with features after batch
        normalization.
    :param feature_matrix_after_pooling: 2-D numpy array with features after
        pooling.
    :param output_file_name: Path to output file.  Figure will be saved here.
    """

    num_output_channels = feature_matrix_after_conv.shape[-1]

    figure_object, axes_object_matrix = plotting_utils.create_paneled_figure(
        num_rows=num_output_channels, num_columns=NUM_PANEL_COLUMNS,
        horizontal_spacing=0., vertical_spacing=0.,
        shared_x_axis=False, shared_y_axis=False, keep_aspect_ratio=True)

    max_colour_value = numpy.percentile(
        numpy.absolute(input_feature_matrix), MAX_COLOUR_PERCENTILE
    )

    axes_object_matrix[0, 0].set_title('Input', fontsize=TITLE_FONT_SIZE)

    for k in range(num_output_channels):
        if k == 0:
            _plot_one_feature_map(
                feature_matrix_2d=input_feature_matrix[..., k],
                max_colour_value=max_colour_value, plot_colour_bar=False,
                axes_object=axes_object_matrix[k, 0]
            )

            continue

        axes_object_matrix[k, 0].axis('off')

    colour_bar_object = plotting_utils.plot_linear_colour_bar(
        axes_object_or_matrix=axes_object_matrix[num_output_channels - 1, 0],
        data_matrix=input_feature_matrix[..., 0],
        colour_map_object=COLOUR_MAP_OBJECT, min_value=-1 * max_colour_value,
        max_value=max_colour_value, orientation_string='horizontal',
        padding=0.015, fraction_of_axis_length=0.9,
        extend_min=True, extend_max=True, font_size=DEFAULT_FONT_SIZE)

    tick_values = colour_bar_object.ax.get_xticks()
    tick_label_strings = ['{0:.1f}'.format(x) for x in tick_values]
    colour_bar_object.set_ticks(tick_values)
    colour_bar_object.set_ticklabels(tick_label_strings)

    letter_label = 'a'
    plotting_utils.label_axes(
        axes_object=axes_object_matrix[0, 0],
        label_string='({0:s})'.format(letter_label),
        font_size=TITLE_FONT_SIZE,
        x_coord_normalized=0.125, y_coord_normalized=1.025
    )

    this_matrix = numpy.stack(
        (feature_matrix_after_conv, feature_matrix_after_activn), axis=0
    )
    max_colour_value = numpy.percentile(
        numpy.absolute(this_matrix), MAX_COLOUR_PERCENTILE
    )

    axes_object_matrix[0, 1].set_title(
        '  After convolution', fontsize=TITLE_FONT_SIZE)

    for k in range(num_output_channels):
        _plot_one_feature_map(
            feature_matrix_2d=feature_matrix_after_conv[..., k],
            max_colour_value=max_colour_value,
            plot_colour_bar=k == num_output_channels - 1,
            axes_object=axes_object_matrix[k, 1]
        )

        letter_label = chr(ord(letter_label) + 1)

        plotting_utils.label_axes(
            axes_object=axes_object_matrix[k, 1],
            label_string='({0:s})'.format(letter_label),
            font_size=TITLE_FONT_SIZE,
            x_coord_normalized=0.125, y_coord_normalized=1.025
        )

    axes_object_matrix[0, 2].set_title(
        ' After activation', fontsize=TITLE_FONT_SIZE)

    for k in range(num_output_channels):
        _plot_one_feature_map(
            feature_matrix_2d=feature_matrix_after_activn[..., k],
            max_colour_value=max_colour_value,
            plot_colour_bar=k == num_output_channels - 1,
            axes_object=axes_object_matrix[k, 2]
        )

        letter_label = chr(ord(letter_label) + 1)

        plotting_utils.label_axes(
            axes_object=axes_object_matrix[k, 2],
            label_string='({0:s})'.format(letter_label),
            font_size=TITLE_FONT_SIZE,
            x_coord_normalized=0.125, y_coord_normalized=1.025
        )

    max_colour_value = numpy.percentile(
        numpy.absolute(feature_matrix_after_bn), MAX_COLOUR_PERCENTILE
    )

    axes_object_matrix[0, 3].set_title(
        '  After batch norm', fontsize=TITLE_FONT_SIZE)

    for k in range(num_output_channels):
        _plot_one_feature_map(
            feature_matrix_2d=feature_matrix_after_bn[..., k],
            max_colour_value=max_colour_value,
            plot_colour_bar=k == num_output_channels - 1,
            axes_object=axes_object_matrix[k, 3]
        )

        letter_label = chr(ord(letter_label) + 1)

        plotting_utils.label_axes(
            axes_object=axes_object_matrix[k, 3],
            label_string='({0:s})'.format(letter_label),
            font_size=TITLE_FONT_SIZE,
            x_coord_normalized=0.125, y_coord_normalized=1.025
        )

    max_colour_value = numpy.percentile(
        numpy.absolute(feature_matrix_after_pooling), MAX_COLOUR_PERCENTILE
    )

    axes_object_matrix[0, 4].set_title(
        'After pooling', fontsize=TITLE_FONT_SIZE)

    for k in range(num_output_channels):
        _plot_one_feature_map(
            feature_matrix_2d=feature_matrix_after_pooling[..., k],
            max_colour_value=max_colour_value,
            plot_colour_bar=k == num_output_channels - 1,
            axes_object=axes_object_matrix[k, 4]
        )

        letter_label = chr(ord(letter_label) + 1)

        plotting_utils.label_axes(
            axes_object=axes_object_matrix[k, 4],
            label_string='({0:s})'.format(letter_label),
            font_size=TITLE_FONT_SIZE,
            x_coord_normalized=0.125, y_coord_normalized=1.025
        )

    print('Saving figure to: "{0:s}"...'.format(output_file_name))
    figure_object.savefig(
        output_file_name, dpi=FIGURE_RESOLUTION_DPI,
        pad_inches=0, bbox_inches='tight'
    )
    pyplot.close(figure_object)
Exemple #14
0
def _plot_gradcam_one_example(
        gradcam_dict, example_index, model_metadata_dict, colour_map_object,
        max_colour_percentile, output_dir_name):
    """Plots class-activation map for one example, all target variables.

    :param gradcam_dict: Dictionary read by `gradcam.read_all_targets_file`.
    :param example_index: Will plot class-activation maps for example with this
        array index.
    :param model_metadata_dict: Dictionary read by `neural_net.read_metafile`.
    :param colour_map_object: See documentation at top of file.
    :param max_colour_percentile: Same.
    :param output_dir_name: Same.
    """

    generator_option_dict = model_metadata_dict[neural_net.TRAINING_OPTIONS_KEY]
    target_names = generator_option_dict[neural_net.VECTOR_TARGET_NAMES_KEY]
    target_names_verbose = [
        TARGET_NAME_TO_VERBOSE[n] for n in target_names
    ]

    heights_km_agl = (
        METRES_TO_KM * generator_option_dict[neural_net.HEIGHTS_KEY]
    )
    height_labels = profile_plotting.create_height_labels(
        tick_values_km_agl=heights_km_agl, use_log_scale=False
    )
    height_labels = [
        height_labels[k] if numpy.mod(k, 4) == 0 else ' '
        for k in range(len(height_labels))
    ]

    example_id_string = gradcam_dict[gradcam.EXAMPLE_IDS_KEY][example_index]
    class_activation_matrix_3d = (
        gradcam_dict[gradcam.CLASS_ACTIVATIONS_KEY][example_index, ...]
    )

    num_targets = len(target_names)
    num_heights = len(height_labels)

    for k in range(num_targets):
        class_activation_matrix_2d = class_activation_matrix_3d[..., k]

        max_colour_value = numpy.percentile(
            class_activation_matrix_2d, max_colour_percentile
        )
        max_colour_value = numpy.maximum(max_colour_value, 0.001)

        figure_object, axes_object = pyplot.subplots(
            1, 1, figsize=(FIGURE_WIDTH_INCHES, FIGURE_HEIGHT_INCHES)
        )

        axes_object.imshow(
            numpy.transpose(class_activation_matrix_2d),
            cmap=colour_map_object, vmin=0., vmax=max_colour_value,
            origin='lower'
        )

        tick_values = numpy.linspace(
            0, num_heights - 1, num=num_heights, dtype=float
        )
        axes_object.set_xticks(tick_values)
        axes_object.set_yticks(tick_values)

        axes_object.set_xticklabels(
            height_labels, fontsize=TICK_LABEL_FONT_SIZE, rotation=90.
        )
        axes_object.set_yticklabels(
            height_labels, fontsize=TICK_LABEL_FONT_SIZE
        )

        axes_object.set_xlabel('Predictor height (km AGL)')
        axes_object.set_ylabel('Target height (km AGL)')

        axes_object.plot(
            axes_object.get_xlim(), axes_object.get_ylim(),
            color=REFERENCE_LINE_COLOUR, linestyle='dashed',
            linewidth=REFERENCE_LINE_WIDTH
        )

        colour_bar_object = plotting_utils.plot_linear_colour_bar(
            axes_object_or_matrix=axes_object,
            data_matrix=class_activation_matrix_2d,
            colour_map_object=colour_map_object,
            min_value=0., max_value=max_colour_value,
            orientation_string='horizontal', padding=0.1,
            extend_min=False, extend_max=True,
            fraction_of_axis_length=0.8, font_size=DEFAULT_FONT_SIZE
        )

        tick_values = colour_bar_object.get_ticks()
        tick_strings = ['{0:.1f}'.format(v) for v in tick_values]
        colour_bar_object.set_ticks(tick_values)
        colour_bar_object.set_ticklabels(tick_strings)

        title_string = 'Class-activation map for {0:s}'.format(
            target_names_verbose[k]
        )
        axes_object.set_title(title_string, fontsize=DEFAULT_FONT_SIZE)

        output_file_name = '{0:s}/{1:s}_{2:s}.jpg'.format(
            output_dir_name, example_id_string.replace('_', '-'),
            target_names[k].replace('_', '-')
        )
        print('Saving figure to: "{0:s}"...'.format(output_file_name))

        figure_object.savefig(
            output_file_name, dpi=FIGURE_RESOLUTION_DPI, pad_inches=0,
            bbox_inches='tight'
        )
        pyplot.close(figure_object)
def _plot_2d_radar_cam(
        colour_map_object, min_unguided_value, max_unguided_value,
        num_unguided_contours, max_guided_value, half_num_guided_contours,
        label_colour_bars, colour_bar_length, figure_objects,
        axes_object_matrices, model_metadata_dict, output_dir_name,
        cam_matrix=None, guided_cam_matrix=None, full_storm_id_string=None,
        storm_time_unix_sec=None):
    """Plots class-activation map for 2-D radar data.

    M = number of rows in spatial grid
    N = number of columns in spatial grid
    F = number of radar fields

    If this method is plotting a composite rather than single example (storm
    object), `full_storm_id_string` and `storm_time_unix_sec` can be None.

    :param colour_map_object: See doc for `_plot_3d_radar_cam`.
    :param min_unguided_value: Same.
    :param max_unguided_value: Same.
    :param num_unguided_contours: Same.
    :param max_guided_value: Same.
    :param half_num_guided_contours: Same.
    :param label_colour_bars: Same.
    :param colour_bar_length: Same.
    :param figure_objects: See doc for
        `plot_input_examples._plot_2d_radar_scan`.
    :param axes_object_matrices: Same.
    :param model_metadata_dict: See doc for `_plot_3d_radar_cam`.
    :param output_dir_name: Same.
    :param cam_matrix: M-by-N numpy array of unguided class activations.
    :param guided_cam_matrix: [used only if `cam_matrix is None`]
        M-by-N-by-F numpy array of guided class activations.
    :param full_storm_id_string: Full storm ID.
    :param storm_time_unix_sec: Storm time.
    """

    pmm_flag = full_storm_id_string is None and storm_time_unix_sec is None
    conv_2d3d = model_metadata_dict[cnn.CONV_2D3D_KEY]

    if conv_2d3d:
        figure_index = 1
        radar_field_name = 'shear'
    else:
        figure_index = 0
        radar_field_name = None

    list_of_layer_operation_dicts = model_metadata_dict[
        cnn.LAYER_OPERATIONS_KEY]

    if list_of_layer_operation_dicts is None:
        training_option_dict = model_metadata_dict[cnn.TRAINING_OPTION_DICT_KEY]
        radar_field_names = training_option_dict[trainval_io.RADAR_FIELDS_KEY]
        num_channels = len(radar_field_names)
    else:
        num_channels = len(list_of_layer_operation_dicts)

    min_unguided_value_log10 = numpy.log10(min_unguided_value)
    max_unguided_value_log10 = numpy.log10(max_unguided_value)
    contour_interval_log10 = (
        (max_unguided_value_log10 - min_unguided_value_log10) /
        (num_unguided_contours - 1)
    )

    if cam_matrix is None:
        saliency_plotting.plot_many_2d_grids_with_contours(
            saliency_matrix_3d=numpy.flip(guided_cam_matrix, axis=0),
            axes_object_matrix=axes_object_matrices[figure_index],
            colour_map_object=colour_map_object,
            max_absolute_contour_level=max_guided_value,
            contour_interval=max_guided_value / half_num_guided_contours,
            row_major=False
        )

        this_colour_bar_object = plotting_utils.plot_linear_colour_bar(
            axes_object_or_matrix=axes_object_matrices[figure_index],
            data_matrix=guided_cam_matrix,
            colour_map_object=colour_map_object, min_value=0.,
            max_value=max_guided_value, orientation_string='horizontal',
            fraction_of_axis_length=colour_bar_length / (1 + int(conv_2d3d)),
            extend_min=False, extend_max=True,
            font_size=COLOUR_BAR_FONT_SIZE
        )

        if label_colour_bars:
            this_colour_bar_object.set_label(
                'Absolute guided class activation',
                fontsize=COLOUR_BAR_FONT_SIZE
            )
    else:
        this_cam_matrix_log10 = numpy.log10(
            numpy.expand_dims(cam_matrix, axis=-1)
        )
        this_cam_matrix_log10 = numpy.repeat(
            this_cam_matrix_log10, repeats=num_channels, axis=-1
        )

        cam_plotting.plot_many_2d_grids(
            class_activation_matrix_3d=numpy.flip(
                this_cam_matrix_log10, axis=0
            ),
            axes_object_matrix=axes_object_matrices[figure_index],
            colour_map_object=colour_map_object,
            min_contour_level=min_unguided_value_log10,
            max_contour_level=max_unguided_value_log10,
            contour_interval=contour_interval_log10, row_major=False
        )

        this_colour_bar_object = plotting_utils.plot_linear_colour_bar(
            axes_object_or_matrix=axes_object_matrices[figure_index],
            data_matrix=this_cam_matrix_log10,
            colour_map_object=colour_map_object,
            min_value=min_unguided_value_log10,
            max_value=max_unguided_value_log10,
            orientation_string='horizontal',
            fraction_of_axis_length=colour_bar_length / (1 + int(conv_2d3d)),
            extend_min=True, extend_max=True,
            font_size=COLOUR_BAR_FONT_SIZE
        )

        these_tick_values = this_colour_bar_object.get_ticks()
        these_tick_strings = [
            '{0:.2f}'.format(10 ** v)[:4] for v in these_tick_values
        ]
        this_colour_bar_object.set_ticks(these_tick_values)
        this_colour_bar_object.set_ticklabels(these_tick_strings)

        if label_colour_bars:
            this_colour_bar_object.set_label(
                'Class activation', fontsize=COLOUR_BAR_FONT_SIZE
            )

    output_file_name = plot_examples.metadata_to_file_name(
        output_dir_name=output_dir_name, is_sounding=False, pmm_flag=pmm_flag,
        full_storm_id_string=full_storm_id_string,
        storm_time_unix_sec=storm_time_unix_sec,
        radar_field_name=radar_field_name
    )

    print('Saving figure to: "{0:s}"...'.format(output_file_name))
    figure_objects[figure_index].savefig(
        output_file_name, dpi=FIGURE_RESOLUTION_DPI, pad_inches=0,
        bbox_inches='tight'
    )
    pyplot.close(figure_objects[figure_index])
def _plot_3d_radar_cam(
        colour_map_object, min_unguided_value, max_unguided_value,
        num_unguided_contours, max_guided_value, half_num_guided_contours,
        label_colour_bars, colour_bar_length, figure_objects,
        axes_object_matrices, model_metadata_dict, output_dir_name,
        cam_matrix=None, guided_cam_matrix=None, full_storm_id_string=None,
        storm_time_unix_sec=None):
    """Plots class-activation map for 3-D radar data.

    M = number of rows in spatial grid
    N = number of columns in spatial grid
    H = number of heights in spatial grid
    F = number of radar fields

    If this method is plotting a composite rather than single example (storm
    object), `full_storm_id_string` and `storm_time_unix_sec` can be None.

    :param colour_map_object: See documentation at top of file.
    :param min_unguided_value: Same.
    :param max_unguided_value: Same.
    :param num_unguided_contours: Same.
    :param max_guided_value: Same.
    :param half_num_guided_contours: Same.
    :param label_colour_bars: Same.
    :param colour_bar_length: Same.
    :param figure_objects: See doc for
        `plot_input_examples._plot_3d_radar_scan`.
    :param axes_object_matrices: Same.
    :param model_metadata_dict: Dictionary returned by
        `cnn.read_model_metadata`.
    :param output_dir_name: Path to output directory.  Figure(s) will be saved
        here.
    :param cam_matrix: M-by-N-by-H numpy array of unguided class activations.
    :param guided_cam_matrix: [used only if `cam_matrix is None`]
        M-by-N-by-H-by-F numpy array of guided class activations.
    :param full_storm_id_string: Full storm ID.
    :param storm_time_unix_sec: Storm time.
    """

    pmm_flag = full_storm_id_string is None and storm_time_unix_sec is None
    conv_2d3d = model_metadata_dict[cnn.CONV_2D3D_KEY]

    if conv_2d3d:
        loop_max = 1
        radar_field_names = ['reflectivity']
    else:
        loop_max = len(figure_objects)
        training_option_dict = model_metadata_dict[cnn.TRAINING_OPTION_DICT_KEY]
        radar_field_names = training_option_dict[trainval_io.RADAR_FIELDS_KEY]

    min_unguided_value_log10 = numpy.log10(min_unguided_value)
    max_unguided_value_log10 = numpy.log10(max_unguided_value)
    contour_interval_log10 = (
        (max_unguided_value_log10 - min_unguided_value_log10) /
        (num_unguided_contours - 1)
    )

    for j in range(loop_max):
        if cam_matrix is None:
            saliency_plotting.plot_many_2d_grids_with_contours(
                saliency_matrix_3d=numpy.flip(
                    guided_cam_matrix[..., j], axis=0
                ),
                axes_object_matrix=axes_object_matrices[j],
                colour_map_object=colour_map_object,
                max_absolute_contour_level=max_guided_value,
                contour_interval=max_guided_value / half_num_guided_contours
            )

            this_colour_bar_object = plotting_utils.plot_linear_colour_bar(
                axes_object_or_matrix=axes_object_matrices[j],
                data_matrix=guided_cam_matrix[..., j],
                colour_map_object=colour_map_object, min_value=0.,
                max_value=max_guided_value, orientation_string='horizontal',
                fraction_of_axis_length=colour_bar_length,
                extend_min=False, extend_max=True,
                font_size=COLOUR_BAR_FONT_SIZE
            )

            if label_colour_bars:
                this_colour_bar_object.set_label(
                    'Absolute guided class activation',
                    fontsize=COLOUR_BAR_FONT_SIZE
                )
        else:
            cam_matrix_log10 = numpy.log10(cam_matrix)

            cam_plotting.plot_many_2d_grids(
                class_activation_matrix_3d=numpy.flip(cam_matrix_log10, axis=0),
                axes_object_matrix=axes_object_matrices[j],
                colour_map_object=colour_map_object,
                min_contour_level=min_unguided_value_log10,
                max_contour_level=max_unguided_value_log10,
                contour_interval=contour_interval_log10
            )

            this_colour_bar_object = plotting_utils.plot_linear_colour_bar(
                axes_object_or_matrix=axes_object_matrices[j],
                data_matrix=cam_matrix_log10,
                colour_map_object=colour_map_object,
                min_value=min_unguided_value_log10,
                max_value=max_unguided_value_log10,
                orientation_string='horizontal',
                fraction_of_axis_length=colour_bar_length,
                extend_min=True, extend_max=True,
                font_size=COLOUR_BAR_FONT_SIZE
            )

            these_tick_values = this_colour_bar_object.get_ticks()
            these_tick_strings = [
                '{0:.2f}'.format(10 ** v)[:4] for v in these_tick_values
            ]
            this_colour_bar_object.set_ticks(these_tick_values)
            this_colour_bar_object.set_ticklabels(these_tick_strings)

            if label_colour_bars:
                this_colour_bar_object.set_label(
                    'Class activation', fontsize=COLOUR_BAR_FONT_SIZE
                )

        this_file_name = plot_examples.metadata_to_file_name(
            output_dir_name=output_dir_name, is_sounding=False,
            pmm_flag=pmm_flag, full_storm_id_string=full_storm_id_string,
            storm_time_unix_sec=storm_time_unix_sec,
            radar_field_name=radar_field_names[j]
        )

        print('Saving figure to: "{0:s}"...'.format(this_file_name))
        figure_objects[j].savefig(
            this_file_name, dpi=FIGURE_RESOLUTION_DPI, pad_inches=0,
            bbox_inches='tight'
        )
        pyplot.close(figure_objects[j])
Exemple #17
0
def plot_saliency_for_sounding(saliency_matrix,
                               sounding_field_names,
                               pressure_levels_mb,
                               colour_map_object,
                               max_absolute_colour_value,
                               min_font_size=DEFAULT_MIN_SOUNDING_FONT_SIZE,
                               max_font_size=DEFAULT_MAX_SOUNDING_FONT_SIZE):
    """Plots saliency for one sounding.

    P = number of pressure levels
    F = number of fields

    :param saliency_matrix: P-by-F numpy array of saliency values.
    :param sounding_field_names: length-F list of field names.
    :param pressure_levels_mb: length-P list of pressure levels (millibars).
    :param colour_map_object: See doc for `plot_2d_grid`.
    :param max_absolute_colour_value: Same.
    :param min_font_size: Same.
    :param max_font_size: Same.
    """

    error_checking.assert_is_geq(max_absolute_colour_value, 0.)
    max_absolute_colour_value = max([max_absolute_colour_value, 0.001])

    error_checking.assert_is_greater_numpy_array(pressure_levels_mb, 0.)
    error_checking.assert_is_numpy_array(pressure_levels_mb, num_dimensions=1)

    error_checking.assert_is_list(sounding_field_names)
    error_checking.assert_is_numpy_array(numpy.array(sounding_field_names),
                                         num_dimensions=1)

    num_pressure_levels = len(pressure_levels_mb)
    num_sounding_fields = len(sounding_field_names)

    error_checking.assert_is_numpy_array_without_nan(saliency_matrix)
    error_checking.assert_is_numpy_array(saliency_matrix,
                                         exact_dimensions=numpy.array([
                                             num_pressure_levels,
                                             num_sounding_fields
                                         ]))

    try:
        u_wind_index = sounding_field_names.index(soundings.U_WIND_NAME)
        v_wind_index = sounding_field_names.index(soundings.V_WIND_NAME)
        plot_wind_barbs = True
    except ValueError:
        plot_wind_barbs = False

    if plot_wind_barbs:
        u_wind_saliency_values = saliency_matrix[:, u_wind_index]
        v_wind_saliency_values = saliency_matrix[:, v_wind_index]
        wind_saliency_magnitudes = numpy.sqrt(u_wind_saliency_values**2 +
                                              v_wind_saliency_values**2)

        colour_norm_object = pyplot.Normalize(vmin=0.,
                                              vmax=max_absolute_colour_value)

        rgb_matrix_for_wind = colour_map_object(
            colour_norm_object(wind_saliency_magnitudes))[..., :-1]

        non_wind_flags = numpy.array(
            [f not in WIND_COMPONENT_NAMES for f in sounding_field_names],
            dtype=bool)

        non_wind_indices = numpy.where(non_wind_flags)[0]
        saliency_matrix = saliency_matrix[:, non_wind_indices]
        sounding_field_names = [
            sounding_field_names[k] for k in non_wind_indices
        ]

        sounding_field_names.append(WIND_NAME)
        num_sounding_fields = len(sounding_field_names)

    rgb_matrix, font_size_matrix = _saliency_to_colour_and_size(
        saliency_matrix=saliency_matrix,
        colour_map_object=colour_map_object,
        max_absolute_colour_value=max_absolute_colour_value,
        min_font_size=min_font_size,
        max_font_size=max_font_size)

    _, axes_object = pyplot.subplots(1,
                                     1,
                                     figsize=(FIGURE_WIDTH_INCHES,
                                              FIGURE_HEIGHT_INCHES))

    axes_object.set_facecolor(
        plotting_utils.colour_from_numpy_to_tuple(
            SOUNDING_SALIENCY_BACKGROUND_COLOUR))

    for k in range(num_sounding_fields):
        if sounding_field_names[k] == WIND_NAME:
            for j in range(num_pressure_levels):
                this_vector = numpy.array(
                    [u_wind_saliency_values[j], v_wind_saliency_values[j]])

                this_vector = (WIND_SALIENCY_MULTIPLIER * this_vector /
                               numpy.linalg.norm(this_vector, ord=2))

                this_colour_tuple = plotting_utils.colour_from_numpy_to_tuple(
                    rgb_matrix_for_wind[j, ...])

                axes_object.barbs(k,
                                  pressure_levels_mb[j],
                                  this_vector[0],
                                  this_vector[1],
                                  length=WIND_BARB_LENGTH,
                                  fill_empty=True,
                                  rounding=False,
                                  sizes={'emptybarb': EMPTY_WIND_BARB_RADIUS},
                                  color=this_colour_tuple)

            continue

        for j in range(num_pressure_levels):
            this_colour_tuple = plotting_utils.colour_from_numpy_to_tuple(
                rgb_matrix[j, k, ...])

            if saliency_matrix[j, k] >= 0:
                axes_object.text(k,
                                 pressure_levels_mb[j],
                                 '+',
                                 fontsize=font_size_matrix[j, k],
                                 color=this_colour_tuple,
                                 horizontalalignment='center',
                                 verticalalignment='center')
            else:
                axes_object.text(k,
                                 pressure_levels_mb[j],
                                 '_',
                                 fontsize=font_size_matrix[j, k],
                                 color=this_colour_tuple,
                                 horizontalalignment='center',
                                 verticalalignment='bottom')

    axes_object.set_xlim(-0.5, num_sounding_fields - 0.5)
    axes_object.set_ylim(100, 1000)
    axes_object.invert_yaxis()
    pyplot.yscale('log')
    pyplot.minorticks_off()

    y_tick_locations = numpy.linspace(100, 1000, num=10, dtype=int)
    y_tick_labels = ['{0:d}'.format(p) for p in y_tick_locations]
    pyplot.yticks(y_tick_locations, y_tick_labels)

    x_tick_locations = numpy.linspace(0,
                                      num_sounding_fields - 1,
                                      num=num_sounding_fields,
                                      dtype=float)
    x_tick_labels = [FIELD_NAME_TO_LATEX_DICT[f] for f in sounding_field_names]
    pyplot.xticks(x_tick_locations, x_tick_labels)

    colour_bar_object = plotting_utils.plot_linear_colour_bar(
        axes_object_or_matrix=axes_object,
        data_matrix=saliency_matrix,
        colour_map_object=colour_map_object,
        min_value=0.,
        max_value=max_absolute_colour_value,
        orientation_string='vertical',
        extend_min=True,
        extend_max=True)

    colour_bar_object.set_label('Saliency (absolute value)')
def _plot_data(num_days_matrix, grid_metadata_dict, colour_map_object):
    """Plots data.

    M = number of rows in grid
    N = number of columns in grid

    :param num_days_matrix: M-by-N numpy array with number of convective days
        for which grid cell is in domain.
    :param grid_metadata_dict: Dictionary created by
        `grids.create_equidistant_grid`.
    :param colour_map_object: See documentation at top of file.
    :return: figure_object: Figure handle (instance of
        `matplotlib.figure.Figure`).
    :return: axes_object: Axes handle (instance of
        `matplotlib.axes._subplots.AxesSubplot`).
    """

    figure_object, axes_object = pyplot.subplots(
        1, 1, figsize=(FIGURE_WIDTH_INCHES, FIGURE_HEIGHT_INCHES))

    basemap_object, basemap_x_matrix_metres, basemap_y_matrix_metres = (
        _get_basemap(grid_metadata_dict))

    num_grid_rows = num_days_matrix.shape[0]
    num_grid_columns = num_days_matrix.shape[1]
    x_spacing_metres = (
        (basemap_x_matrix_metres[0, -1] - basemap_x_matrix_metres[0, 0]) /
        (num_grid_columns - 1))
    y_spacing_metres = (
        (basemap_y_matrix_metres[-1, 0] - basemap_y_matrix_metres[0, 0]) /
        (num_grid_rows - 1))

    matrix_to_plot, edge_x_coords_metres, edge_y_coords_metres = (
        grids.xy_field_grid_points_to_edges(
            field_matrix=num_days_matrix,
            x_min_metres=basemap_x_matrix_metres[0, 0],
            y_min_metres=basemap_y_matrix_metres[0, 0],
            x_spacing_metres=x_spacing_metres,
            y_spacing_metres=y_spacing_metres))

    matrix_to_plot = numpy.ma.masked_where(matrix_to_plot == 0, matrix_to_plot)

    plotting_utils.plot_coastlines(basemap_object=basemap_object,
                                   axes_object=axes_object,
                                   line_colour=BORDER_COLOUR)

    plotting_utils.plot_countries(basemap_object=basemap_object,
                                  axes_object=axes_object,
                                  line_colour=BORDER_COLOUR)

    plotting_utils.plot_states_and_provinces(basemap_object=basemap_object,
                                             axes_object=axes_object,
                                             line_colour=BORDER_COLOUR)

    plotting_utils.plot_parallels(basemap_object=basemap_object,
                                  axes_object=axes_object,
                                  num_parallels=NUM_PARALLELS)

    plotting_utils.plot_meridians(basemap_object=basemap_object,
                                  axes_object=axes_object,
                                  num_meridians=NUM_MERIDIANS)

    basemap_object.pcolormesh(edge_x_coords_metres,
                              edge_y_coords_metres,
                              matrix_to_plot,
                              cmap=colour_map_object,
                              vmin=1,
                              vmax=numpy.max(num_days_matrix),
                              shading='flat',
                              edgecolors='None',
                              axes=axes_object,
                              zorder=-1e12)

    colour_bar_object = plotting_utils.plot_linear_colour_bar(
        axes_object_or_matrix=axes_object,
        data_matrix=num_days_matrix,
        colour_map_object=colour_map_object,
        min_value=1,
        max_value=numpy.max(num_days_matrix),
        orientation_string='horizontal',
        extend_min=False,
        extend_max=False,
        padding=0.05)

    tick_values = colour_bar_object.get_ticks()
    tick_strings = ['{0:d}'.format(int(numpy.round(v))) for v in tick_values]
    colour_bar_object.set_ticks(tick_values)
    colour_bar_object.set_ticklabels(tick_strings)

    axes_object.set_title('Number of convective days by grid cell')

    return figure_object, axes_object
Exemple #19
0
def _plot_feature_maps_one_layer(feature_matrix, full_id_strings,
                                 storm_times_unix_sec, layer_name,
                                 output_dir_name):
    """Plots all feature maps for one layer.

    E = number of examples (storm objects)
    M = number of spatial rows
    N = number of spatial columns
    H = number of spatial depths (heights)
    C = number of channels

    :param feature_matrix: numpy array (E x M x N x C or E x M x N x H x C) of
        feature maps.
    :param full_id_strings: length-E list of full storm IDs.
    :param storm_times_unix_sec: length-E numpy array of storm times.
    :param layer_name: Name of layer.
    :param output_dir_name: Name of output directory for this layer.
    """

    num_spatial_dimensions = len(feature_matrix.shape) - 2
    num_storm_objects = feature_matrix.shape[0]
    num_channels = feature_matrix.shape[-1]

    if num_spatial_dimensions == 3:
        num_heights = feature_matrix.shape[-2]
    else:
        num_heights = None

    num_panel_rows = int(numpy.round(numpy.sqrt(num_channels)))
    annotation_string_by_channel = [None] * num_channels

    # annotation_string_by_channel = [
    #     'Filter {0:d}'.format(c + 1) for c in range(num_channels)
    # ]

    if num_channels >= NUM_PANELS_FOR_NO_FONT:
        annotation_string_by_channel = [''] * num_channels
        font_size = TINY_FONT_SIZE + 0
    elif num_channels >= NUM_PANELS_FOR_TINY_FONT:
        font_size = TINY_FONT_SIZE + 0
    elif num_channels >= NUM_PANELS_FOR_SMALL_FONT:
        font_size = SMALL_FONT_SIZE + 0
    else:
        font_size = MAIN_FONT_SIZE + 0

    max_colour_value = numpy.percentile(numpy.absolute(feature_matrix), 99)
    min_colour_value = -1 * max_colour_value

    for i in range(num_storm_objects):
        this_time_string = time_conversion.unix_sec_to_string(
            storm_times_unix_sec[i], TIME_FORMAT)

        if num_spatial_dimensions == 2:
            _, this_axes_object_matrix = (
                feature_map_plotting.plot_many_2d_feature_maps(
                    feature_matrix=numpy.flip(feature_matrix[i, ...], axis=0),
                    annotation_string_by_panel=annotation_string_by_channel,
                    num_panel_rows=num_panel_rows,
                    colour_map_object=pyplot.cm.seismic,
                    min_colour_value=min_colour_value,
                    max_colour_value=max_colour_value,
                    font_size=font_size))

            plotting_utils.plot_linear_colour_bar(
                axes_object_or_matrix=this_axes_object_matrix,
                data_matrix=feature_matrix[i, ...],
                colour_map_object=pyplot.cm.seismic,
                min_value=min_colour_value,
                max_value=max_colour_value,
                orientation_string='horizontal',
                extend_min=True,
                extend_max=True)

            this_title_string = 'Layer "{0:s}", storm "{1:s}" at {2:s}'.format(
                layer_name, full_id_strings[i], this_time_string)
            pyplot.suptitle(this_title_string, fontsize=MAIN_FONT_SIZE)

            this_figure_file_name = (
                '{0:s}/storm={1:s}_{2:s}_features.jpg').format(
                    output_dir_name, full_id_strings[i].replace('_', '-'),
                    this_time_string)

            print('Saving figure to: "{0:s}"...'.format(this_figure_file_name))
            pyplot.savefig(this_figure_file_name, dpi=FIGURE_RESOLUTION_DPI)
            pyplot.close()

        else:
            for k in range(num_heights):
                _, this_axes_object_matrix = (
                    feature_map_plotting.plot_many_2d_feature_maps(
                        feature_matrix=numpy.flip(feature_matrix[i, :, :,
                                                                 k, :],
                                                  axis=0),
                        annotation_string_by_panel=annotation_string_by_channel,
                        num_panel_rows=num_panel_rows,
                        colour_map_object=pyplot.cm.seismic,
                        min_colour_value=min_colour_value,
                        max_colour_value=max_colour_value,
                        font_size=font_size))

                plotting_utils.plot_linear_colour_bar(
                    axes_object_or_matrix=this_axes_object_matrix,
                    data_matrix=feature_matrix[i, :, :, k, :],
                    colour_map_object=pyplot.cm.seismic,
                    min_value=min_colour_value,
                    max_value=max_colour_value,
                    orientation_string='horizontal',
                    extend_min=True,
                    extend_max=True)

                this_title_string = (
                    'Layer "{0:s}", height {1:d} of {2:d}, storm "{3:s}" at '
                    '{4:s}').format(layer_name, k + 1, num_heights,
                                    full_id_strings[i], this_time_string)

                pyplot.suptitle(this_title_string, fontsize=MAIN_FONT_SIZE)

                this_figure_file_name = (
                    '{0:s}/storm={1:s}_{2:s}_features_height{3:02d}.jpg'
                ).format(output_dir_name, full_id_strings[i].replace('_', '-'),
                         this_time_string, k + 1)

                print('Saving figure to: "{0:s}"...'.format(
                    this_figure_file_name))

                pyplot.savefig(this_figure_file_name,
                               dpi=FIGURE_RESOLUTION_DPI)
                pyplot.close()
Exemple #20
0
def _plot_score_one_field(latitude_matrix_deg,
                          longitude_matrix_deg,
                          score_matrix,
                          colour_map_object,
                          min_colour_value,
                          max_colour_value,
                          taper_cbar_top,
                          taper_cbar_bottom,
                          log_scale=False):
    """Plots one score for one field.

    M = number of rows in grid
    N = number of columns in grid

    :param latitude_matrix_deg: M-by-N numpy array of latitudes (deg N).
    :param longitude_matrix_deg: M-by-N numpy array of longitudes (deg E).
    :param score_matrix: M-by-N numpy array of score values.
    :param colour_map_object: Colour scheme (instance of `matplotlib.pyplot.cm`).
    :param min_colour_value: Minimum value in colour bar.
    :param max_colour_value: Max value in colour bar.
    :param taper_cbar_top: Boolean flag.  If True, will taper bottom of colour
        bar, implying that lower values are possible.
    :param taper_cbar_bottom: Same but for top of colour bar.
    :param log_scale: Boolean flag.  If True, will make colour bar logarithmic.
    :return: figure_object: Figure handle (instance of
        `matplotlib.figure.Figure`).
    :return: axes_object: Axes handle (instance of
        `matplotlib.axes._subplots.AxesSubplot`).
    """

    (figure_object, axes_object,
     basemap_object) = plotting_utils.create_equidist_cylindrical_map(
         min_latitude_deg=latitude_matrix_deg[0, 0],
         max_latitude_deg=latitude_matrix_deg[-1, -1],
         min_longitude_deg=longitude_matrix_deg[0, 0],
         max_longitude_deg=longitude_matrix_deg[-1, -1],
         resolution_string=RESOLUTION_STRING)

    latitude_spacing_deg = latitude_matrix_deg[1, 0] - latitude_matrix_deg[0,
                                                                           0]
    longitude_spacing_deg = (longitude_matrix_deg[0, 1] -
                             longitude_matrix_deg[0, 0])

    print(numpy.sum(numpy.invert(numpy.isnan(score_matrix))))

    (score_matrix_at_edges, grid_edge_latitudes_deg,
     grid_edge_longitudes_deg) = grids.latlng_field_grid_points_to_edges(
         field_matrix=score_matrix,
         min_latitude_deg=latitude_matrix_deg[0, 0],
         min_longitude_deg=longitude_matrix_deg[0, 0],
         lat_spacing_deg=latitude_spacing_deg,
         lng_spacing_deg=longitude_spacing_deg)

    score_matrix_at_edges = numpy.ma.masked_where(
        numpy.isnan(score_matrix_at_edges), score_matrix_at_edges)

    plotting_utils.plot_coastlines(basemap_object=basemap_object,
                                   axes_object=axes_object,
                                   line_colour=BORDER_COLOUR,
                                   line_width=BORDER_WIDTH)
    plotting_utils.plot_countries(basemap_object=basemap_object,
                                  axes_object=axes_object,
                                  line_colour=BORDER_COLOUR,
                                  line_width=BORDER_WIDTH)
    # plotting_utils.plot_states_and_provinces(
    #     basemap_object=basemap_object, axes_object=axes_object,
    #     line_colour=BORDER_COLOUR
    # )
    plotting_utils.plot_parallels(basemap_object=basemap_object,
                                  axes_object=axes_object,
                                  num_parallels=NUM_PARALLELS,
                                  line_width=0)
    plotting_utils.plot_meridians(basemap_object=basemap_object,
                                  axes_object=axes_object,
                                  num_meridians=NUM_MERIDIANS,
                                  line_width=0)

    pyplot.pcolormesh(grid_edge_longitudes_deg,
                      grid_edge_latitudes_deg,
                      score_matrix_at_edges,
                      cmap=colour_map_object,
                      vmin=min_colour_value,
                      vmax=max_colour_value,
                      shading='flat',
                      edgecolors='None',
                      axes=axes_object,
                      zorder=-1e12)

    colour_bar_object = plotting_utils.plot_linear_colour_bar(
        axes_object_or_matrix=axes_object,
        data_matrix=score_matrix,
        colour_map_object=colour_map_object,
        min_value=min_colour_value,
        max_value=max_colour_value,
        orientation_string='horizontal',
        extend_min=taper_cbar_bottom,
        extend_max=taper_cbar_top,
        padding=0.05,
        font_size=COLOUR_BAR_FONT_SIZE)

    tick_values = colour_bar_object.get_ticks()

    if log_scale:
        tick_strings = [
            '{0:d}'.format(int(numpy.round(10**v))) for v in tick_values
        ]
    elif numpy.nanmax(numpy.absolute(score_matrix)) >= 6:
        tick_strings = [
            '{0:d}'.format(int(numpy.round(v))) for v in tick_values
        ]
    else:
        tick_strings = ['{0:.2f}'.format(v) for v in tick_values]

    colour_bar_object.set_ticks(tick_values)
    colour_bar_object.set_ticklabels(tick_strings)

    return figure_object, axes_object