示例#1
0
def plot_roc_curve(axes_object,
                   pod_by_threshold,
                   pofd_by_threshold,
                   line_colour=ROC_CURVE_COLOUR,
                   plot_background=True):
    """Plots ROC (receiver operating characteristic) curve.

    T = number of binarization thresholds

    For the definition of a "binarization threshold" and the role they play in
    ROC curves, see `model_evaluation.get_points_in_roc_curve`.

    :param axes_object: Instance of `matplotlib.axes._subplots.AxesSubplot`.
    :param pod_by_threshold: length-T numpy array of POD (probability of
        detection) values.
    :param pofd_by_threshold: length-T numpy array of POFD (probability of false
        detection) values.
    :param line_colour: Line colour.
    :param plot_background: Boolean flag.  If True, will plot background
        (reference line and Peirce-score contours).
    :return: line_handle: Line handle for ROC curve.
    """

    error_checking.assert_is_numpy_array(pod_by_threshold, num_dimensions=1)
    error_checking.assert_is_geq_numpy_array(pod_by_threshold,
                                             0.,
                                             allow_nan=True)
    error_checking.assert_is_leq_numpy_array(pod_by_threshold,
                                             1.,
                                             allow_nan=True)

    num_thresholds = len(pod_by_threshold)
    expected_dim = numpy.array([num_thresholds], dtype=int)

    error_checking.assert_is_numpy_array(pofd_by_threshold,
                                         exact_dimensions=expected_dim)
    error_checking.assert_is_geq_numpy_array(pofd_by_threshold,
                                             0.,
                                             allow_nan=True)
    error_checking.assert_is_leq_numpy_array(pofd_by_threshold,
                                             1.,
                                             allow_nan=True)

    error_checking.assert_is_boolean(plot_background)

    if plot_background:
        pofd_matrix, pod_matrix = model_eval.get_pofd_pod_grid()
        peirce_score_matrix = pod_matrix - pofd_matrix

        this_colour_map_object, this_colour_norm_object = (
            _get_peirce_colour_scheme())

        pyplot.contourf(pofd_matrix,
                        pod_matrix,
                        peirce_score_matrix,
                        CSI_LEVELS,
                        cmap=this_colour_map_object,
                        norm=this_colour_norm_object,
                        vmin=0.,
                        vmax=1.,
                        axes=axes_object)

        colour_bar_object = plotting_utils.plot_colour_bar(
            axes_object_or_matrix=axes_object,
            data_matrix=peirce_score_matrix,
            colour_map_object=this_colour_map_object,
            colour_norm_object=this_colour_norm_object,
            orientation_string='vertical',
            extend_min=False,
            extend_max=False,
            fraction_of_axis_length=0.8)

        colour_bar_object.set_label('Peirce score (POD minus POFD)')

        random_x_coords, random_y_coords = model_eval.get_random_roc_curve()
        axes_object.plot(
            random_x_coords,
            random_y_coords,
            color=plotting_utils.colour_from_numpy_to_tuple(RANDOM_ROC_COLOUR),
            linestyle='dashed',
            linewidth=RANDOM_ROC_WIDTH)

    nan_flags = numpy.logical_or(numpy.isnan(pofd_by_threshold),
                                 numpy.isnan(pod_by_threshold))

    if numpy.all(nan_flags):
        line_handle = None
    else:
        real_indices = numpy.where(numpy.invert(nan_flags))[0]

        line_handle = axes_object.plot(
            pofd_by_threshold[real_indices],
            pod_by_threshold[real_indices],
            color=plotting_utils.colour_from_numpy_to_tuple(line_colour),
            linestyle='solid',
            linewidth=ROC_CURVE_WIDTH)[0]

    axes_object.set_xlabel('POFD (probability of false detection)')
    axes_object.set_ylabel('POD (probability of detection)')
    axes_object.set_xlim(0., 1.)
    axes_object.set_ylim(0., 1.)

    return line_handle
示例#2
0
def plot_taylor_diagram_many_heights(
        target_stdevs,
        prediction_stdevs,
        correlations,
        heights_m_agl,
        figure_object,
        colour_map_object=DEFAULT_HEIGHT_CMAP_OBJECT):
    """Plots Taylor diagram for many heights on the same axes.

    Each point should be for the same variable, just at different
    heights.

    H = number of heights

    :param target_stdevs: length-H numpy array with standard deviations of
        target (actual) values.
    :param prediction_stdevs: length-H numpy array with standard deviations of
        predicted values.
    :param correlations: length-H numpy array of correlations.
    :param heights_m_agl: length-H numpy array of heights (metres above ground
        level).
    :param figure_object: Will plot on this figure (instance of
        `matplotlib.figure.Figure`).
    :param colour_map_object: Colour map (instance of `matplotlib.pyplot.cm` or
        similar).  Will be used to colour points in Taylor diagram by height.
    :return: taylor_diagram_object: Handle for Taylor diagram (instance of
        `taylor_diagram.TaylorDiagram`).
    """

    error_checking.assert_is_geq_numpy_array(target_stdevs, 0.)
    error_checking.assert_is_numpy_array(target_stdevs, num_dimensions=1)

    num_heights = len(target_stdevs)
    expected_dim = numpy.array([num_heights], dtype=int)

    error_checking.assert_is_geq_numpy_array(prediction_stdevs, 0.)
    error_checking.assert_is_numpy_array(prediction_stdevs,
                                         exact_dimensions=expected_dim)

    error_checking.assert_is_geq_numpy_array(correlations, -1., allow_nan=True)
    error_checking.assert_is_leq_numpy_array(correlations, 1., allow_nan=True)
    error_checking.assert_is_numpy_array(correlations,
                                         exact_dimensions=expected_dim)

    error_checking.assert_is_geq_numpy_array(heights_m_agl, 0.)
    error_checking.assert_is_numpy_array(heights_m_agl,
                                         exact_dimensions=expected_dim)

    heights_km_agl = heights_m_agl * METRES_TO_KM
    colour_norm_object = matplotlib.colors.LogNorm(
        vmin=numpy.min(heights_km_agl), vmax=numpy.max(heights_km_agl))

    mean_target_stdev = numpy.mean(target_stdevs)
    this_ratio = numpy.maximum(
        numpy.max(target_stdevs),
        numpy.max(prediction_stdevs)) / mean_target_stdev

    taylor_diagram_object = taylor_diagram.TaylorDiagram(
        refstd=mean_target_stdev,
        fig=figure_object,
        srange=(0, this_ratio),
        extend=False,
        plot_reference_line=False)

    this_marker_object = taylor_diagram_object.samplePoints[0]
    this_marker_object.set_visible(False)

    for j in range(num_heights):
        if numpy.isnan(correlations[j]):
            continue

        this_colour = colour_map_object(colour_norm_object(heights_km_agl[j]))

        taylor_diagram_object.add_sample(stddev=target_stdevs[j], corrcoef=1.)

        this_marker_object = taylor_diagram_object.samplePoints[-1]
        this_marker_object.set_marker(TAYLOR_TARGET_MARKER_TYPE)
        this_marker_object.set_markersize(TAYLOR_TARGET_MARKER_SIZE)
        this_marker_object.set_markerfacecolor(this_colour)
        this_marker_object.set_markeredgewidth(0)

        taylor_diagram_object.add_sample(stddev=prediction_stdevs[j],
                                         corrcoef=correlations[j])

        this_marker_object = taylor_diagram_object.samplePoints[-1]
        this_marker_object.set_marker(TAYLOR_PREDICTION_MARKER_TYPE)
        this_marker_object.set_markersize(TAYLOR_PREDICTION_MARKER_SIZE)
        this_marker_object.set_markerfacecolor(this_colour)
        this_marker_object.set_markeredgewidth(0)

    crmse_contour_object = taylor_diagram_object.add_contours(levels=5,
                                                              colors='0.5')
    pyplot.clabel(crmse_contour_object, inline=1, fmt='%.0f')

    taylor_diagram_object.add_grid()
    taylor_diagram_object._ax.axis[:].major_ticks.set_tick_out(True)

    colour_bar_object = plotting_utils.plot_colour_bar(
        axes_object_or_matrix=figure_object.axes[0],
        data_matrix=heights_km_agl,
        colour_map_object=colour_map_object,
        colour_norm_object=colour_norm_object,
        orientation_string='vertical',
        extend_min=False,
        extend_max=False,
        fraction_of_axis_length=0.85,
        font_size=FONT_SIZE)

    tick_values = colour_bar_object.get_ticks()
    tick_strings = profile_plotting.create_height_labels(
        tick_values_km_agl=tick_values, use_log_scale=True)
    colour_bar_object.set_ticks(tick_values)
    colour_bar_object.set_ticklabels(tick_strings)

    colour_bar_object.set_label('Height (km AGL)', fontsize=FONT_SIZE)

    return taylor_diagram_object
def plot_many_2d_grids(data_matrix,
                       field_names,
                       axes_objects,
                       panel_names=None,
                       plot_grid_lines=True,
                       colour_map_objects=None,
                       colour_norm_objects=None,
                       refl_opacity=DEFAULT_OPACITY,
                       plot_colour_bar_flags=None,
                       panel_name_font_size=DEFAULT_FONT_SIZE,
                       colour_bar_font_size=DEFAULT_FONT_SIZE,
                       colour_bar_length=DEFAULT_COLOUR_BAR_LENGTH):
    """Plots many 2-D grids in paneled figure.

    M = number of rows in grid
    N = number of columns in grid
    C = number of fields

    :param data_matrix: M-by-N-by-C numpy array of radar values.
    :param field_names: length-C list of field names.
    :param axes_objects: length-C list of axes handles (instances of
        `matplotlib.axes._subplots.AxesSubplot`).
    :param panel_names: length-C list of panel names (to be printed at bottom of
        each panel).  If None, panel names will not be printed.
    :param plot_grid_lines: Boolean flag.  If True, will plot grid lines over
        radar images.
    :param colour_map_objects: length-C list of colour schemes (instances of
        `matplotlib.pyplot.cm` or similar).  If None, will use default colour
        scheme for each field.
    :param colour_norm_objects: length-C list of colour-normalizers (instances
        of `matplotlib.colors.BoundaryNorm` or similar).  If None, will use
        default normalizer for each field.
    :param refl_opacity: Opacity for reflectivity colour scheme.  Used only if
        `colour_map_objects is None and colour_norm_objects is None`.
    :param plot_colour_bar_flags: length-C numpy array of Boolean flags.  If
        `plot_colour_bar_flags[k] == True`, will plot colour bar for [k]th
        panel.  If None, will plot no colour bars.
    :param panel_name_font_size: Font size for panel names.
    :param colour_bar_font_size: Font size for colour-bar tick marks.
    :param colour_bar_length: Length of colour bars (as fraction of axis
        length).
    :return: colour_bar_objects: length-C list of colour bars.  If
        `plot_colour_bar_flags[k] == False`, colour_bar_objects[k] will be None.
    """

    error_checking.assert_is_numpy_array(data_matrix, num_dimensions=3)
    num_fields = data_matrix.shape[-1]
    these_expected_dim = numpy.array([num_fields], dtype=int)

    error_checking.assert_is_string_list(field_names)
    error_checking.assert_is_numpy_array(numpy.array(field_names),
                                         exact_dimensions=these_expected_dim)

    error_checking.assert_is_numpy_array(numpy.array(axes_objects),
                                         exact_dimensions=these_expected_dim)

    if panel_names is None:
        panel_names = [None] * num_fields
    else:
        error_checking.assert_is_string_list(panel_names)
        error_checking.assert_is_numpy_array(
            numpy.array(panel_names), exact_dimensions=these_expected_dim)

    if colour_map_objects is None or colour_norm_objects is None:
        colour_map_objects = [None] * num_fields
        colour_norm_objects = [None] * num_fields
    else:
        error_checking.assert_is_numpy_array(
            numpy.array(colour_map_objects),
            exact_dimensions=these_expected_dim)
        error_checking.assert_is_numpy_array(
            numpy.array(colour_norm_objects),
            exact_dimensions=these_expected_dim)

    if plot_colour_bar_flags is None:
        plot_colour_bar_flags = numpy.full(num_fields, 0, dtype=bool)

    error_checking.assert_is_boolean_numpy_array(plot_colour_bar_flags)
    error_checking.assert_is_numpy_array(plot_colour_bar_flags,
                                         exact_dimensions=these_expected_dim)

    colour_bar_objects = [None] * num_fields

    for k in range(num_fields):
        this_colour_map_object, this_colour_norm_object = (
            plot_2d_grid_without_coords(
                field_matrix=data_matrix[..., k],
                field_name=field_names[k],
                axes_object=axes_objects[k],
                annotation_string=panel_names[k],
                font_size=panel_name_font_size,
                plot_grid_lines=plot_grid_lines,
                colour_map_object=copy.deepcopy(colour_map_objects[k]),
                colour_norm_object=copy.deepcopy(colour_norm_objects[k]),
                refl_opacity=refl_opacity))

        if not plot_colour_bar_flags[k]:
            continue

        colour_bar_objects[k] = plotting_utils.plot_colour_bar(
            axes_object_or_matrix=axes_objects[k],
            data_matrix=data_matrix[..., k],
            colour_map_object=this_colour_map_object,
            colour_norm_object=this_colour_norm_object,
            orientation_string='horizontal',
            font_size=colour_bar_font_size,
            fraction_of_axis_length=colour_bar_length,
            extend_min=field_names[k] in SHEAR_VORT_DIV_NAMES,
            extend_max=True)

    return colour_bar_objects
示例#4
0
def plot_many_2d_grids_without_coords(field_matrix,
                                      field_name_by_panel,
                                      num_panel_rows=None,
                                      figure_object=None,
                                      axes_object_matrix=None,
                                      panel_names=None,
                                      colour_map_object_by_panel=None,
                                      colour_norm_object_by_panel=None,
                                      plot_colour_bar_by_panel=None,
                                      font_size=DEFAULT_FONT_SIZE,
                                      row_major=True):
    """Plots 2-D colour map in each panel (one per field/height pair).

    M = number of rows in spatial grid
    N = number of columns in spatial grid
    P = number of panels (field/height pairs)

    This method uses the default colour scheme for each radar field.

    If `num_panel_rows is None`, this method needs arguments `figure_object` and
    `axes_object_matrix` -- and vice-versa.

    :param field_matrix: M-by-N-by-P numpy array of radar values.
    :param field_name_by_panel: length-P list of field names.
    :param num_panel_rows: Number of rows in paneled figure (different than M,
        which is number of rows in spatial grid).
    :param figure_object: See doc for `plotting_utils.create_paneled_figure`.
    :param axes_object_matrix: See above.
    :param panel_names: length-P list of panel names (will be printed at bottoms
        of panels).  If you do not want panel names, make this None.
    :param colour_map_object_by_panel: length-P list of `matplotlib.pyplot.cm`
        objects.  If this is None, the default will be used for each field.
    :param colour_norm_object_by_panel: length-P list of
        `matplotlib.colors.BoundaryNorm` objects.  If this is None, the default
        will be used for each field.
    :param plot_colour_bar_by_panel: length-P numpy array of Boolean flags.  If
        plot_colour_bar_by_panel[k] = True, horizontal colour bar will be
        plotted under [k]th panel.  If you want to plot colour bar for every
        panel, leave this as None.
    :param font_size: Font size.
    :param row_major: Boolean flag.  If True, panels will be filled along rows
        first, then down columns.  If False, down columns first, then along
        rows.
    :return: figure_object: See doc for `plotting_utils.create_paneled_figure`.
    :return: axes_object_matrix: Same.
    :raises: ValueError: if `colour_map_object_by_panel` or
        `colour_norm_object_by_panel` has different length than number of
        panels.
    """

    error_checking.assert_is_boolean(row_major)
    error_checking.assert_is_numpy_array(field_matrix, num_dimensions=3)
    num_panels = field_matrix.shape[2]

    if panel_names is None:
        panel_names = [None] * num_panels
    if plot_colour_bar_by_panel is None:
        plot_colour_bar_by_panel = numpy.full(num_panels, True, dtype=bool)

    these_expected_dim = numpy.array([num_panels], dtype=int)
    error_checking.assert_is_numpy_array(numpy.array(panel_names),
                                         exact_dimensions=these_expected_dim)
    error_checking.assert_is_numpy_array(numpy.array(field_name_by_panel),
                                         exact_dimensions=these_expected_dim)

    error_checking.assert_is_boolean_numpy_array(plot_colour_bar_by_panel)
    error_checking.assert_is_numpy_array(plot_colour_bar_by_panel,
                                         exact_dimensions=these_expected_dim)

    if (colour_map_object_by_panel is None
            or colour_norm_object_by_panel is None):
        colour_map_object_by_panel = [None] * num_panels
        colour_norm_object_by_panel = [None] * num_panels

    error_checking.assert_is_list(colour_map_object_by_panel)
    error_checking.assert_is_list(colour_norm_object_by_panel)

    if len(colour_map_object_by_panel) != num_panels:
        error_string = (
            'Number of colour maps ({0:d}) should equal number of panels '
            '({1:d}).').format(len(colour_map_object_by_panel), num_panels)

        raise ValueError(error_string)

    if len(colour_norm_object_by_panel) != num_panels:
        error_string = (
            'Number of colour-normalizers ({0:d}) should equal number of panels'
            ' ({1:d}).').format(len(colour_norm_object_by_panel), num_panels)

        raise ValueError(error_string)

    if figure_object is None:
        error_checking.assert_is_integer(num_panel_rows)
        error_checking.assert_is_geq(num_panel_rows, 1)
        error_checking.assert_is_leq(num_panel_rows, num_panels)

        num_panel_columns = int(numpy.ceil(float(num_panels) / num_panel_rows))

        figure_object, axes_object_matrix = (
            plotting_utils.create_paneled_figure(num_rows=num_panel_rows,
                                                 num_columns=num_panel_columns,
                                                 shared_x_axis=False,
                                                 shared_y_axis=False,
                                                 keep_aspect_ratio=True))
    else:
        error_checking.assert_is_numpy_array(axes_object_matrix,
                                             num_dimensions=2)

        num_panel_rows = axes_object_matrix.shape[0]
        num_panel_columns = axes_object_matrix.shape[1]

    if row_major:
        order_string = 'C'
    else:
        order_string = 'F'

    for k in range(num_panels):
        this_panel_row, this_panel_column = numpy.unravel_index(
            k, (num_panel_rows, num_panel_columns), order=order_string)

        # this_colour_map_object, this_colour_norm_object = (
        #     plot_2d_grid_without_coords(
        #         field_matrix=field_matrix[..., k],
        #         field_name=field_name_by_panel[k],
        #         axes_object=axes_object_matrix[
        #             this_panel_row, this_panel_column],
        #         annotation_string=panel_names[k], font_size=font_size,
        #         colour_map_object=colour_map_object_by_panel[k],
        #         colour_norm_object=colour_norm_object_by_panel[k]
        #     )
        # )

        this_colour_map_object, this_colour_norm_object = (
            plot_2d_grid_without_coords(
                field_matrix=field_matrix[..., k],
                field_name=field_name_by_panel[k],
                axes_object=axes_object_matrix[this_panel_row,
                                               this_panel_column],
                annotation_string=None,
                font_size=font_size,
                colour_map_object=colour_map_object_by_panel[k],
                colour_norm_object=colour_norm_object_by_panel[k]))

        if not plot_colour_bar_by_panel[k]:
            continue

        this_extend_min_flag = field_name_by_panel[k] in SHEAR_VORT_DIV_NAMES

        this_colour_bar_object = plotting_utils.plot_colour_bar(
            axes_object_or_matrix=axes_object_matrix[this_panel_row,
                                                     this_panel_column],
            data_matrix=field_matrix[..., k],
            colour_map_object=this_colour_map_object,
            colour_norm_object=this_colour_norm_object,
            orientation_string='horizontal',
            extend_min=this_extend_min_flag,
            extend_max=True,
            fraction_of_axis_length=0.75,
            font_size=font_size)

        this_colour_bar_object.set_label(panel_names[k].replace('\n', '; '),
                                         fontsize=font_size,
                                         fontweight='bold')

    for k in range(num_panel_rows * num_panel_columns):
        if k < num_panels:
            continue

        this_panel_row, this_panel_column = numpy.unravel_index(
            k, (num_panel_rows, num_panel_columns), order=order_string)

        axes_object_matrix[this_panel_row, this_panel_column].axis('off')

    return figure_object, axes_object_matrix
示例#5
0
def plot_rel_curve_many_heights(mean_target_matrix,
                                mean_prediction_matrix,
                                heights_m_agl,
                                min_value_to_plot,
                                max_value_to_plot,
                                axes_object,
                                colour_map_object=DEFAULT_HEIGHT_CMAP_OBJECT):
    """Plots reliability curves for many heights on the same axes.

    Reliability curves should be for the same variable, just at different
    heights.

    B = number of forecast bins
    H = number of heights

    :param mean_target_matrix: H-by-B numpy array of mean target (actual)
        values.
    :param mean_prediction_matrix: H-by-B numpy array of mean predicted values.
    :param heights_m_agl: length-H numpy array of heights (metres above ground
        level).
    :param min_value_to_plot: Minimum value to plot (for both x- and y-axes).
    :param max_value_to_plot: Max value to plot (for both x- and y-axes).
    :param axes_object: Will plot on these axes (instance of
        `matplotlib.axes._subplots.AxesSubplot`).
    :param colour_map_object: Colour map (instance of `matplotlib.pyplot.cm` or
        similar).  Will be used to colour reliability curves by height.
    """

    error_checking.assert_is_numpy_array(mean_target_matrix, num_dimensions=2)
    error_checking.assert_is_numpy_array(mean_prediction_matrix,
                                         exact_dimensions=numpy.array(
                                             mean_target_matrix.shape,
                                             dtype=int))

    num_heights = mean_target_matrix.shape[0]

    error_checking.assert_is_geq_numpy_array(heights_m_agl, 0.)
    error_checking.assert_is_numpy_array(heights_m_agl,
                                         exact_dimensions=numpy.array(
                                             [num_heights], dtype=int))

    error_checking.assert_is_greater(max_value_to_plot, min_value_to_plot)

    heights_km_agl = heights_m_agl * METRES_TO_KM
    colour_norm_object = matplotlib.colors.LogNorm(
        vmin=numpy.min(heights_km_agl), vmax=numpy.max(heights_km_agl))

    for j in range(num_heights):
        this_colour = colour_map_object(colour_norm_object(heights_km_agl[j]))

        _plot_reliability_curve(axes_object=axes_object,
                                mean_predictions=mean_prediction_matrix[j, :],
                                mean_observations=mean_target_matrix[j, :],
                                line_colour=this_colour,
                                min_value_to_plot=min_value_to_plot,
                                max_value_to_plot=max_value_to_plot)

    axes_object.set_xlim(min_value_to_plot, max_value_to_plot)
    axes_object.set_ylim(min_value_to_plot, max_value_to_plot)

    colour_bar_object = plotting_utils.plot_colour_bar(
        axes_object_or_matrix=axes_object,
        data_matrix=heights_km_agl,
        colour_map_object=colour_map_object,
        colour_norm_object=colour_norm_object,
        orientation_string='vertical',
        extend_min=False,
        extend_max=False,
        font_size=FONT_SIZE)

    tick_values = colour_bar_object.get_ticks()
    tick_strings = profile_plotting.create_height_labels(
        tick_values_km_agl=tick_values, use_log_scale=True)
    colour_bar_object.set_ticks(tick_values)
    colour_bar_object.set_ticklabels(tick_strings)

    colour_bar_object.set_label('Height (km AGL)', fontsize=FONT_SIZE)
示例#6
0
def plot_roc_curve(axes_object,
                   pod_by_threshold,
                   pofd_by_threshold,
                   line_colour=DEFAULT_ROC_COLOUR,
                   line_width=DEFAULT_ROC_WIDTH,
                   random_line_colour=DEFAULT_RANDOM_ROC_COLOUR,
                   random_line_width=DEFAULT_RANDOM_ROC_WIDTH):
    """Plots ROC (receiver operating characteristic) curve.

    T = number of binarization thresholds

    For the definition of a "binarization threshold" and the role they play in
    ROC curves, see `model_evaluation.get_points_in_roc_curve`.

    :param axes_object: Instance of `matplotlib.axes._subplots.AxesSubplot`.
    :param pod_by_threshold: length-T numpy array of POD (probability of
        detection) values.
    :param pofd_by_threshold: length-T numpy array of POFD (probability of false
        detection) values.
    :param line_colour: Colour (in any format accepted by `matplotlib.colors`).
    :param line_width: Line width (real positive number).
    :param random_line_colour: Colour of reference line (ROC curve for a random
        predictor).
    :param random_line_width: Width of reference line.
    """

    error_checking.assert_is_numpy_array(pod_by_threshold, num_dimensions=1)
    error_checking.assert_is_geq_numpy_array(pod_by_threshold,
                                             0.,
                                             allow_nan=True)
    error_checking.assert_is_leq_numpy_array(pod_by_threshold,
                                             1.,
                                             allow_nan=True)
    num_thresholds = len(pod_by_threshold)

    error_checking.assert_is_numpy_array(pofd_by_threshold,
                                         exact_dimensions=numpy.array(
                                             [num_thresholds]))
    error_checking.assert_is_geq_numpy_array(pofd_by_threshold,
                                             0.,
                                             allow_nan=True)
    error_checking.assert_is_leq_numpy_array(pofd_by_threshold,
                                             1.,
                                             allow_nan=True)

    pofd_matrix, pod_matrix = model_eval.get_pofd_pod_grid()
    peirce_score_matrix = pod_matrix - pofd_matrix

    this_colour_map_object, this_colour_norm_object = _get_peirce_colour_scheme(
    )

    pyplot.contourf(pofd_matrix,
                    pod_matrix,
                    peirce_score_matrix,
                    LEVELS_FOR_CSI_CONTOURS,
                    cmap=this_colour_map_object,
                    norm=this_colour_norm_object,
                    vmin=0.,
                    vmax=1.,
                    axes=axes_object)

    colour_bar_object = plotting_utils.plot_colour_bar(
        axes_object_or_matrix=axes_object,
        data_matrix=peirce_score_matrix,
        colour_map_object=this_colour_map_object,
        colour_norm_object=this_colour_norm_object,
        orientation_string='vertical',
        extend_min=False,
        extend_max=False)

    colour_bar_object.set_label('Peirce score (POD minus POFD)')

    random_x_coords, random_y_coords = model_eval.get_random_roc_curve()
    axes_object.plot(
        random_x_coords,
        random_y_coords,
        color=plotting_utils.colour_from_numpy_to_tuple(random_line_colour),
        linestyle='dashed',
        linewidth=random_line_width)

    nan_flags = numpy.logical_or(numpy.isnan(pofd_by_threshold),
                                 numpy.isnan(pod_by_threshold))

    if not numpy.all(nan_flags):
        real_indices = numpy.where(numpy.invert(nan_flags))[0]

        axes_object.plot(
            pofd_by_threshold[real_indices],
            pod_by_threshold[real_indices],
            color=plotting_utils.colour_from_numpy_to_tuple(line_colour),
            linestyle='solid',
            linewidth=line_width)

    axes_object.set_xlabel('POFD (probability of false detection)')
    axes_object.set_ylabel('POD (probability of detection)')
    axes_object.set_xlim(0., 1.)
    axes_object.set_ylim(0., 1.)
def _plot_3d_radar_difference(difference_matrix,
                              colour_map_object,
                              max_colour_percentile,
                              model_metadata_dict,
                              backwards_opt_dict,
                              output_dir_name,
                              example_index=None,
                              significance_matrix=None):
    """Plots difference (after minus before optimization) for 3-D radar data.

    M = number of rows in spatial grid
    N = number of columns in spatial grid
    H = number of heights in spatial grid
    F = number of fields

    :param difference_matrix: M-by-N-by-H-by-F numpy array of differences (after
        minus before optimization).
    :param colour_map_object: See documentation at top of file.
    :param max_colour_percentile: Same.
    :param model_metadata_dict: Dictionary returned by
        `cnn.read_model_metadata`.
    :param backwards_opt_dict: Dictionary returned by
        `backwards_optimization.read_standard_file` or
        `backwards_optimization.read_pmm_file`, containing metadata.
    :param output_dir_name: Name of output directory.  Figure(s) will be saved
        here.
    :param example_index: This method will plot only the [i]th example, where
        i = `example_index`.  This will be used to find metadata for the given
        example in `backwards_opt_dict`.  If `backwards_opt_dict` contains PMM
        (probability-matched means), leave this argument alone.
    :param significance_matrix: M-by-N-by-H-by-F numpy array of Boolean flags,
        indicating where these differences are significantly different than
        differences from another backwards optimization.
    """

    training_option_dict = model_metadata_dict[cnn.TRAINING_OPTION_DICT_KEY]
    radar_heights_m_agl = training_option_dict[trainval_io.RADAR_HEIGHTS_KEY]
    num_heights = len(radar_heights_m_agl)

    num_panel_rows = int(numpy.floor(numpy.sqrt(num_heights)))

    pmm_flag = backwards_opt.MEAN_FINAL_ACTIVATION_KEY in backwards_opt_dict
    if pmm_flag:
        initial_activation = backwards_opt_dict[
            backwards_opt.MEAN_INITIAL_ACTIVATION_KEY]
        final_activation = backwards_opt_dict[
            backwards_opt.MEAN_FINAL_ACTIVATION_KEY]

        full_storm_id_string = None
        storm_time_string = None
    else:
        initial_activation = backwards_opt_dict[
            backwards_opt.INITIAL_ACTIVATIONS_KEY][example_index]
        final_activation = backwards_opt_dict[
            backwards_opt.FINAL_ACTIVATIONS_KEY][example_index]

        full_storm_id_string = backwards_opt_dict[
            backwards_opt.FULL_IDS_KEY][example_index]

        storm_time_string = time_conversion.unix_sec_to_string(
            backwards_opt_dict[backwards_opt.STORM_TIMES_KEY][example_index],
            plot_input_examples.TIME_FORMAT)

    conv_2d3d = model_metadata_dict[cnn.CONV_2D3D_KEY]
    if conv_2d3d:
        radar_field_names = [radar_utils.REFL_NAME]
    else:
        radar_field_names = training_option_dict[trainval_io.RADAR_FIELDS_KEY]

    num_fields = len(radar_field_names)

    for j in range(num_fields):
        this_max_colour_value = numpy.percentile(
            numpy.absolute(difference_matrix[..., j]), max_colour_percentile)

        this_colour_norm_object = matplotlib.colors.Normalize(
            vmin=-1 * this_max_colour_value,
            vmax=this_max_colour_value,
            clip=False)

        # TODO(thunderhoser): Deal with change of units.
        this_figure_object, this_axes_object_matrix = (
            radar_plotting.plot_3d_grid_without_coords(
                field_matrix=numpy.flip(difference_matrix[..., j], axis=0),
                field_name=radar_field_names[j],
                grid_point_heights_metres=radar_heights_m_agl,
                ground_relative=True,
                num_panel_rows=num_panel_rows,
                font_size=FONT_SIZE_SANS_COLOUR_BARS,
                colour_map_object=colour_map_object,
                colour_norm_object=this_colour_norm_object))

        if significance_matrix is not None:
            this_matrix = numpy.flip(significance_matrix[..., j], axis=0)

            significance_plotting.plot_many_2d_grids_without_coords(
                significance_matrix=this_matrix,
                axes_object_matrix=this_axes_object_matrix)

        plotting_utils.plot_colour_bar(
            axes_object_or_matrix=this_axes_object_matrix,
            data_matrix=difference_matrix[..., j],
            colour_map_object=colour_map_object,
            colour_norm_object=this_colour_norm_object,
            orientation_string='horizontal',
            extend_min=True,
            extend_max=True)

        if pmm_flag:
            this_title_string = 'PMM'
        else:
            this_title_string = 'Storm "{0:s}" at {1:s}'.format(
                full_storm_id_string, storm_time_string)

        this_title_string += (
            '; {0:s}; activation from {1:.2e} to {2:.2e}').format(
                radar_field_names[j], initial_activation, final_activation)

        this_figure_object.suptitle(this_title_string,
                                    fontsize=TITLE_FONT_SIZE)

        this_file_name = plot_input_examples.metadata_to_radar_fig_file_name(
            output_dir_name=output_dir_name,
            pmm_flag=pmm_flag,
            full_storm_id_string=full_storm_id_string,
            storm_time_string=storm_time_string,
            radar_field_name=radar_field_names[j])

        print('Saving figure to: "{0:s}"...'.format(this_file_name))
        this_figure_object.savefig(this_file_name,
                                   dpi=FIGURE_RESOLUTION_DPI,
                                   pad_inches=0,
                                   bbox_inches='tight')
        pyplot.close(this_figure_object)
示例#8
0
def _plot_one_field(reflectivity_matrix_dbz, latitudes_deg, longitudes_deg,
                    add_colour_bar, panel_letter, output_file_name):
    """Plots reflectivity field from one dataset.

    :param reflectivity_matrix_dbz: See doc for `_read_file`.
    :param latitudes_deg: Same.
    :param longitudes_deg: Same.
    :param add_colour_bar: Boolean flag.
    :param panel_letter: Panel letter (will be printed at top left of figure).
    :param output_file_name: Path to output file (figure will be saved here).
    """

    (figure_object, axes_object,
     basemap_object) = plotting_utils.create_equidist_cylindrical_map(
         min_latitude_deg=numpy.min(latitudes_deg),
         max_latitude_deg=numpy.max(latitudes_deg),
         min_longitude_deg=numpy.min(longitudes_deg),
         max_longitude_deg=numpy.max(longitudes_deg),
         resolution_string='i')

    plotting_utils.plot_coastlines(basemap_object=basemap_object,
                                   axes_object=axes_object,
                                   line_colour=BORDER_COLOUR)
    plotting_utils.plot_countries(basemap_object=basemap_object,
                                  axes_object=axes_object,
                                  line_colour=BORDER_COLOUR)
    plotting_utils.plot_states_and_provinces(basemap_object=basemap_object,
                                             axes_object=axes_object,
                                             line_colour=BORDER_COLOUR)
    plotting_utils.plot_parallels(basemap_object=basemap_object,
                                  axes_object=axes_object,
                                  num_parallels=NUM_PARALLELS)
    plotting_utils.plot_meridians(basemap_object=basemap_object,
                                  axes_object=axes_object,
                                  num_meridians=NUM_MERIDIANS)

    radar_plotting.plot_latlng_grid(
        field_matrix=reflectivity_matrix_dbz,
        field_name=RADAR_FIELD_NAME,
        axes_object=axes_object,
        min_grid_point_latitude_deg=numpy.min(latitudes_deg),
        min_grid_point_longitude_deg=numpy.min(longitudes_deg),
        latitude_spacing_deg=latitudes_deg[1] - latitudes_deg[0],
        longitude_spacing_deg=longitudes_deg[1] - longitudes_deg[0])

    if add_colour_bar:
        colour_map_object, colour_norm_object = (
            radar_plotting.get_default_colour_scheme(RADAR_FIELD_NAME))

        plotting_utils.plot_colour_bar(axes_object_or_matrix=axes_object,
                                       data_matrix=reflectivity_matrix_dbz,
                                       colour_map_object=colour_map_object,
                                       colour_norm_object=colour_norm_object,
                                       orientation_string='horizontal',
                                       padding=0.05,
                                       extend_min=False,
                                       extend_max=True,
                                       fraction_of_axis_length=1.)

    plotting_utils.label_axes(axes_object=axes_object,
                              label_string='({0:s})'.format(panel_letter),
                              y_coord_normalized=1.03)

    print('Saving figure to: "{0:s}"...'.format(output_file_name))
    figure_object.savefig(output_file_name,
                          dpi=FIGURE_RESOLUTION_DPI,
                          pad_inches=0,
                          bbox_inches='tight')
    pyplot.close(figure_object)
def _plot_one_example_one_time(storm_object_table, full_id_string,
                               valid_time_unix_sec, tornado_table,
                               top_myrorss_dir_name, radar_field_name,
                               radar_height_m_asl, latitude_limits_deg,
                               longitude_limits_deg):
    """Plots one example with surrounding context at one time.

    :param storm_object_table: pandas DataFrame, containing only storm objects
        at one time with the relevant primary ID.  Columns are documented in
        `storm_tracking_io.write_file`.
    :param full_id_string: Full ID of storm of interest.
    :param valid_time_unix_sec: Valid time.
    :param tornado_table: pandas DataFrame created by
        `linkage._read_input_tornado_reports`.
    :param top_myrorss_dir_name: See documentation at top of file.
    :param radar_field_name: Same.
    :param radar_height_m_asl: Same.
    :param latitude_limits_deg: See doc for `_get_plotting_limits`.
    :param longitude_limits_deg: Same.
    """

    min_plot_latitude_deg = latitude_limits_deg[0]
    max_plot_latitude_deg = latitude_limits_deg[1]
    min_plot_longitude_deg = longitude_limits_deg[0]
    max_plot_longitude_deg = longitude_limits_deg[1]

    radar_file_name = myrorss_and_mrms_io.find_raw_file(
        top_directory_name=top_myrorss_dir_name,
        spc_date_string=time_conversion.time_to_spc_date_string(
            valid_time_unix_sec),
        unix_time_sec=valid_time_unix_sec,
        data_source=radar_utils.MYRORSS_SOURCE_ID,
        field_name=radar_field_name,
        height_m_asl=radar_height_m_asl,
        raise_error_if_missing=True)

    print('Reading data from: "{0:s}"...'.format(radar_file_name))

    radar_metadata_dict = myrorss_and_mrms_io.read_metadata_from_raw_file(
        netcdf_file_name=radar_file_name,
        data_source=radar_utils.MYRORSS_SOURCE_ID)

    sparse_grid_table = (myrorss_and_mrms_io.read_data_from_sparse_grid_file(
        netcdf_file_name=radar_file_name,
        field_name_orig=radar_metadata_dict[
            myrorss_and_mrms_io.FIELD_NAME_COLUMN_ORIG],
        data_source=radar_utils.MYRORSS_SOURCE_ID,
        sentinel_values=radar_metadata_dict[radar_utils.SENTINEL_VALUE_COLUMN])
                         )

    radar_matrix, grid_point_latitudes_deg, grid_point_longitudes_deg = (
        radar_s2f.sparse_to_full_grid(sparse_grid_table=sparse_grid_table,
                                      metadata_dict=radar_metadata_dict))

    radar_matrix = numpy.flip(radar_matrix, axis=0)
    grid_point_latitudes_deg = grid_point_latitudes_deg[::-1]

    axes_object, basemap_object = (
        plotting_utils.create_equidist_cylindrical_map(
            min_latitude_deg=min_plot_latitude_deg,
            max_latitude_deg=max_plot_latitude_deg,
            min_longitude_deg=min_plot_longitude_deg,
            max_longitude_deg=max_plot_longitude_deg,
            resolution_string='i')[1:])

    plotting_utils.plot_coastlines(basemap_object=basemap_object,
                                   axes_object=axes_object,
                                   line_colour=BORDER_COLOUR)

    plotting_utils.plot_countries(basemap_object=basemap_object,
                                  axes_object=axes_object,
                                  line_colour=BORDER_COLOUR)

    plotting_utils.plot_states_and_provinces(basemap_object=basemap_object,
                                             axes_object=axes_object,
                                             line_colour=BORDER_COLOUR)

    plotting_utils.plot_parallels(basemap_object=basemap_object,
                                  axes_object=axes_object,
                                  num_parallels=NUM_PARALLELS)

    plotting_utils.plot_meridians(basemap_object=basemap_object,
                                  axes_object=axes_object,
                                  num_meridians=NUM_MERIDIANS)

    radar_plotting.plot_latlng_grid(
        field_matrix=radar_matrix,
        field_name=radar_field_name,
        axes_object=axes_object,
        min_grid_point_latitude_deg=numpy.min(grid_point_latitudes_deg),
        min_grid_point_longitude_deg=numpy.min(grid_point_longitudes_deg),
        latitude_spacing_deg=numpy.diff(grid_point_latitudes_deg[:2])[0],
        longitude_spacing_deg=numpy.diff(grid_point_longitudes_deg[:2])[0])

    colour_map_object, colour_norm_object = (
        radar_plotting.get_default_colour_scheme(radar_field_name))

    plotting_utils.plot_colour_bar(axes_object_or_matrix=axes_object,
                                   data_matrix=radar_matrix,
                                   colour_map_object=colour_map_object,
                                   colour_norm_object=colour_norm_object,
                                   orientation_string='horizontal',
                                   extend_min=False,
                                   extend_max=True,
                                   fraction_of_axis_length=0.8)

    first_list, second_list = temporal_tracking.full_to_partial_ids(
        [full_id_string])
    primary_id_string = first_list[0]
    secondary_id_string = second_list[0]

    # Plot outlines of unrelated storms (with different primary IDs).
    this_storm_object_table = storm_object_table.loc[storm_object_table[
        tracking_utils.PRIMARY_ID_COLUMN] != primary_id_string]

    storm_plotting.plot_storm_outlines(
        storm_object_table=this_storm_object_table,
        axes_object=axes_object,
        basemap_object=basemap_object,
        line_width=2,
        line_colour='k',
        line_style='dashed')

    # Plot outlines of related storms (with the same primary ID).
    this_storm_object_table = storm_object_table.loc[
        (storm_object_table[tracking_utils.PRIMARY_ID_COLUMN] ==
         primary_id_string) & (storm_object_table[
             tracking_utils.SECONDARY_ID_COLUMN] != secondary_id_string)]

    this_num_storm_objects = len(this_storm_object_table.index)

    if this_num_storm_objects > 0:
        storm_plotting.plot_storm_outlines(
            storm_object_table=this_storm_object_table,
            axes_object=axes_object,
            basemap_object=basemap_object,
            line_width=2,
            line_colour='k',
            line_style='solid')

        for j in range(len(this_storm_object_table)):
            axes_object.text(
                this_storm_object_table[
                    tracking_utils.CENTROID_LONGITUDE_COLUMN].values[j],
                this_storm_object_table[
                    tracking_utils.CENTROID_LATITUDE_COLUMN].values[j],
                'P',
                fontsize=FONT_SIZE,
                color=FONT_COLOUR,
                fontweight='bold',
                horizontalalignment='center',
                verticalalignment='center')

    # Plot outline of storm of interest (same secondary ID).
    this_storm_object_table = storm_object_table.loc[storm_object_table[
        tracking_utils.SECONDARY_ID_COLUMN] == secondary_id_string]

    storm_plotting.plot_storm_outlines(
        storm_object_table=this_storm_object_table,
        axes_object=axes_object,
        basemap_object=basemap_object,
        line_width=4,
        line_colour='k',
        line_style='solid')

    this_num_storm_objects = len(this_storm_object_table.index)

    plot_forecast = (this_num_storm_objects > 0 and FORECAST_PROBABILITY_COLUMN
                     in list(this_storm_object_table))

    if plot_forecast:
        this_polygon_object_latlng = this_storm_object_table[
            tracking_utils.LATLNG_POLYGON_COLUMN].values[0]

        this_latitude_deg = numpy.min(
            numpy.array(this_polygon_object_latlng.exterior.xy[1]))

        this_longitude_deg = this_storm_object_table[
            tracking_utils.CENTROID_LONGITUDE_COLUMN].values[0]

        label_string = 'Prob = {0:.3f}\nat {1:s}'.format(
            this_storm_object_table[FORECAST_PROBABILITY_COLUMN].values[0],
            time_conversion.unix_sec_to_string(valid_time_unix_sec,
                                               TORNADO_TIME_FORMAT))

        bounding_box_dict = {
            'facecolor':
            plotting_utils.colour_from_numpy_to_tuple(
                PROBABILITY_BACKGROUND_COLOUR),
            'alpha':
            PROBABILITY_BACKGROUND_OPACITY,
            'edgecolor':
            'k',
            'linewidth':
            1
        }

        axes_object.text(this_longitude_deg,
                         this_latitude_deg,
                         label_string,
                         fontsize=FONT_SIZE,
                         color=plotting_utils.colour_from_numpy_to_tuple(
                             PROBABILITY_FONT_COLOUR),
                         fontweight='bold',
                         bbox=bounding_box_dict,
                         horizontalalignment='center',
                         verticalalignment='top',
                         zorder=1e10)

    tornado_latitudes_deg = tornado_table[linkage.EVENT_LATITUDE_COLUMN].values
    tornado_longitudes_deg = tornado_table[
        linkage.EVENT_LONGITUDE_COLUMN].values

    tornado_times_unix_sec = tornado_table[linkage.EVENT_TIME_COLUMN].values
    tornado_time_strings = [
        time_conversion.unix_sec_to_string(t, TORNADO_TIME_FORMAT)
        for t in tornado_times_unix_sec
    ]

    axes_object.plot(tornado_longitudes_deg,
                     tornado_latitudes_deg,
                     linestyle='None',
                     marker=TORNADO_MARKER_TYPE,
                     markersize=TORNADO_MARKER_SIZE,
                     markeredgewidth=TORNADO_MARKER_EDGE_WIDTH,
                     markerfacecolor=plotting_utils.colour_from_numpy_to_tuple(
                         TORNADO_MARKER_COLOUR),
                     markeredgecolor=plotting_utils.colour_from_numpy_to_tuple(
                         TORNADO_MARKER_COLOUR))

    num_tornadoes = len(tornado_latitudes_deg)

    for j in range(num_tornadoes):
        axes_object.text(tornado_longitudes_deg[j] + 0.02,
                         tornado_latitudes_deg[j] - 0.02,
                         tornado_time_strings[j],
                         fontsize=FONT_SIZE,
                         color=FONT_COLOUR,
                         fontweight='bold',
                         horizontalalignment='left',
                         verticalalignment='top')
示例#10
0
def _plot_2d3d_radar_scan(list_of_predictor_matrices,
                          model_metadata_dict,
                          allow_whitespace,
                          title_string=None):
    """Plots 3-D reflectivity and 2-D azimuthal shear for one example.

    :param list_of_predictor_matrices: See doc for `_plot_3d_radar_scan`.
    :param model_metadata_dict: Same.
    :param allow_whitespace: Same.
    :param title_string: Same.
    :return: figure_objects: length-2 list of figure handles (instances of
        `matplotlib.figure.Figure`).  The first is for reflectivity; the second
        is for azimuthal shear.
    :return: axes_object_matrices: length-2 list (the first is for reflectivity;
        the second is for azimuthal shear).  Each element is a 2-D numpy
        array of axes handles (instances of
        `matplotlib.axes._subplots.AxesSubplot`).
    """

    training_option_dict = model_metadata_dict[cnn.TRAINING_OPTION_DICT_KEY]
    az_shear_field_names = training_option_dict[trainval_io.RADAR_FIELDS_KEY]
    refl_heights_m_agl = training_option_dict[trainval_io.RADAR_HEIGHTS_KEY]

    num_az_shear_fields = len(az_shear_field_names)
    num_refl_heights = len(refl_heights_m_agl)

    this_num_panel_rows = int(numpy.floor(numpy.sqrt(num_refl_heights)))
    this_num_panel_columns = int(
        numpy.ceil(float(num_refl_heights) / this_num_panel_rows))

    if allow_whitespace:
        refl_figure_object = None
        refl_axes_object_matrix = None
    else:
        refl_figure_object, refl_axes_object_matrix = (
            plotting_utils.create_paneled_figure(
                num_rows=this_num_panel_rows,
                num_columns=this_num_panel_columns,
                horizontal_spacing=0.,
                vertical_spacing=0.,
                shared_x_axis=False,
                shared_y_axis=False,
                keep_aspect_ratio=True))

    refl_figure_object, refl_axes_object_matrix = (
        radar_plotting.plot_3d_grid_without_coords(
            field_matrix=numpy.flip(list_of_predictor_matrices[0][..., 0],
                                    axis=0),
            field_name=radar_utils.REFL_NAME,
            grid_point_heights_metres=refl_heights_m_agl,
            ground_relative=True,
            num_panel_rows=this_num_panel_rows,
            figure_object=refl_figure_object,
            axes_object_matrix=refl_axes_object_matrix,
            font_size=FONT_SIZE_SANS_COLOUR_BARS))

    if allow_whitespace:
        this_colour_map_object, this_colour_norm_object = (
            radar_plotting.get_default_colour_scheme(radar_utils.REFL_NAME))

        plotting_utils.plot_colour_bar(
            axes_object_or_matrix=refl_axes_object_matrix,
            data_matrix=list_of_predictor_matrices[0],
            colour_map_object=this_colour_map_object,
            colour_norm_object=this_colour_norm_object,
            orientation_string='horizontal',
            extend_min=True,
            extend_max=True)

        if title_string is not None:
            this_title_string = '{0:s}; {1:s}'.format(title_string,
                                                      radar_utils.REFL_NAME)
            pyplot.suptitle(this_title_string, fontsize=TITLE_FONT_SIZE)

    if allow_whitespace:
        shear_figure_object = None
        shear_axes_object_matrix = None
    else:
        shear_figure_object, shear_axes_object_matrix = (
            plotting_utils.create_paneled_figure(
                num_rows=1,
                num_columns=num_az_shear_fields,
                horizontal_spacing=0.,
                vertical_spacing=0.,
                shared_x_axis=False,
                shared_y_axis=False,
                keep_aspect_ratio=True))

    shear_figure_object, shear_axes_object_matrix = (
        radar_plotting.plot_many_2d_grids_without_coords(
            field_matrix=numpy.flip(list_of_predictor_matrices[1], axis=0),
            field_name_by_panel=az_shear_field_names,
            panel_names=az_shear_field_names,
            num_panel_rows=1,
            figure_object=shear_figure_object,
            axes_object_matrix=shear_axes_object_matrix,
            plot_colour_bar_by_panel=numpy.full(num_az_shear_fields,
                                                False,
                                                dtype=bool),
            font_size=FONT_SIZE_SANS_COLOUR_BARS))

    if allow_whitespace:
        this_colour_map_object, this_colour_norm_object = (
            radar_plotting.get_default_colour_scheme(
                radar_utils.LOW_LEVEL_SHEAR_NAME))

        plotting_utils.plot_colour_bar(
            axes_object_or_matrix=shear_axes_object_matrix,
            data_matrix=list_of_predictor_matrices[1],
            colour_map_object=this_colour_map_object,
            colour_norm_object=this_colour_norm_object,
            orientation_string='horizontal',
            extend_min=True,
            extend_max=True)

        if title_string is not None:
            pyplot.suptitle(title_string, fontsize=TITLE_FONT_SIZE)

    figure_objects = [refl_figure_object, shear_figure_object]
    axes_object_matrices = [refl_axes_object_matrix, shear_axes_object_matrix]
    return figure_objects, axes_object_matrices
def _plot_forecast_one_time(
        gridded_forecast_dict, time_index, min_plot_latitude_deg,
        max_plot_latitude_deg, min_plot_longitude_deg, max_plot_longitude_deg,
        output_dir_name, tornado_dir_name=None):
    """Plots gridded forecast at one time.

    :param gridded_forecast_dict: Dictionary returned by
        `prediction_io.read_gridded_predictions`.
    :param time_index: Will plot the [i]th gridded forecast, where
        i = `time_index`.
    :param min_plot_latitude_deg: See documentation at top of file.
    :param max_plot_latitude_deg: Same.
    :param min_plot_longitude_deg: Same.
    :param max_plot_longitude_deg: Same.
    :param output_dir_name: Name of output directory.  Figure will be saved
        here.
    :param tornado_dir_name: See documentation at top of file.
    """

    init_time_unix_sec = gridded_forecast_dict[prediction_io.INIT_TIMES_KEY][
        time_index
    ]
    min_lead_time_seconds = gridded_forecast_dict[
        prediction_io.MIN_LEAD_TIME_KEY
    ]
    max_lead_time_seconds = gridded_forecast_dict[
        prediction_io.MAX_LEAD_TIME_KEY
    ]

    first_valid_time_unix_sec = init_time_unix_sec + min_lead_time_seconds
    last_valid_time_unix_sec = init_time_unix_sec + max_lead_time_seconds

    tornado_latitudes_deg = numpy.array([])
    tornado_longitudes_deg = numpy.array([])

    if tornado_dir_name is not None:
        first_year = int(
            time_conversion.unix_sec_to_string(first_valid_time_unix_sec, '%Y')
        )
        last_year = int(
            time_conversion.unix_sec_to_string(last_valid_time_unix_sec, '%Y')
        )

        for this_year in range(first_year, last_year + 1):
            this_file_name = tornado_io.find_processed_file(
                directory_name=tornado_dir_name, year=this_year)

            print('Reading tornado reports from: "{0:s}"...'.format(
                this_file_name))

            this_tornado_table = tornado_io.read_processed_file(this_file_name)

            this_tornado_table = this_tornado_table.loc[
                (this_tornado_table[tornado_io.START_TIME_COLUMN]
                 >= first_valid_time_unix_sec)
                & (this_tornado_table[tornado_io.START_TIME_COLUMN]
                   <= last_valid_time_unix_sec)
                ]

            tornado_latitudes_deg = numpy.concatenate((
                tornado_latitudes_deg,
                this_tornado_table[tornado_io.START_LAT_COLUMN].values
            ))

            tornado_longitudes_deg = numpy.concatenate((
                tornado_longitudes_deg,
                this_tornado_table[tornado_io.START_LNG_COLUMN].values
            ))

        print('\n')

    custom_area = all([
        x is not None for x in
        [min_plot_latitude_deg, max_plot_latitude_deg, min_plot_longitude_deg,
         max_plot_longitude_deg]
    ])

    if custom_area:
        latlng_limit_dict = {
            plotting_utils.MIN_LATITUDE_KEY: min_plot_latitude_deg,
            plotting_utils.MAX_LATITUDE_KEY: max_plot_latitude_deg,
            plotting_utils.MIN_LONGITUDE_KEY: min_plot_longitude_deg,
            plotting_utils.MAX_LONGITUDE_KEY: max_plot_longitude_deg
        }
    else:
        latlng_limit_dict = None

    axes_object, basemap_object = plotting_utils.create_map_with_nwp_proj(
        model_name=nwp_model_utils.RAP_MODEL_NAME,
        grid_name=nwp_model_utils.NAME_OF_130GRID, xy_limit_dict=None,
        latlng_limit_dict=latlng_limit_dict, resolution_string='i'
    )[1:]

    # if not custom_area:
    #     min_plot_latitude_deg = basemap_object.llcrnrlat
    #     max_plot_latitude_deg = basemap_object.urcrnrlat
    #     min_plot_longitude_deg = basemap_object.llcrnrlon
    #     max_plot_longitude_deg = basemap_object.urcrnrlon

    x_offset_metres, y_offset_metres = _get_projection_offsets(
        basemap_object=basemap_object, pyproj_object=PYPROJ_OBJECT,
        test_latitudes_deg=TEST_LATITUDES_DEG,
        test_longitudes_deg=TEST_LONGITUDES_DEG)

    probability_matrix = gridded_forecast_dict[
        prediction_io.XY_PROBABILITIES_KEY
    ][time_index]

    # If necessary, convert from sparse to dense matrix.
    if not isinstance(probability_matrix, numpy.ndarray):
        probability_matrix = probability_matrix.toarray()

    x_coords_metres = (
        gridded_forecast_dict[prediction_io.GRID_X_COORDS_KEY] + x_offset_metres
    )
    y_coords_metres = (
        gridded_forecast_dict[prediction_io.GRID_Y_COORDS_KEY] + y_offset_metres
    )

    probability_plotting.plot_xy_grid(
        probability_matrix=probability_matrix,
        x_min_metres=numpy.min(x_coords_metres),
        y_min_metres=numpy.min(y_coords_metres),
        x_spacing_metres=numpy.diff(x_coords_metres[:2])[0],
        y_spacing_metres=numpy.diff(y_coords_metres[:2])[0],
        axes_object=axes_object, basemap_object=basemap_object)

    plotting_utils.plot_coastlines(
        basemap_object=basemap_object, axes_object=axes_object,
        line_colour=BORDER_COLOUR)

    plotting_utils.plot_countries(
        basemap_object=basemap_object, axes_object=axes_object,
        line_colour=BORDER_COLOUR)

    plotting_utils.plot_states_and_provinces(
        basemap_object=basemap_object, axes_object=axes_object,
        line_colour=BORDER_COLOUR)

    colour_map_object, colour_norm_object = (
        probability_plotting.get_default_colour_map()
    )

    plotting_utils.plot_colour_bar(
        axes_object_or_matrix=axes_object, data_matrix=probability_matrix,
        colour_map_object=colour_map_object,
        colour_norm_object=colour_norm_object, orientation_string='horizontal',
        extend_min=True, extend_max=True, fraction_of_axis_length=0.8)

    if len(tornado_latitudes_deg) > 0:
        tornado_x_coords_metres, tornado_y_coords_metres = basemap_object(
            tornado_longitudes_deg, tornado_latitudes_deg)

        axes_object.plot(
            tornado_x_coords_metres, tornado_y_coords_metres, linestyle='None',
            marker=TORNADO_MARKER_TYPE, markersize=TORNADO_MARKER_SIZE,
            markeredgewidth=TORNADO_MARKER_EDGE_WIDTH,
            markerfacecolor=plotting_utils.colour_from_numpy_to_tuple(
                TORNADO_MARKER_COLOUR),
            markeredgecolor=plotting_utils.colour_from_numpy_to_tuple(
                TORNADO_MARKER_COLOUR)
        )

    init_time_string = time_conversion.unix_sec_to_string(
        init_time_unix_sec, FILE_NAME_TIME_FORMAT
    )

    # first_valid_time_string = time_conversion.unix_sec_to_string(
    #     first_valid_time_unix_sec, FILE_NAME_TIME_FORMAT
    # )
    # last_valid_time_string = time_conversion.unix_sec_to_string(
    #     last_valid_time_unix_sec, FILE_NAME_TIME_FORMAT
    # )
    # title_string = 'Forecast init {0:s}, valid {1:s} to {2:s}'.format(
    #     init_time_string, first_valid_time_string, last_valid_time_string
    # )
    # pyplot.title(title_string, fontsize=TITLE_FONT_SIZE)

    output_file_name = (
        '{0:s}/gridded_forecast_init-{1:s}_lead-{2:06d}-{3:06d}sec.png'
    ).format(
        output_dir_name, init_time_string, min_lead_time_seconds,
        max_lead_time_seconds
    )

    print('Saving figure to: "{0:s}"...'.format(output_file_name))
    pyplot.savefig(output_file_name, dpi=FIGURE_RESOLUTION_DPI)
    pyplot.close()

    imagemagick_utils.trim_whitespace(input_file_name=output_file_name,
                                      output_file_name=output_file_name)
示例#12
0
def _plot_3d_radar_scan(list_of_predictor_matrices,
                        model_metadata_dict,
                        allow_whitespace,
                        title_string=None):
    """Plots 3-D radar scan for one example.

    J = number of panel rows in image
    K = number of panel columns in image
    F = number of radar fields

    :param list_of_predictor_matrices: List created by
        `testing_io.read_specific_examples`, except that the first axis (example
        dimension) is removed.
    :param model_metadata_dict: Dictionary returned by
        `cnn.read_model_metadata`.
    :param allow_whitespace: See documentation at top of file.
    :param title_string: Title (may be None).

    :return: figure_objects: length-F list of figure handles (instances of
        `matplotlib.figure.Figure`).
    :return: axes_object_matrices: length-F list.  Each element is a J-by-K
        numpy array of axes handles (instances of
        `matplotlib.axes._subplots.AxesSubplot`).
    """

    training_option_dict = model_metadata_dict[cnn.TRAINING_OPTION_DICT_KEY]
    radar_field_names = training_option_dict[trainval_io.RADAR_FIELDS_KEY]
    radar_heights_m_agl = training_option_dict[trainval_io.RADAR_HEIGHTS_KEY]

    num_radar_fields = len(radar_field_names)
    num_radar_heights = len(radar_heights_m_agl)

    num_panel_rows = int(numpy.floor(numpy.sqrt(num_radar_heights)))
    num_panel_columns = int(
        numpy.ceil(float(num_radar_heights) / num_panel_rows))

    figure_objects = [None] * num_radar_fields
    axes_object_matrices = [None] * num_radar_fields
    radar_matrix = list_of_predictor_matrices[0]

    for j in range(num_radar_fields):
        this_radar_matrix = numpy.flip(radar_matrix[..., j], axis=0)

        if not allow_whitespace:
            figure_objects[j], axes_object_matrices[j] = (
                plotting_utils.create_paneled_figure(
                    num_rows=num_panel_rows,
                    num_columns=num_panel_columns,
                    horizontal_spacing=0.,
                    vertical_spacing=0.,
                    shared_x_axis=False,
                    shared_y_axis=False,
                    keep_aspect_ratio=True))

        figure_objects[j], axes_object_matrices[j] = (
            radar_plotting.plot_3d_grid_without_coords(
                field_matrix=this_radar_matrix,
                field_name=radar_field_names[j],
                grid_point_heights_metres=radar_heights_m_agl,
                ground_relative=True,
                num_panel_rows=num_panel_rows,
                figure_object=figure_objects[j],
                axes_object_matrix=axes_object_matrices[j],
                font_size=FONT_SIZE_SANS_COLOUR_BARS))

        if allow_whitespace:
            this_colour_map_object, this_colour_norm_object = (
                radar_plotting.get_default_colour_scheme(radar_field_names[j]))

            plotting_utils.plot_colour_bar(
                axes_object_or_matrix=axes_object_matrices[j],
                data_matrix=this_radar_matrix,
                colour_map_object=this_colour_map_object,
                colour_norm_object=this_colour_norm_object,
                orientation_string='horizontal',
                extend_min=True,
                extend_max=True)

            if title_string is not None:
                this_title_string = '{0:s}; {1:s}'.format(
                    title_string, radar_field_names[j])
                pyplot.suptitle(this_title_string, fontsize=TITLE_FONT_SIZE)

    return figure_objects, axes_object_matrices
示例#13
0
def _plot_one_example_one_time(
        storm_object_table, full_id_string, valid_time_unix_sec,
        tornado_table, top_myrorss_dir_name, radar_field_name,
        radar_height_m_asl, latitude_limits_deg, longitude_limits_deg):
    """Plots one example with surrounding context at one time.

    :param storm_object_table: pandas DataFrame, containing only storm objects
        at one time with the relevant primary ID.  Columns are documented in
        `storm_tracking_io.write_file`.
    :param full_id_string: Full ID of storm of interest.
    :param valid_time_unix_sec: Valid time.
    :param tornado_table: pandas DataFrame created by
        `linkage._read_input_tornado_reports`.
    :param top_myrorss_dir_name: See documentation at top of file.
    :param radar_field_name: Same.
    :param radar_height_m_asl: Same.
    :param latitude_limits_deg: See doc for `_get_plotting_limits`.
    :param longitude_limits_deg: Same.
    """

    min_plot_latitude_deg = latitude_limits_deg[0]
    max_plot_latitude_deg = latitude_limits_deg[1]
    min_plot_longitude_deg = longitude_limits_deg[0]
    max_plot_longitude_deg = longitude_limits_deg[1]

    radar_file_name = myrorss_and_mrms_io.find_raw_file_inexact_time(
        top_directory_name=top_myrorss_dir_name,
        desired_time_unix_sec=valid_time_unix_sec,
        spc_date_string=time_conversion.time_to_spc_date_string(
            valid_time_unix_sec),
        data_source=radar_utils.MYRORSS_SOURCE_ID,
        field_name=radar_field_name, height_m_asl=radar_height_m_asl,
        max_time_offset_sec=
        myrorss_and_mrms_io.DEFAULT_MAX_TIME_OFFSET_FOR_NON_SHEAR_SEC,
        raise_error_if_missing=True)

    print('Reading data from: "{0:s}"...'.format(radar_file_name))

    radar_metadata_dict = myrorss_and_mrms_io.read_metadata_from_raw_file(
        netcdf_file_name=radar_file_name,
        data_source=radar_utils.MYRORSS_SOURCE_ID)

    sparse_grid_table = (
        myrorss_and_mrms_io.read_data_from_sparse_grid_file(
            netcdf_file_name=radar_file_name,
            field_name_orig=radar_metadata_dict[
                myrorss_and_mrms_io.FIELD_NAME_COLUMN_ORIG],
            data_source=radar_utils.MYRORSS_SOURCE_ID,
            sentinel_values=radar_metadata_dict[
                radar_utils.SENTINEL_VALUE_COLUMN]
        )
    )

    radar_matrix, grid_point_latitudes_deg, grid_point_longitudes_deg = (
        radar_s2f.sparse_to_full_grid(
            sparse_grid_table=sparse_grid_table,
            metadata_dict=radar_metadata_dict)
    )

    radar_matrix = numpy.flip(radar_matrix, axis=0)
    grid_point_latitudes_deg = grid_point_latitudes_deg[::-1]

    axes_object, basemap_object = (
        plotting_utils.create_equidist_cylindrical_map(
            min_latitude_deg=min_plot_latitude_deg,
            max_latitude_deg=max_plot_latitude_deg,
            min_longitude_deg=min_plot_longitude_deg,
            max_longitude_deg=max_plot_longitude_deg, resolution_string='h'
        )[1:]
    )

    plotting_utils.plot_coastlines(
        basemap_object=basemap_object, axes_object=axes_object,
        line_colour=plotting_utils.DEFAULT_COUNTRY_COLOUR)

    plotting_utils.plot_countries(
        basemap_object=basemap_object, axes_object=axes_object)

    plotting_utils.plot_states_and_provinces(
        basemap_object=basemap_object, axes_object=axes_object)

    plotting_utils.plot_parallels(
        basemap_object=basemap_object, axes_object=axes_object,
        num_parallels=NUM_PARALLELS, line_width=0)

    plotting_utils.plot_meridians(
        basemap_object=basemap_object, axes_object=axes_object,
        num_meridians=NUM_MERIDIANS, line_width=0)

    radar_plotting.plot_latlng_grid(
        field_matrix=radar_matrix, field_name=radar_field_name,
        axes_object=axes_object,
        min_grid_point_latitude_deg=numpy.min(grid_point_latitudes_deg),
        min_grid_point_longitude_deg=numpy.min(grid_point_longitudes_deg),
        latitude_spacing_deg=numpy.diff(grid_point_latitudes_deg[:2])[0],
        longitude_spacing_deg=numpy.diff(grid_point_longitudes_deg[:2])[0]
    )

    colour_map_object, colour_norm_object = (
        radar_plotting.get_default_colour_scheme(radar_field_name)
    )

    plotting_utils.plot_colour_bar(
        axes_object_or_matrix=axes_object, data_matrix=radar_matrix,
        colour_map_object=colour_map_object,
        colour_norm_object=colour_norm_object, orientation_string='horizontal',
        padding=0.05, extend_min=False, extend_max=True,
        fraction_of_axis_length=0.8)

    first_list, second_list = temporal_tracking.full_to_partial_ids(
        [full_id_string]
    )
    primary_id_string = first_list[0]
    secondary_id_string = second_list[0]

    # Plot outlines of unrelated storms (with different primary IDs).
    this_storm_object_table = storm_object_table.loc[
        storm_object_table[tracking_utils.PRIMARY_ID_COLUMN] !=
        primary_id_string
    ]

    storm_plotting.plot_storm_outlines(
        storm_object_table=this_storm_object_table, axes_object=axes_object,
        basemap_object=basemap_object, line_width=AUXILIARY_STORM_WIDTH,
        line_colour='k', line_style='dashed')

    # Plot outlines of related storms (with the same primary ID).
    this_storm_object_table = storm_object_table.loc[
        (storm_object_table[tracking_utils.PRIMARY_ID_COLUMN] ==
         primary_id_string) &
        (storm_object_table[tracking_utils.SECONDARY_ID_COLUMN] !=
         secondary_id_string)
    ]

    this_num_storm_objects = len(this_storm_object_table.index)

    if this_num_storm_objects > 0:
        storm_plotting.plot_storm_outlines(
            storm_object_table=this_storm_object_table, axes_object=axes_object,
            basemap_object=basemap_object, line_width=AUXILIARY_STORM_WIDTH,
            line_colour='k', line_style='solid'
        )

        for j in range(len(this_storm_object_table)):
            axes_object.text(
                this_storm_object_table[
                    tracking_utils.CENTROID_LONGITUDE_COLUMN
                ].values[j],
                this_storm_object_table[
                    tracking_utils.CENTROID_LATITUDE_COLUMN
                ].values[j],
                'P',
                fontsize=MAIN_FONT_SIZE, color=FONT_COLOUR, fontweight='bold',
                horizontalalignment='center', verticalalignment='center'
            )

    # Plot outline of storm of interest (same secondary ID).
    this_storm_object_table = storm_object_table.loc[
        storm_object_table[tracking_utils.SECONDARY_ID_COLUMN] ==
        secondary_id_string
    ]

    storm_plotting.plot_storm_outlines(
        storm_object_table=this_storm_object_table, axes_object=axes_object,
        basemap_object=basemap_object, line_width=MAIN_STORM_WIDTH,
        line_colour='k', line_style='solid')

    this_num_storm_objects = len(this_storm_object_table.index)

    plot_forecast = (
        this_num_storm_objects > 0 and
        FORECAST_PROBABILITY_COLUMN in list(this_storm_object_table)
    )

    if plot_forecast:
        label_string = 'Prob = {0:.3f}\nat {1:s}'.format(
            this_storm_object_table[FORECAST_PROBABILITY_COLUMN].values[0],
            time_conversion.unix_sec_to_string(
                valid_time_unix_sec, TORNADO_TIME_FORMAT)
        )

        axes_object.set_title(
            label_string.replace('\n', ' '), fontsize=TITLE_FONT_SIZE
        )

    tornado_id_strings = tornado_table[tornado_io.TORNADO_ID_COLUMN].values

    for this_tornado_id_string in numpy.unique(tornado_id_strings):
        these_rows = numpy.where(
            tornado_id_strings == this_tornado_id_string
        )[0]

        this_tornado_table = tornado_table.iloc[these_rows].sort_values(
            linkage.EVENT_TIME_COLUMN, axis=0, ascending=True, inplace=False
        )
        _plot_one_tornado(
            tornado_table=this_tornado_table, axes_object=axes_object
        )
示例#14
0
def _plot_one_score(score_matrix, colour_map_object, min_colour_value,
                    max_colour_value, colour_bar_label, is_score_bias,
                    best_model_index, output_file_name):
    """Plots one score.

    :param score_matrix: 4-D numpy array of scores, where the first axis
        represents dropout rate; second represents L2 weight; third represents
        num dense layers; and fourth is data augmentation (yes or no).
    :param colour_map_object: See documentation at top of file.
    :param min_colour_value: Minimum value in colour scheme.
    :param max_colour_value: Max value in colour scheme.
    :param colour_bar_label: Label string for colour bar.
    :param is_score_bias: Boolean flag.  If True, score to be plotted is
        frequency bias, which changes settings for colour scheme.
    :param best_model_index: Linear index of best model.
    :param output_file_name: Path to output file (figure will be saved here).
    """

    if is_score_bias:
        colour_map_object, colour_norm_object = _get_bias_colour_scheme(
            max_value=max_colour_value)
    else:
        colour_norm_object = None

    num_dense_layer_counts = len(DENSE_LAYER_COUNTS)
    num_data_aug_flags = len(DATA_AUGMENTATION_FLAGS)

    figure_object, axes_object_matrix = plotting_utils.create_paneled_figure(
        num_rows=num_dense_layer_counts * num_data_aug_flags,
        num_columns=1,
        horizontal_spacing=0.15,
        vertical_spacing=0.15,
        shared_x_axis=False,
        shared_y_axis=False,
        keep_aspect_ratio=True)

    axes_object_matrix = numpy.reshape(
        axes_object_matrix, (num_dense_layer_counts, num_data_aug_flags))

    x_axis_label = r'L$_2$ weight (log$_{10}$)'
    y_axis_label = 'Dropout rate'
    x_tick_labels = ['{0:.1f}'.format(w) for w in numpy.log10(L2_WEIGHTS)]
    y_tick_labels = ['{0:.3f}'.format(d) for d in DROPOUT_RATES]

    best_model_index_tuple = numpy.unravel_index(best_model_index,
                                                 score_matrix.shape)

    for k in range(num_dense_layer_counts):
        for m in range(num_data_aug_flags):
            model_eval.plot_hyperparam_grid(
                score_matrix=score_matrix[..., k, m],
                min_colour_value=min_colour_value,
                max_colour_value=max_colour_value,
                colour_map_object=colour_map_object,
                colour_norm_object=colour_norm_object,
                axes_object=axes_object_matrix[k, m])

            axes_object_matrix[k, m].set_xticklabels(
                x_tick_labels, fontsize=TICK_LABEL_FONT_SIZE, rotation=90.)
            axes_object_matrix[k, m].set_yticklabels(
                y_tick_labels, fontsize=TICK_LABEL_FONT_SIZE)
            axes_object_matrix[k, m].set_ylabel(y_axis_label,
                                                fontsize=TICK_LABEL_FONT_SIZE)

            if k == num_dense_layer_counts - 1 and m == num_data_aug_flags - 1:
                axes_object_matrix[k, m].set_xlabel(x_axis_label)
            else:
                axes_object_matrix[k, m].set_xticks([], [])

            this_title_string = '{0:d} dense layer{1:s}, DA {2:s}'.format(
                DENSE_LAYER_COUNTS[k],
                's' if DENSE_LAYER_COUNTS[k] > 1 else '',
                'on' if DATA_AUGMENTATION_FLAGS[m] else 'off')

            axes_object_matrix[k, m].set_title(this_title_string)

    i = best_model_index_tuple[0]
    j = best_model_index_tuple[1]
    k = best_model_index_tuple[2]
    m = best_model_index_tuple[3]

    axes_object_matrix[k, m].plot(j,
                                  i,
                                  linestyle='None',
                                  marker=BEST_MODEL_MARKER_TYPE,
                                  markersize=BEST_MODEL_MARKER_SIZE,
                                  markerfacecolor=MARKER_COLOUR,
                                  markeredgecolor=MARKER_COLOUR,
                                  markeredgewidth=BEST_MODEL_MARKER_WIDTH)

    corrupt_model_indices = numpy.where(numpy.isnan(
        numpy.ravel(score_matrix)))[0]

    for this_linear_index in corrupt_model_indices:
        i, j, k, m = numpy.unravel_index(this_linear_index, score_matrix.shape)

        axes_object_matrix[k,
                           m].plot(j,
                                   i,
                                   linestyle='None',
                                   marker=CORRUPT_MODEL_MARKER_TYPE,
                                   markersize=CORRUPT_MODEL_MARKER_SIZE,
                                   markerfacecolor=MARKER_COLOUR,
                                   markeredgecolor=MARKER_COLOUR,
                                   markeredgewidth=CORRUPT_MODEL_MARKER_WIDTH)

    if is_score_bias:
        colour_bar_object = plotting_utils.plot_colour_bar(
            axes_object_or_matrix=axes_object_matrix,
            data_matrix=score_matrix,
            colour_map_object=colour_map_object,
            colour_norm_object=colour_norm_object,
            orientation_string='vertical',
            extend_min=False,
            extend_max=True,
            font_size=DEFAULT_FONT_SIZE)

        tick_values = colour_bar_object.get_ticks()
        tick_strings = ['{0:.1f}'.format(v) for v in tick_values]

        colour_bar_object.set_ticks(tick_values)
        colour_bar_object.set_ticklabels(tick_strings)
    else:
        colour_bar_object = plotting_utils.plot_linear_colour_bar(
            axes_object_or_matrix=axes_object_matrix,
            data_matrix=score_matrix,
            colour_map_object=colour_map_object,
            min_value=min_colour_value,
            max_value=max_colour_value,
            orientation_string='vertical',
            extend_min=True,
            extend_max=True,
            font_size=DEFAULT_FONT_SIZE)

    colour_bar_object.set_label(colour_bar_label)
    print('Saving figure to: "{0:s}"...'.format(output_file_name))

    figure_object.savefig(output_file_name,
                          dpi=FIGURE_RESOLUTION_DPI,
                          pad_inches=0,
                          bbox_inches='tight')
    pyplot.close(figure_object)
示例#15
0
def plot_performance_diagram(axes_object,
                             pod_by_threshold,
                             success_ratio_by_threshold,
                             line_colour=PERF_DIAGRAM_COLOUR,
                             plot_background=True):
    """Plots performance diagram.

    T = number of binarization thresholds

    For the definition of a "binarization threshold" and the role they play in
    performance diagrams, see
    `model_evaluation.get_points_in_performance_diagram`.

    :param axes_object: Instance of `matplotlib.axes._subplots.AxesSubplot`.
    :param pod_by_threshold: length-T numpy array of POD (probability of
        detection) values.
    :param success_ratio_by_threshold: length-T numpy array of success ratios.
    :param line_colour: Line colour.
    :param plot_background: Boolean flag.  If True, will plot background
        (frequency-bias and CSI contours).
    :return: line_handle: Line handle for ROC curve.
    """

    error_checking.assert_is_numpy_array(pod_by_threshold, num_dimensions=1)
    error_checking.assert_is_geq_numpy_array(pod_by_threshold,
                                             0.,
                                             allow_nan=True)
    error_checking.assert_is_leq_numpy_array(pod_by_threshold,
                                             1.,
                                             allow_nan=True)

    num_thresholds = len(pod_by_threshold)
    expected_dim = numpy.array([num_thresholds], dtype=int)

    error_checking.assert_is_numpy_array(success_ratio_by_threshold,
                                         exact_dimensions=expected_dim)
    error_checking.assert_is_geq_numpy_array(success_ratio_by_threshold,
                                             0.,
                                             allow_nan=True)
    error_checking.assert_is_leq_numpy_array(success_ratio_by_threshold,
                                             1.,
                                             allow_nan=True)

    error_checking.assert_is_boolean(plot_background)

    if plot_background:
        success_ratio_matrix, pod_matrix = model_eval.get_sr_pod_grid()
        csi_matrix = model_eval.csi_from_sr_and_pod(
            success_ratio_array=success_ratio_matrix, pod_array=pod_matrix)
        frequency_bias_matrix = model_eval.frequency_bias_from_sr_and_pod(
            success_ratio_array=success_ratio_matrix, pod_array=pod_matrix)

        this_colour_map_object, this_colour_norm_object = (
            _get_csi_colour_scheme())

        pyplot.contourf(success_ratio_matrix,
                        pod_matrix,
                        csi_matrix,
                        CSI_LEVELS,
                        cmap=this_colour_map_object,
                        norm=this_colour_norm_object,
                        vmin=0.,
                        vmax=1.,
                        axes=axes_object)

        colour_bar_object = plotting_utils.plot_colour_bar(
            axes_object_or_matrix=axes_object,
            data_matrix=csi_matrix,
            colour_map_object=this_colour_map_object,
            colour_norm_object=this_colour_norm_object,
            orientation_string='vertical',
            extend_min=False,
            extend_max=False,
            fraction_of_axis_length=0.8)

        colour_bar_object.set_label('CSI (critical success index)')

        bias_colour_tuple = plotting_utils.colour_from_numpy_to_tuple(
            FREQ_BIAS_COLOUR)

        bias_colours_2d_tuple = ()
        for _ in range(len(FREQ_BIAS_LEVELS)):
            bias_colours_2d_tuple += (bias_colour_tuple, )

        bias_contour_object = pyplot.contour(success_ratio_matrix,
                                             pod_matrix,
                                             frequency_bias_matrix,
                                             FREQ_BIAS_LEVELS,
                                             colors=bias_colours_2d_tuple,
                                             linewidths=FREQ_BIAS_WIDTH,
                                             linestyles='dashed',
                                             axes=axes_object)

        pyplot.clabel(bias_contour_object,
                      inline=True,
                      inline_spacing=FREQ_BIAS_PADDING,
                      fmt=FREQ_BIAS_STRING_FORMAT,
                      fontsize=FONT_SIZE)

    nan_flags = numpy.logical_or(numpy.isnan(success_ratio_by_threshold),
                                 numpy.isnan(pod_by_threshold))

    if numpy.all(nan_flags):
        line_handle = None
    else:
        real_indices = numpy.where(numpy.invert(nan_flags))[0]

        line_handle = axes_object.plot(
            success_ratio_by_threshold[real_indices],
            pod_by_threshold[real_indices],
            color=plotting_utils.colour_from_numpy_to_tuple(line_colour),
            linestyle='solid',
            linewidth=PERF_DIAGRAM_WIDTH)[0]

    axes_object.set_xlabel('Success ratio (1 - FAR)')
    axes_object.set_ylabel('POD (probability of detection)')
    axes_object.set_xlim(0., 1.)
    axes_object.set_ylim(0., 1.)

    return line_handle
def _plot_storm_outlines_one_time(storm_object_table,
                                  valid_time_unix_sec,
                                  warning_table,
                                  axes_object,
                                  basemap_object,
                                  storm_outline_colour,
                                  storm_outline_opacity,
                                  include_secondary_ids,
                                  output_dir_name,
                                  primary_id_to_track_colour=None,
                                  radar_matrix=None,
                                  radar_field_name=None,
                                  radar_latitudes_deg=None,
                                  radar_longitudes_deg=None,
                                  radar_colour_map_object=None):
    """Plots storm outlines (and may underlay radar data) at one time step.

    M = number of rows in radar grid
    N = number of columns in radar grid
    K = number of storm objects

    If `primary_id_to_track_colour is None`, all storm tracks will be the same
    colour.

    :param storm_object_table: See doc for `storm_plotting.plot_storm_outlines`.
    :param valid_time_unix_sec: Will plot storm outlines only at this time.
        Will plot tracks up to and including this time.
    :param warning_table: None or a pandas table with the following columns.
    warning_table.start_time_unix_sec: Start time.
    warning_table.end_time_unix_sec: End time.
    warning_table.polygon_object_latlng: Polygon (instance of
        `shapely.geometry.Polygon`) with lat-long coordinates of warning
        boundary.

    :param axes_object: See doc for `storm_plotting.plot_storm_outlines`.
    :param basemap_object: Same.
    :param storm_outline_colour: Same.
    :param storm_outline_opacity: Same.
    :param include_secondary_ids: Same.
    :param output_dir_name: See documentation at top of file.
    :param primary_id_to_track_colour: Dictionary created by
        `_assign_colours_to_storms`.  If this is None, all storm tracks will be
        the same colour.
    :param radar_matrix: M-by-N numpy array of radar values.  If
        `radar_matrix is None`, radar data will simply not be plotted.
    :param radar_field_name: [used only if `radar_matrix is not None`]
        See documentation at top of file.
    :param radar_latitudes_deg: [used only if `radar_matrix is not None`]
        length-M numpy array of grid-point latitudes (deg N).
    :param radar_longitudes_deg: [used only if `radar_matrix is not None`]
        length-N numpy array of grid-point longitudes (deg E).
    :param radar_colour_map_object: [used only if `radar_matrix is not None`]
        Colour map (instance of `matplotlib.pyplot.cm`).  If None, will use
        default for the given field.
    """

    # plot_storm_ids = radar_matrix is None or radar_colour_map_object is None
    plot_storm_ids = False

    min_plot_latitude_deg = basemap_object.llcrnrlat
    max_plot_latitude_deg = basemap_object.urcrnrlat
    min_plot_longitude_deg = basemap_object.llcrnrlon
    max_plot_longitude_deg = basemap_object.urcrnrlon

    plotting_utils.plot_coastlines(basemap_object=basemap_object,
                                   axes_object=axes_object,
                                   line_colour=BORDER_COLOUR)
    plotting_utils.plot_countries(basemap_object=basemap_object,
                                  axes_object=axes_object,
                                  line_colour=BORDER_COLOUR)
    plotting_utils.plot_states_and_provinces(basemap_object=basemap_object,
                                             axes_object=axes_object,
                                             line_colour=BORDER_COLOUR)
    plotting_utils.plot_parallels(basemap_object=basemap_object,
                                  axes_object=axes_object,
                                  num_parallels=NUM_PARALLELS)
    plotting_utils.plot_meridians(basemap_object=basemap_object,
                                  axes_object=axes_object,
                                  num_meridians=NUM_MERIDIANS)

    if radar_matrix is not None:
        custom_colour_map = radar_colour_map_object is not None

        good_indices = numpy.where(
            numpy.logical_and(radar_latitudes_deg >= min_plot_latitude_deg,
                              radar_latitudes_deg <= max_plot_latitude_deg))[0]

        radar_latitudes_deg = radar_latitudes_deg[good_indices]
        radar_matrix = radar_matrix[good_indices, :]

        good_indices = numpy.where(
            numpy.logical_and(
                radar_longitudes_deg >= min_plot_longitude_deg,
                radar_longitudes_deg <= max_plot_longitude_deg))[0]

        radar_longitudes_deg = radar_longitudes_deg[good_indices]
        radar_matrix = radar_matrix[:, good_indices]

        latitude_spacing_deg = radar_latitudes_deg[1] - radar_latitudes_deg[0]
        longitude_spacing_deg = (radar_longitudes_deg[1] -
                                 radar_longitudes_deg[0])

        if radar_colour_map_object is None:
            colour_map_object, colour_norm_object = (
                radar_plotting.get_default_colour_scheme(radar_field_name))
        else:
            colour_map_object = radar_colour_map_object
            colour_norm_object = radar_plotting.get_default_colour_scheme(
                radar_field_name)[-1]

            this_ratio = radar_plotting._field_to_plotting_units(
                field_matrix=1., field_name=radar_field_name)

            colour_norm_object = pyplot.Normalize(
                colour_norm_object.vmin / this_ratio,
                colour_norm_object.vmax / this_ratio)

        radar_plotting.plot_latlng_grid(
            field_matrix=radar_matrix,
            field_name=radar_field_name,
            axes_object=axes_object,
            min_grid_point_latitude_deg=numpy.min(radar_latitudes_deg),
            min_grid_point_longitude_deg=numpy.min(radar_longitudes_deg),
            latitude_spacing_deg=latitude_spacing_deg,
            longitude_spacing_deg=longitude_spacing_deg,
            colour_map_object=colour_map_object,
            colour_norm_object=colour_norm_object)

        latitude_range_deg = max_plot_latitude_deg - min_plot_latitude_deg
        longitude_range_deg = max_plot_longitude_deg - min_plot_longitude_deg

        if latitude_range_deg > longitude_range_deg:
            orientation_string = 'vertical'
        else:
            orientation_string = 'horizontal'

        colour_bar_object = plotting_utils.plot_colour_bar(
            axes_object_or_matrix=axes_object,
            data_matrix=radar_matrix,
            colour_map_object=colour_map_object,
            colour_norm_object=colour_norm_object,
            orientation_string=orientation_string,
            padding=0.05,
            extend_min=radar_field_name in radar_plotting.SHEAR_VORT_DIV_NAMES,
            extend_max=True,
            fraction_of_axis_length=1.)

        radar_field_name_verbose = radar_utils.field_name_to_verbose(
            field_name=radar_field_name, include_units=True)
        radar_field_name_verbose = radar_field_name_verbose.replace(
            'm ASL', 'kft ASL')
        colour_bar_object.set_label(radar_field_name_verbose)

        if custom_colour_map:
            tick_values = colour_bar_object.get_ticks()
            tick_label_strings = ['{0:.1f}'.format(v) for v in tick_values]
            colour_bar_object.set_ticks(tick_values)
            colour_bar_object.set_ticklabels(tick_label_strings)

    valid_time_rows = numpy.where(storm_object_table[
        tracking_utils.VALID_TIME_COLUMN].values == valid_time_unix_sec)[0]

    this_colour = matplotlib.colors.to_rgba(storm_outline_colour,
                                            storm_outline_opacity)

    storm_plotting.plot_storm_outlines(
        storm_object_table=storm_object_table.iloc[valid_time_rows],
        axes_object=axes_object,
        basemap_object=basemap_object,
        line_colour=this_colour)

    if plot_storm_ids:
        storm_plotting.plot_storm_ids(
            storm_object_table=storm_object_table.iloc[valid_time_rows],
            axes_object=axes_object,
            basemap_object=basemap_object,
            plot_near_centroids=False,
            include_secondary_ids=include_secondary_ids,
            font_colour=storm_plotting.DEFAULT_FONT_COLOUR)

    if warning_table is not None:
        warning_indices = numpy.where(
            numpy.logical_and(
                warning_table[WARNING_START_TIME_KEY].values <=
                valid_time_unix_sec, warning_table[WARNING_END_TIME_KEY].values
                >= valid_time_unix_sec))[0]

        for k in warning_indices:
            this_vertex_dict = polygons.polygon_object_to_vertex_arrays(
                warning_table[WARNING_LATLNG_POLYGON_KEY].values[k])
            these_latitudes_deg = this_vertex_dict[polygons.EXTERIOR_Y_COLUMN]
            these_longitudes_deg = this_vertex_dict[polygons.EXTERIOR_X_COLUMN]

            these_latitude_flags = numpy.logical_and(
                these_latitudes_deg >= min_plot_latitude_deg,
                these_latitudes_deg <= max_plot_latitude_deg)
            these_longitude_flags = numpy.logical_and(
                these_longitudes_deg >= min_plot_longitude_deg,
                these_longitudes_deg <= max_plot_longitude_deg)
            these_coord_flags = numpy.logical_and(these_latitude_flags,
                                                  these_longitude_flags)

            if not numpy.any(these_coord_flags):
                continue

            these_x_metres, these_y_metres = basemap_object(
                these_longitudes_deg, these_latitudes_deg)
            axes_object.plot(these_x_metres,
                             these_y_metres,
                             color=this_colour,
                             linestyle='dashed',
                             linewidth=storm_plotting.DEFAULT_POLYGON_WIDTH)

            axes_object.text(numpy.mean(these_x_metres),
                             numpy.mean(these_y_metres),
                             'W{0:d}'.format(k),
                             fontsize=storm_plotting.DEFAULT_FONT_SIZE,
                             fontweight='bold',
                             color=this_colour,
                             horizontalalignment='center',
                             verticalalignment='center')

            these_sec_id_strings = (
                warning_table[LINKED_SECONDARY_IDS_KEY].values[k])
            if len(these_sec_id_strings) == 0:
                continue

            these_object_indices = numpy.array([], dtype=int)

            for this_sec_id_string in these_sec_id_strings:
                these_subindices = numpy.where(
                    storm_object_table[tracking_utils.SECONDARY_ID_COLUMN].
                    values[valid_time_rows] == this_sec_id_string)[0]

                these_object_indices = numpy.concatenate(
                    (these_object_indices, valid_time_rows[these_subindices]))

            for i in these_object_indices:
                this_vertex_dict = polygons.polygon_object_to_vertex_arrays(
                    storm_object_table[
                        tracking_utils.LATLNG_POLYGON_COLUMN].values[i])

                these_x_metres, these_y_metres = basemap_object(
                    this_vertex_dict[polygons.EXTERIOR_X_COLUMN],
                    this_vertex_dict[polygons.EXTERIOR_Y_COLUMN])

                axes_object.text(numpy.mean(these_x_metres),
                                 numpy.mean(these_y_metres),
                                 'W{0:d}'.format(k),
                                 fontsize=storm_plotting.DEFAULT_FONT_SIZE,
                                 fontweight='bold',
                                 color=this_colour,
                                 horizontalalignment='center',
                                 verticalalignment='center')

    if primary_id_to_track_colour is None:
        storm_plotting.plot_storm_tracks(storm_object_table=storm_object_table,
                                         axes_object=axes_object,
                                         basemap_object=basemap_object,
                                         colour_map_object=None,
                                         constant_colour=DEFAULT_TRACK_COLOUR)
    else:
        for this_primary_id_string in primary_id_to_track_colour:
            this_storm_object_table = storm_object_table.loc[
                storm_object_table[tracking_utils.PRIMARY_ID_COLUMN] ==
                this_primary_id_string]

            if len(this_storm_object_table.index) == 0:
                continue

            storm_plotting.plot_storm_tracks(
                storm_object_table=this_storm_object_table,
                axes_object=axes_object,
                basemap_object=basemap_object,
                colour_map_object=None,
                constant_colour=primary_id_to_track_colour[
                    this_primary_id_string])

    nice_time_string = time_conversion.unix_sec_to_string(
        valid_time_unix_sec, NICE_TIME_FORMAT)

    abbrev_time_string = time_conversion.unix_sec_to_string(
        valid_time_unix_sec, FILE_NAME_TIME_FORMAT)

    pyplot.title('Storm objects at {0:s}'.format(nice_time_string))
    output_file_name = '{0:s}/storm_outlines_{1:s}.jpg'.format(
        output_dir_name, abbrev_time_string)

    print('Saving figure to: "{0:s}"...'.format(output_file_name))
    pyplot.savefig(output_file_name,
                   dpi=FIGURE_RESOLUTION_DPI,
                   pad_inches=0,
                   bbox_inches='tight')
    pyplot.close()
示例#17
0
def _plot_3d_radar(training_option_dict,
                   output_dir_name,
                   pmm_flag,
                   diff_colour_map_object=None,
                   max_colour_percentile_for_diff=None,
                   full_id_strings=None,
                   storm_time_strings=None,
                   novel_radar_matrix=None,
                   novel_radar_matrix_upconv=None,
                   novel_radar_matrix_upconv_svd=None):
    """Plots results of novelty detection for 3-D radar fields.

    E = number of examples (storm objects)
    M = number of rows in spatial grid
    N = number of columns in spatial grid
    H = number of heights in spatial grid
    F = number of fields

    If `novel_radar_matrix` is the only matrix given, this method will plot the
    original (not reconstructed) radar fields.

    If `novel_radar_matrix_upconv` is the only matrix given, will plot
    upconvnet-reconstructed fields.

    If `novel_radar_matrix_upconv_svd` is the only matrix given, will plot
    upconvnet-and-SVD-reconstructed fields.

    If both `novel_radar_matrix_upconv` and `novel_radar_matrix_upconv_svd` are
    given, will plot novelty fields (upconvnet/SVD reconstruction minus
    upconvnet reconstruction).

    :param training_option_dict: See doc for `cnn.read_model_metadata`.
    :param output_dir_name: Name of output directory (figures will be saved
        here).
    :param pmm_flag: Boolean flag.  If True, the input matrices contain
        probability-matched means.
    :param diff_colour_map_object:
        [used only if both `novel_radar_matrix_upconv` and
        `novel_radar_matrix_upconv_svd` are given]

        See documentation at top of file.

    :param max_colour_percentile_for_diff: Same.
    :param full_id_strings: [optional and used only if `pmm_flag = False`]
        length-E list of full storm IDs.
    :param storm_time_strings: [optional and used only if `pmm_flag = False`]
        length-E list of storm times.
    :param novel_radar_matrix: E-by-M-by-N-by-H-by-F numpy array of original
        (not reconstructed) radar fields.
    :param novel_radar_matrix_upconv: E-by-M-by-N-by-H-by-F numpy array of
        upconvnet-reconstructed radar fields.
    :param novel_radar_matrix_upconv_svd: E-by-M-by-N-by-H-by-F numpy array of
        upconvnet-and-SVD-reconstructed radar fields.
    """

    if pmm_flag:
        have_storm_ids = False
    else:
        have_storm_ids = not (full_id_strings is None
                              or storm_time_strings is None)

    plot_difference = False

    if novel_radar_matrix is not None:
        plot_type_abbrev = 'actual'
        plot_type_verbose = 'actual'
        radar_matrix_to_plot = novel_radar_matrix
    else:
        if (novel_radar_matrix_upconv is not None
                and novel_radar_matrix_upconv_svd is not None):

            plot_difference = True
            plot_type_abbrev = 'novelty'
            plot_type_verbose = 'novelty'
            radar_matrix_to_plot = (novel_radar_matrix_upconv -
                                    novel_radar_matrix_upconv_svd)

        else:
            if novel_radar_matrix_upconv is not None:
                plot_type_abbrev = 'upconv'
                plot_type_verbose = 'upconvnet reconstruction'
                radar_matrix_to_plot = novel_radar_matrix_upconv
            else:
                plot_type_abbrev = 'upconv-svd'
                plot_type_verbose = 'upconvnet/SVD reconstruction'
                radar_matrix_to_plot = novel_radar_matrix_upconv_svd

    radar_field_names = training_option_dict[trainval_io.RADAR_FIELDS_KEY]
    radar_heights_m_agl = training_option_dict[trainval_io.RADAR_HEIGHTS_KEY]

    num_storms = novel_radar_matrix.shape[0]
    num_heights = novel_radar_matrix.shape[-2]
    num_panel_rows = int(numpy.floor(numpy.sqrt(num_heights)))

    for i in range(num_storms):
        if pmm_flag:
            this_title_string = 'Probability-matched mean'
            this_base_file_name = 'pmm'
        else:
            if have_storm_ids:
                this_title_string = 'Storm "{0:s}" at {1:s}'.format(
                    full_id_strings[i], storm_time_strings[i])

                this_base_file_name = '{0:s}_{1:s}'.format(
                    full_id_strings[i].replace('_', '-'),
                    storm_time_strings[i])
            else:
                this_title_string = 'Example {0:d}'.format(i + 1)
                this_base_file_name = 'example{0:06d}'.format(i)

        this_title_string += ' ({0:s})'.format(plot_type_verbose)

        for j in range(len(radar_field_names)):
            this_file_name = '{0:s}/{1:s}_{2:s}_{3:s}.jpg'.format(
                output_dir_name, this_base_file_name, plot_type_abbrev,
                radar_field_names[j].replace('_', '-'))

            if plot_difference:
                this_colour_map_object = diff_colour_map_object

                this_max_value = numpy.percentile(
                    numpy.absolute(radar_matrix_to_plot[i, ..., j]),
                    max_colour_percentile_for_diff)

                this_colour_norm_object = matplotlib.colors.Normalize(
                    vmin=-1 * this_max_value, vmax=this_max_value, clip=False)
            else:
                this_colour_map_object, this_colour_norm_object = (
                    radar_plotting.get_default_colour_scheme(
                        radar_field_names[j]))

            _, this_axes_object_matrix = (
                radar_plotting.plot_3d_grid_without_coords(
                    field_matrix=numpy.flip(radar_matrix_to_plot[i, ..., j],
                                            axis=0),
                    field_name=radar_field_names[j],
                    grid_point_heights_metres=radar_heights_m_agl,
                    ground_relative=True,
                    num_panel_rows=num_panel_rows,
                    font_size=FONT_SIZE_SANS_COLOUR_BARS,
                    colour_map_object=this_colour_map_object,
                    colour_norm_object=this_colour_norm_object))

            plotting_utils.plot_colour_bar(
                axes_object_or_matrix=this_axes_object_matrix,
                data_matrix=radar_matrix_to_plot[i, ..., j],
                colour_map_object=this_colour_map_object,
                colour_norm_object=this_colour_norm_object,
                orientation_string='horizontal',
                extend_min=True,
                extend_max=True)

            pyplot.suptitle(this_title_string, fontsize=TITLE_FONT_SIZE)
            print('Saving figure to: "{0:s}"...'.format(this_file_name))
            pyplot.savefig(this_file_name, dpi=FIGURE_RESOLUTION_DPI)
            pyplot.close()
def _plot_storm_outlines_one_time(storm_object_table,
                                  valid_time_unix_sec,
                                  axes_object,
                                  basemap_object,
                                  storm_colour,
                                  storm_opacity,
                                  include_secondary_ids,
                                  output_dir_name,
                                  radar_matrix=None,
                                  radar_field_name=None,
                                  radar_latitudes_deg=None,
                                  radar_longitudes_deg=None):
    """Plots storm outlines (and may underlay radar data) at one time step.

    M = number of rows in radar grid
    N = number of columns in radar grid
    K = number of storm objects

    :param storm_object_table: See doc for `storm_plotting.plot_storm_outlines`.
    :param valid_time_unix_sec: Will plot storm outlines only at this time.
        Will plot tracks up to and including this time.
    :param axes_object: Same.
    :param basemap_object: Same.
    :param storm_colour: Same.
    :param storm_opacity: Same.
    :param include_secondary_ids: Same.
    :param output_dir_name: See documentation at top of file.
    :param radar_matrix: M-by-N numpy array of radar values.  If
        `radar_matrix is None`, radar data will simply not be plotted.
    :param radar_field_name: [used only if `radar_matrix is not None`]
        See documentation at top of file.
    :param radar_latitudes_deg: [used only if `radar_matrix is not None`]
        length-M numpy array of grid-point latitudes (deg N).
    :param radar_longitudes_deg: [used only if `radar_matrix is not None`]
        length-N numpy array of grid-point longitudes (deg E).
    """

    min_plot_latitude_deg = basemap_object.llcrnrlat
    max_plot_latitude_deg = basemap_object.urcrnrlat
    min_plot_longitude_deg = basemap_object.llcrnrlon
    max_plot_longitude_deg = basemap_object.urcrnrlon

    plotting_utils.plot_coastlines(basemap_object=basemap_object,
                                   axes_object=axes_object,
                                   line_colour=BORDER_COLOUR)

    plotting_utils.plot_countries(basemap_object=basemap_object,
                                  axes_object=axes_object,
                                  line_colour=BORDER_COLOUR)

    plotting_utils.plot_states_and_provinces(basemap_object=basemap_object,
                                             axes_object=axes_object,
                                             line_colour=BORDER_COLOUR)

    plotting_utils.plot_parallels(basemap_object=basemap_object,
                                  axes_object=axes_object,
                                  num_parallels=NUM_PARALLELS)

    plotting_utils.plot_meridians(basemap_object=basemap_object,
                                  axes_object=axes_object,
                                  num_meridians=NUM_MERIDIANS)

    if radar_matrix is not None:
        good_indices = numpy.where(
            numpy.logical_and(radar_latitudes_deg >= min_plot_latitude_deg,
                              radar_latitudes_deg <= max_plot_latitude_deg))[0]

        radar_latitudes_deg = radar_latitudes_deg[good_indices]
        radar_matrix = radar_matrix[good_indices, :]

        good_indices = numpy.where(
            numpy.logical_and(
                radar_longitudes_deg >= min_plot_longitude_deg,
                radar_longitudes_deg <= max_plot_longitude_deg))[0]

        radar_longitudes_deg = radar_longitudes_deg[good_indices]
        radar_matrix = radar_matrix[:, good_indices]

        latitude_spacing_deg = radar_latitudes_deg[1] - radar_latitudes_deg[0]
        longitude_spacing_deg = (radar_longitudes_deg[1] -
                                 radar_longitudes_deg[0])

        radar_plotting.plot_latlng_grid(
            field_matrix=radar_matrix,
            field_name=radar_field_name,
            axes_object=axes_object,
            min_grid_point_latitude_deg=numpy.min(radar_latitudes_deg),
            min_grid_point_longitude_deg=numpy.min(radar_longitudes_deg),
            latitude_spacing_deg=latitude_spacing_deg,
            longitude_spacing_deg=longitude_spacing_deg)

        colour_map_object, colour_norm_object = (
            radar_plotting.get_default_colour_scheme(radar_field_name))

        latitude_range_deg = max_plot_latitude_deg - min_plot_latitude_deg
        longitude_range_deg = max_plot_longitude_deg - min_plot_longitude_deg

        if latitude_range_deg > longitude_range_deg:
            orientation_string = 'vertical'
        else:
            orientation_string = 'horizontal'

        colour_bar_object = plotting_utils.plot_colour_bar(
            axes_object_or_matrix=axes_object,
            data_matrix=radar_matrix,
            colour_map_object=colour_map_object,
            colour_norm_object=colour_norm_object,
            orientation_string=orientation_string,
            extend_min=radar_field_name in radar_plotting.SHEAR_VORT_DIV_NAMES,
            extend_max=True,
            fraction_of_axis_length=0.9)

        colour_bar_object.set_label(
            radar_plotting.FIELD_NAME_TO_VERBOSE_DICT[radar_field_name])

    valid_time_rows = numpy.where(storm_object_table[
        tracking_utils.VALID_TIME_COLUMN].values == valid_time_unix_sec)[0]

    line_colour = matplotlib.colors.to_rgba(storm_colour, storm_opacity)

    storm_plotting.plot_storm_outlines(
        storm_object_table=storm_object_table.iloc[valid_time_rows],
        axes_object=axes_object,
        basemap_object=basemap_object,
        line_colour=line_colour)

    storm_plotting.plot_storm_ids(
        storm_object_table=storm_object_table.iloc[valid_time_rows],
        axes_object=axes_object,
        basemap_object=basemap_object,
        plot_near_centroids=False,
        include_secondary_ids=include_secondary_ids,
        font_colour=storm_plotting.DEFAULT_FONT_COLOUR)

    storm_plotting.plot_storm_tracks(storm_object_table=storm_object_table,
                                     axes_object=axes_object,
                                     basemap_object=basemap_object,
                                     colour_map_object=None,
                                     line_colour=TRACK_COLOUR)

    nice_time_string = time_conversion.unix_sec_to_string(
        valid_time_unix_sec, NICE_TIME_FORMAT)

    abbrev_time_string = time_conversion.unix_sec_to_string(
        valid_time_unix_sec, FILE_NAME_TIME_FORMAT)

    pyplot.title('Storm objects at {0:s}'.format(nice_time_string))
    output_file_name = '{0:s}/storm_outlines_{1:s}.jpg'.format(
        output_dir_name, abbrev_time_string)

    print('Saving figure to: "{0:s}"...'.format(output_file_name))
    pyplot.savefig(output_file_name, dpi=FIGURE_RESOLUTION_DPI)
    pyplot.close()

    imagemagick_utils.trim_whitespace(input_file_name=output_file_name,
                                      output_file_name=output_file_name)
示例#19
0
def plot_performance_diagram(axes_object,
                             pod_by_threshold,
                             success_ratio_by_threshold,
                             line_colour=DEFAULT_PERFORMANCE_COLOUR,
                             line_width=DEFAULT_PERFORMANCE_WIDTH,
                             bias_line_colour=DEFAULT_FREQ_BIAS_COLOUR,
                             bias_line_width=DEFAULT_FREQ_BIAS_WIDTH):
    """Plots performance diagram.

    T = number of binarization thresholds

    For the definition of a "binarization threshold" and the role they play in
    performance diagrams, see
    `model_evaluation.get_points_in_performance_diagram`.

    :param axes_object: Instance of `matplotlib.axes._subplots.AxesSubplot`.
    :param pod_by_threshold: length-T numpy array of POD (probability of
        detection) values.
    :param success_ratio_by_threshold: length-T numpy array of success ratios.
    :param line_colour: Colour (in any format accepted by `matplotlib.colors`).
    :param line_width: Line width (real positive number).
    :param bias_line_colour: Colour of contour lines for frequency bias.
    :param bias_line_width: Width of contour lines for frequency bias.
    """

    error_checking.assert_is_numpy_array(pod_by_threshold, num_dimensions=1)
    error_checking.assert_is_geq_numpy_array(pod_by_threshold,
                                             0.,
                                             allow_nan=True)
    error_checking.assert_is_leq_numpy_array(pod_by_threshold,
                                             1.,
                                             allow_nan=True)
    num_thresholds = len(pod_by_threshold)

    error_checking.assert_is_numpy_array(success_ratio_by_threshold,
                                         exact_dimensions=numpy.array(
                                             [num_thresholds]))
    error_checking.assert_is_geq_numpy_array(success_ratio_by_threshold,
                                             0.,
                                             allow_nan=True)
    error_checking.assert_is_leq_numpy_array(success_ratio_by_threshold,
                                             1.,
                                             allow_nan=True)

    success_ratio_matrix, pod_matrix = model_eval.get_sr_pod_grid()
    csi_matrix = model_eval.csi_from_sr_and_pod(success_ratio_matrix,
                                                pod_matrix)
    frequency_bias_matrix = model_eval.frequency_bias_from_sr_and_pod(
        success_ratio_matrix, pod_matrix)

    this_colour_map_object, this_colour_norm_object = _get_csi_colour_scheme()

    pyplot.contourf(success_ratio_matrix,
                    pod_matrix,
                    csi_matrix,
                    LEVELS_FOR_CSI_CONTOURS,
                    cmap=this_colour_map_object,
                    norm=this_colour_norm_object,
                    vmin=0.,
                    vmax=1.,
                    axes=axes_object)

    colour_bar_object = plotting_utils.plot_colour_bar(
        axes_object_or_matrix=axes_object,
        data_matrix=csi_matrix,
        colour_map_object=this_colour_map_object,
        colour_norm_object=this_colour_norm_object,
        orientation_string='vertical',
        extend_min=False,
        extend_max=False)

    colour_bar_object.set_label('CSI (critical success index)')

    bias_colour_tuple = plotting_utils.colour_from_numpy_to_tuple(
        bias_line_colour)

    bias_colours_2d_tuple = ()
    for _ in range(len(LEVELS_FOR_FREQ_BIAS_CONTOURS)):
        bias_colours_2d_tuple += (bias_colour_tuple, )

    bias_contour_object = pyplot.contour(success_ratio_matrix,
                                         pod_matrix,
                                         frequency_bias_matrix,
                                         LEVELS_FOR_FREQ_BIAS_CONTOURS,
                                         colors=bias_colours_2d_tuple,
                                         linewidths=bias_line_width,
                                         linestyles='dashed',
                                         axes=axes_object)

    pyplot.clabel(bias_contour_object,
                  inline=True,
                  inline_spacing=PIXEL_PADDING_FOR_FREQ_BIAS_LABELS,
                  fmt=STRING_FORMAT_FOR_FREQ_BIAS_LABELS,
                  fontsize=FONT_SIZE)

    nan_flags = numpy.logical_or(numpy.isnan(success_ratio_by_threshold),
                                 numpy.isnan(pod_by_threshold))

    if not numpy.all(nan_flags):
        real_indices = numpy.where(numpy.invert(nan_flags))[0]

        axes_object.plot(
            success_ratio_by_threshold[real_indices],
            pod_by_threshold[real_indices],
            color=plotting_utils.colour_from_numpy_to_tuple(line_colour),
            linestyle='solid',
            linewidth=line_width)

    axes_object.set_xlabel('Success ratio (1 - FAR)')
    axes_object.set_ylabel('POD (probability of detection)')
    axes_object.set_xlim(0., 1.)
    axes_object.set_ylim(0., 1.)
def _plot_echo_tops(echo_top_matrix_km_asl, latitudes_deg, longitudes_deg,
                    plot_colour_bar, convective_flag_matrix=None):
    """Plots grid of 40-dBZ echo tops.

    M = number of rows in grid
    N = number of columns in grid

    :param echo_top_matrix_km_asl: M-by-N numpy array of echo tops (km above sea
        level).
    :param latitudes_deg: length-M numpy array of latitudes (deg N).
    :param longitudes_deg: length-N numpy array of longitudes (deg E).
    :param plot_colour_bar: Boolean flag.
    :param convective_flag_matrix: M-by-N numpy array of Boolean flags,
        indicating which grid cells are convective.  If
        `convective_flag_matrix is None`, all grid cells will be plotted.  If
        `convective_flag_matrix is not None`, only convective grid cells will be
        plotted.
    :return: figure_object: Figure handle (instance of
        `matplotlib.figure.Figure`).
    :return: axes_object: Axes handle (instance of
        `matplotlib.axes._subplots.AxesSubplot`).
    :return: basemap_object: Basemap handle (instance of
        `mpl_toolkits.basemap.Basemap`).
    """

    figure_object, axes_object, basemap_object = (
        plotting_utils.create_equidist_cylindrical_map(
            min_latitude_deg=numpy.min(latitudes_deg),
            max_latitude_deg=numpy.max(latitudes_deg),
            min_longitude_deg=numpy.min(longitudes_deg),
            max_longitude_deg=numpy.max(longitudes_deg), resolution_string='h'
        )
    )

    # plotting_utils.plot_coastlines(
    #     basemap_object=basemap_object, axes_object=axes_object,
    #     line_colour=plotting_utils.DEFAULT_COUNTRY_COLOUR
    # )
    plotting_utils.plot_countries(
        basemap_object=basemap_object, axes_object=axes_object
    )
    plotting_utils.plot_states_and_provinces(
        basemap_object=basemap_object, axes_object=axes_object
    )
    plotting_utils.plot_parallels(
        basemap_object=basemap_object, axes_object=axes_object,
        num_parallels=NUM_PARALLELS, line_width=0
    )
    plotting_utils.plot_meridians(
        basemap_object=basemap_object, axes_object=axes_object,
        num_meridians=NUM_MERIDIANS, line_width=0
    )

    matrix_to_plot = echo_top_matrix_km_asl + 0.
    if convective_flag_matrix is not None:
        matrix_to_plot[convective_flag_matrix == False] = numpy.nan

    radar_plotting.plot_latlng_grid(
        field_matrix=matrix_to_plot, field_name=radar_utils.ECHO_TOP_40DBZ_NAME,
        axes_object=axes_object,
        min_grid_point_latitude_deg=numpy.min(latitudes_deg),
        min_grid_point_longitude_deg=numpy.min(longitudes_deg),
        latitude_spacing_deg=numpy.diff(latitudes_deg[:2])[0],
        longitude_spacing_deg=numpy.diff(longitudes_deg[:2])[0]
    )

    if not plot_colour_bar:
        return figure_object, axes_object, basemap_object

    colour_map_object, colour_norm_object = (
        radar_plotting.get_default_colour_scheme(
            radar_utils.ECHO_TOP_40DBZ_NAME)
    )

    colour_bar_object = plotting_utils.plot_colour_bar(
        axes_object_or_matrix=axes_object, data_matrix=matrix_to_plot,
        colour_map_object=colour_map_object,
        colour_norm_object=colour_norm_object, orientation_string='horizontal',
        extend_min=False, extend_max=True, fraction_of_axis_length=1.
    )

    colour_bar_object.set_label('40-dBZ echo top (kft ASL)')

    return figure_object, axes_object, basemap_object