def _plot_saliency_for_2d3d_radar(
        list_of_input_matrices, list_of_saliency_matrices,
        training_option_dict, saliency_colour_map_object,
        max_colour_value_by_example, output_dir_name, storm_ids=None,
        storm_times_unix_sec=None):
    """Plots saliency for 2-D azimuthal-shear and 3-D reflectivity fields.

    E = number of examples (storm objects)

    If `storm_ids is None` and `storm_times_unix_sec is None`, will assume that
    the input matrices contain probability-matched means.

    :param list_of_input_matrices: See doc for
        `saliency_maps.read_standard_file`.
    :param list_of_saliency_matrices: Same.
    :param training_option_dict: Dictionary returned by
        `cnn.read_model_metadata`.
    :param saliency_colour_map_object: See documentation at top of file.
    :param max_colour_value_by_example: length-E numpy array with max value in
        colour scheme for each example.  Minimum value for [i]th example will be
        -1 * max_colour_value_by_example[i], since the colour scheme is
        zero-centered and divergent.
    :param output_dir_name: Name of output directory (figures will be saved
        here).
    :param storm_ids: length-E list of storm IDs (strings).
    :param storm_times_unix_sec: length-E numpy array of storm times.
    """

    pmm_flag = storm_ids is None and storm_times_unix_sec is None

    reflectivity_matrix_dbz = list_of_input_matrices[0]
    reflectivity_saliency_matrix = list_of_saliency_matrices[0]
    az_shear_matrix_s01 = list_of_input_matrices[1]
    az_shear_saliency_matrix = list_of_saliency_matrices[1]

    num_examples = reflectivity_matrix_dbz.shape[0]
    num_reflectivity_heights = len(
        training_option_dict[trainval_io.RADAR_HEIGHTS_KEY]
    )
    num_panel_rows_for_reflectivity = int(numpy.floor(
        numpy.sqrt(num_reflectivity_heights)
    ))

    az_shear_field_names = training_option_dict[trainval_io.RADAR_FIELDS_KEY]
    num_az_shear_fields = len(az_shear_field_names)
    plot_colour_bar_flags = numpy.full(num_az_shear_fields, False, dtype=bool)

    for i in range(num_examples):
        _, these_axes_objects = radar_plotting.plot_3d_grid_without_coords(
            field_matrix=numpy.flip(reflectivity_matrix_dbz[i, ..., 0], axis=0),
            field_name=radar_utils.REFL_NAME,
            grid_point_heights_metres=training_option_dict[
                trainval_io.RADAR_HEIGHTS_KEY],
            ground_relative=True,
            num_panel_rows=num_panel_rows_for_reflectivity,
            font_size=FONT_SIZE_SANS_COLOUR_BARS)

        saliency_plotting.plot_many_2d_grids_with_pm_signs(
            saliency_matrix_3d=numpy.flip(
                reflectivity_saliency_matrix[i, ..., 0], axis=0),
            axes_objects_2d_list=these_axes_objects,
            colour_map_object=saliency_colour_map_object,
            max_absolute_colour_value=max_colour_value_by_example[i])

        this_colour_map_object, this_colour_norm_object = (
            radar_plotting.get_default_colour_scheme(radar_utils.REFL_NAME)
        )

        plotting_utils.add_colour_bar(
            axes_object_or_list=these_axes_objects,
            values_to_colour=reflectivity_matrix_dbz[i, ..., 0],
            colour_map=this_colour_map_object,
            colour_norm_object=this_colour_norm_object,
            orientation='horizontal', extend_min=True, extend_max=True)

        if pmm_flag:
            this_title_string = 'Probability-matched mean'
            this_file_name = '{0:s}/saliency_pmm_reflectivity.jpg'.format(
                output_dir_name)
        else:
            this_storm_time_string = time_conversion.unix_sec_to_string(
                storm_times_unix_sec[i], TIME_FORMAT)

            this_title_string = 'Storm "{0:s}" at {1:s}'.format(
                storm_ids[i], this_storm_time_string)

            this_file_name = (
                '{0:s}/saliency_{1:s}_{2:s}_reflectivity.jpg'
            ).format(
                output_dir_name, storm_ids[i].replace('_', '-'),
                this_storm_time_string
            )

        this_title_string += ' (max absolute saliency = {0:.3f})'.format(
            max_colour_value_by_example[i])
        pyplot.suptitle(this_title_string, fontsize=TITLE_FONT_SIZE)

        print 'Saving figure to file: "{0:s}"...'.format(this_file_name)
        pyplot.savefig(this_file_name, dpi=FIGURE_RESOLUTION_DPI)
        pyplot.close()

        _, these_axes_objects = (
            radar_plotting.plot_many_2d_grids_without_coords(
                field_matrix=numpy.flip(az_shear_matrix_s01[i, ...], axis=0),
                field_name_by_panel=az_shear_field_names,
                panel_names=az_shear_field_names, num_panel_rows=1,
                plot_colour_bar_by_panel=plot_colour_bar_flags,
                font_size=FONT_SIZE_SANS_COLOUR_BARS)
        )

        saliency_plotting.plot_many_2d_grids_with_pm_signs(
            saliency_matrix_3d=numpy.flip(
                az_shear_saliency_matrix[i, ...], axis=0),
            axes_objects_2d_list=these_axes_objects,
            colour_map_object=saliency_colour_map_object,
            max_absolute_colour_value=max_colour_value_by_example[i])

        this_colour_map_object, this_colour_norm_object = (
            radar_plotting.get_default_colour_scheme(
                radar_utils.LOW_LEVEL_SHEAR_NAME)
        )

        plotting_utils.add_colour_bar(
            axes_object_or_list=these_axes_objects,
            values_to_colour=az_shear_saliency_matrix[i, ...],
            colour_map=this_colour_map_object,
            colour_norm_object=this_colour_norm_object,
            orientation='horizontal', extend_min=True, extend_max=True)

        pyplot.suptitle(this_title_string, fontsize=TITLE_FONT_SIZE)
        this_file_name = this_file_name.replace(
            '_reflectivity.jpg', '_azimuthal-shear.jpg')

        print 'Saving figure to file: "{0:s}"...'.format(this_file_name)
        pyplot.savefig(this_file_name, dpi=FIGURE_RESOLUTION_DPI)
        pyplot.close()
예제 #2
0
def _plot_2d_radar(model_metadata_dict,
                   output_dir_name,
                   pmm_flag,
                   diff_colour_map_object=None,
                   max_colour_percentile_for_diff=None,
                   full_id_strings=None,
                   storm_time_strings=None,
                   novel_radar_matrix=None,
                   novel_radar_matrix_upconv=None,
                   novel_radar_matrix_upconv_svd=None):
    """Plots results of novelty detection for 2-D radar fields.

    E = number of examples (storm objects)
    M = number of rows in spatial grid
    N = number of columns in spatial grid
    C = number of channels (field/height pairs)

    This method handles the 3 input matrices in the same way as
    `_plot_3d_radar`.

    :param model_metadata_dict: Dictionary returned by
        `cnn.read_model_metadata`.
    :param output_dir_name: Same.
    :param pmm_flag: Same.
    :param diff_colour_map_object: Same.
    :param max_colour_percentile_for_diff: Same.
    :param full_id_strings: Same.
    :param storm_time_strings: Same.
    :param novel_radar_matrix: E-by-M-by-N-by-C numpy array of original
        (not reconstructed) radar fields.
    :param novel_radar_matrix_upconv: E-by-M-by-N-by-C numpy array of
        upconvnet-reconstructed radar fields.
    :param novel_radar_matrix_upconv_svd: E-by-M-by-N-by-C numpy array of
        upconvnet-and-SVD-reconstructed radar fields.
    """

    training_option_dict = model_metadata_dict[cnn.TRAINING_OPTION_DICT_KEY]
    list_of_layer_operation_dicts = model_metadata_dict[
        cnn.LAYER_OPERATIONS_KEY]

    if pmm_flag:
        have_storm_ids = False
    else:
        have_storm_ids = not (full_id_strings is None
                              or storm_time_strings is None)

    plot_difference = False

    if novel_radar_matrix is not None:
        plot_type_abbrev = 'actual'
        plot_type_verbose = 'actual'
        radar_matrix_to_plot = novel_radar_matrix
    else:
        if (novel_radar_matrix_upconv is not None
                and novel_radar_matrix_upconv_svd is not None):

            plot_difference = True
            plot_type_abbrev = 'novelty'
            plot_type_verbose = 'novelty'
            radar_matrix_to_plot = (novel_radar_matrix_upconv -
                                    novel_radar_matrix_upconv_svd)

        else:
            if novel_radar_matrix_upconv is not None:
                plot_type_abbrev = 'upconv'
                plot_type_verbose = 'upconvnet reconstruction'
                radar_matrix_to_plot = novel_radar_matrix_upconv
            else:
                plot_type_abbrev = 'upconv-svd'
                plot_type_verbose = 'upconvnet/SVD reconstruction'
                radar_matrix_to_plot = novel_radar_matrix_upconv_svd

    if list_of_layer_operation_dicts is None:
        field_name_by_panel = training_option_dict[
            trainval_io.RADAR_FIELDS_KEY]

        panel_names = (radar_plotting.radar_fields_and_heights_to_panel_names(
            field_names=field_name_by_panel,
            heights_m_agl=training_option_dict[trainval_io.RADAR_HEIGHTS_KEY]))

        plot_colour_bar_by_panel = numpy.full(len(panel_names),
                                              True,
                                              dtype=bool)

    else:
        field_name_by_panel, panel_names = (
            radar_plotting.layer_ops_to_field_and_panel_names(
                list_of_layer_operation_dicts))

        plot_colour_bar_by_panel = numpy.full(len(panel_names),
                                              False,
                                              dtype=bool)
        plot_colour_bar_by_panel[2::3] = True

    num_panels = len(panel_names)
    num_storms = radar_matrix_to_plot.shape[0]
    num_channels = radar_matrix_to_plot.shape[-1]
    num_panel_rows = int(numpy.floor(numpy.sqrt(num_channels)))

    for i in range(num_storms):
        if pmm_flag:
            this_title_string = 'Probability-matched mean'
            this_file_name = 'pmm'
        else:
            if have_storm_ids:
                this_title_string = 'Storm "{0:s}" at {1:s}'.format(
                    full_id_strings[i], storm_time_strings[i])

                this_file_name = '{0:s}_{1:s}'.format(
                    full_id_strings[i].replace('_', '-'),
                    storm_time_strings[i])
            else:
                this_title_string = 'Example {0:d}'.format(i + 1)
                this_file_name = 'example{0:06d}'.format(i)

        this_title_string += ' ({0:s})'.format(plot_type_verbose)
        this_file_name = '{0:s}/{1:s}_{2:s}_radar.jpg'.format(
            output_dir_name, this_file_name, plot_type_abbrev)

        if plot_difference:
            this_cmap_object_by_panel = [diff_colour_map_object] * num_panels
            this_cnorm_object_by_panel = [None] * num_panels

            if list_of_layer_operation_dicts is None:
                for j in range(num_panels):
                    this_max_value = numpy.percentile(
                        numpy.absolute(radar_matrix_to_plot[i, ..., j]),
                        max_colour_percentile_for_diff)

                    this_cnorm_object_by_panel[
                        j] = matplotlib.colors.Normalize(vmin=-1 *
                                                         this_max_value,
                                                         vmax=this_max_value,
                                                         clip=False)
            else:
                unique_field_names = numpy.unique(
                    numpy.array(field_name_by_panel))

                for this_field_name in unique_field_names:
                    these_panel_indices = numpy.where(
                        numpy.array(field_name_by_panel) == this_field_name)[0]

                    this_diff_matrix = radar_matrix_to_plot[
                        i, ..., these_panel_indices]

                    this_max_value = numpy.percentile(
                        numpy.absolute(this_diff_matrix),
                        max_colour_percentile_for_diff)

                    for this_index in these_panel_indices:
                        this_cnorm_object_by_panel[this_index] = (
                            matplotlib.colors.Normalize(vmin=-1 *
                                                        this_max_value,
                                                        vmax=this_max_value,
                                                        clip=False))
        else:
            this_cmap_object_by_panel = None
            this_cnorm_object_by_panel = None

        radar_plotting.plot_many_2d_grids_without_coords(
            field_matrix=numpy.flip(radar_matrix_to_plot[i, ...], axis=0),
            field_name_by_panel=field_name_by_panel,
            num_panel_rows=num_panel_rows,
            panel_names=panel_names,
            colour_map_object_by_panel=this_cmap_object_by_panel,
            colour_norm_object_by_panel=this_cnorm_object_by_panel,
            plot_colour_bar_by_panel=plot_colour_bar_by_panel,
            font_size=FONT_SIZE_WITH_COLOUR_BARS,
            row_major=False)

        pyplot.suptitle(this_title_string, fontsize=TITLE_FONT_SIZE)
        print('Saving figure to: "{0:s}"...'.format(this_file_name))
        pyplot.savefig(this_file_name, dpi=FIGURE_RESOLUTION_DPI)
        pyplot.close()
예제 #3
0
def _plot_comparison(
        predictor_matrix, model_metadata_dict, machine_mask_matrix_3d,
        human_mask_matrix_3d, iou_by_channel, positive_flag, output_file_name):
    """Plots comparison between human and machine interpretation maps.

    M = number of rows in grid (physical space)
    N = number of columns in grid (physical space)
    C = number of channels

    :param predictor_matrix: M-by-N-by-C numpy array of predictors.
    :param model_metadata_dict: Dictionary returned by
        `cnn.read_model_metadata`.
    :param machine_mask_matrix_3d: M-by-N-by-C numpy array of Boolean flags,
        indicating where machine interpretation value is strongly positive or
        negative.
    :param human_mask_matrix_3d: Same.
    :param iou_by_channel: length-C numpy array of IoU values (intersection over
        union) between human and machine masks.
    :param positive_flag: Boolean flag.  If True (False), masks indicate where
        interpretation value is strongly positive (negative).
    :param output_file_name: Path to output file (figure will be saved here).
    """

    training_option_dict = model_metadata_dict[cnn.TRAINING_OPTION_DICT_KEY]
    list_of_layer_operation_dicts = model_metadata_dict[
        cnn.LAYER_OPERATIONS_KEY]

    if list_of_layer_operation_dicts is None:
        field_name_by_panel = training_option_dict[trainval_io.RADAR_FIELDS_KEY]

        panel_names = radar_plotting.fields_and_heights_to_names(
            field_names=field_name_by_panel,
            heights_m_agl=training_option_dict[trainval_io.RADAR_HEIGHTS_KEY]
        )

        plot_colour_bar_by_panel = numpy.full(
            len(field_name_by_panel), True, dtype=bool
        )
    else:
        field_name_by_panel, panel_names = (
            radar_plotting.layer_operations_to_names(
                list_of_layer_operation_dicts=list_of_layer_operation_dicts
            )
        )

        plot_colour_bar_by_panel = numpy.full(
            len(field_name_by_panel), False, dtype=bool
        )

        plot_colour_bar_by_panel[2::3] = True

    num_panels = len(field_name_by_panel)
    num_panel_rows = int(numpy.floor(
        numpy.sqrt(num_panels)
    ))

    for k in range(num_panels):
        panel_names[k] += '\n{0:s} IoU = {1:.3f}'.format(
            'Positive' if positive_flag else 'Negative',
            iou_by_channel[k]
        )

    axes_object_matrix = radar_plotting.plot_many_2d_grids_without_coords(
        field_matrix=numpy.flip(predictor_matrix, axis=0),
        field_name_by_panel=field_name_by_panel, panel_names=panel_names,
        num_panel_rows=num_panel_rows,
        plot_colour_bar_by_panel=plot_colour_bar_by_panel, font_size=14,
        row_major=False
    )[1]

    for k in range(num_panels):
        _plot_comparison_one_channel(
            human_mask_matrix_3d=human_mask_matrix_3d,
            machine_mask_matrix_3d=numpy.flip(machine_mask_matrix_3d, axis=0),
            channel_index=k, axes_object_matrix=axes_object_matrix)

    print('Saving figure to: "{0:s}"...'.format(output_file_name))
    pyplot.savefig(output_file_name, dpi=FIGURE_RESOLUTION_DPI,
                   pad_inches=0., bbox_inches='tight')
    pyplot.close()
def _plot_saliency_for_2d_radar(
        radar_matrix, radar_saliency_matrix, model_metadata_dict,
        saliency_colour_map_object, max_colour_value_by_example,
        output_dir_name, storm_ids=None, storm_times_unix_sec=None):
    """Plots saliency for 2-D radar fields.

    E = number of examples
    M = number of rows in spatial grid
    N = number of columns in spatial grid
    C = number of channels (field/height pairs)

    If `storm_ids is None` and `storm_times_unix_sec is None`, will assume that
    the input matrices contain probability-matched means.

    :param radar_matrix: E-by-M-by-N-by-C numpy array of radar values
        (predictors).
    :param radar_saliency_matrix: E-by-M-by-N-by-C numpy array of saliency
        values.
    :param model_metadata_dict: See doc for `cnn.read_model_metadata`.
    :param saliency_colour_map_object: See doc for
        `_plot_saliency_for_2d3d_radar`.
    :param max_colour_value_by_example: Same.
    :param output_dir_name: Same.
    :param storm_ids: Same.
    :param storm_times_unix_sec: Same.
    """

    pmm_flag = storm_ids is None and storm_times_unix_sec is None
    training_option_dict = model_metadata_dict[cnn.TRAINING_OPTION_DICT_KEY]
    list_of_layer_operation_dicts = model_metadata_dict[
        cnn.LAYER_OPERATIONS_KEY]

    if list_of_layer_operation_dicts is None:
        field_name_by_panel = training_option_dict[
            trainval_io.RADAR_FIELDS_KEY]

        panel_names = (
            radar_plotting.radar_fields_and_heights_to_panel_names(
                field_names=field_name_by_panel,
                heights_m_agl=training_option_dict[
                    trainval_io.RADAR_HEIGHTS_KEY]
            )
        )

        plot_colour_bar_by_panel = numpy.full(
            len(panel_names), True, dtype=bool)

    else:
        field_name_by_panel, panel_names = (
            radar_plotting.layer_ops_to_field_and_panel_names(
                list_of_layer_operation_dicts)
        )

        plot_colour_bar_by_panel = numpy.full(
            len(panel_names), False, dtype=bool)
        plot_colour_bar_by_panel[2::3] = True

    num_examples = radar_matrix.shape[0]
    num_panels = len(field_name_by_panel)
    num_panel_rows = int(numpy.floor(numpy.sqrt(num_panels)))

    for i in range(num_examples):
        _, these_axes_objects = (
            radar_plotting.plot_many_2d_grids_without_coords(
                field_matrix=numpy.flip(radar_matrix[i, ...], axis=0),
                field_name_by_panel=field_name_by_panel,
                panel_names=panel_names, num_panel_rows=num_panel_rows,
                plot_colour_bar_by_panel=plot_colour_bar_by_panel,
                font_size=FONT_SIZE_WITH_COLOUR_BARS, row_major=False)
        )

        this_contour_interval = (
            max_colour_value_by_example[i] / HALF_NUM_CONTOURS
        )

        saliency_plotting.plot_many_2d_grids_with_contours(
            saliency_matrix_3d=numpy.flip(
                radar_saliency_matrix[i, ...], axis=0),
            axes_objects_2d_list=these_axes_objects,
            colour_map_object=saliency_colour_map_object,
            max_absolute_contour_level=max_colour_value_by_example[i],
            contour_interval=this_contour_interval, row_major=False)

        if pmm_flag:
            this_title_string = 'Probability-matched mean'
            this_file_name = '{0:s}/saliency_pmm_radar.jpg'.format(
                output_dir_name)
        else:
            this_storm_time_string = time_conversion.unix_sec_to_string(
                storm_times_unix_sec[i], TIME_FORMAT)

            this_title_string = 'Storm "{0:s}" at {1:s}'.format(
                storm_ids[i], this_storm_time_string)

            this_file_name = '{0:s}/saliency_{1:s}_{2:s}_radar.jpg'.format(
                output_dir_name, storm_ids[i].replace('_', '-'),
                this_storm_time_string)

        this_title_string += ' (max absolute saliency = {0:.3f})'.format(
            max_colour_value_by_example[i])
        pyplot.suptitle(this_title_string, fontsize=TITLE_FONT_SIZE)

        print 'Saving figure to file: "{0:s}"...'.format(this_file_name)
        pyplot.savefig(this_file_name, dpi=FIGURE_RESOLUTION_DPI)
        pyplot.close()
예제 #5
0
def _plot_bwo_for_2d3d_radar(
        list_of_optimized_matrices, training_option_dict,
        diff_colour_map_object, max_colour_percentile_for_diff,
        top_output_dir_name, pmm_flag, list_of_input_matrices=None,
        storm_ids=None, storm_times_unix_sec=None):
    """Plots BWO results for 2-D azimuthal-shear and 3-D reflectivity fields.

    E = number of examples (storm objects)
    T = number of input tensors to the model

    :param list_of_optimized_matrices: length-T list of numpy arrays, where the
        [i]th array is the optimized version of the [i]th input matrix to the
        model.
    :param training_option_dict: See doc for `cnn.read_model_metadata`.
    :param diff_colour_map_object: See documentation at top of file.
    :param max_colour_percentile_for_diff: Same.
    :param top_output_dir_name: Path to top-level output directory (figures will
        be saved here).
    :param pmm_flag: Boolean flag.  If True, `list_of_predictor_matrices`
        contains probability-matched means.
    :param list_of_input_matrices: Same as `list_of_optimized_matrices` but with
        non-optimized input matrices.
    :param storm_ids: [optional and used only if `pmm_flag = False`]
        length-E list of storm IDs (strings).
    :param storm_times_unix_sec: [optional and used only if `pmm_flag = False`]
        length-E numpy array of storm times.
    """

    before_optimization_dir_name = '{0:s}/before_optimization'.format(
        top_output_dir_name)
    after_optimization_dir_name = '{0:s}/after_optimization'.format(
        top_output_dir_name)
    difference_dir_name = '{0:s}/after_minus_before_optimization'.format(
        top_output_dir_name)

    file_system_utils.mkdir_recursive_if_necessary(
        directory_name=before_optimization_dir_name)
    file_system_utils.mkdir_recursive_if_necessary(
        directory_name=after_optimization_dir_name)
    file_system_utils.mkdir_recursive_if_necessary(
        directory_name=difference_dir_name)

    if pmm_flag:
        have_storm_ids = False
    else:
        have_storm_ids = not (storm_ids is None or storm_times_unix_sec is None)

    az_shear_field_names = training_option_dict[trainval_io.RADAR_FIELDS_KEY]
    num_az_shear_fields = len(az_shear_field_names)
    plot_colour_bar_flags = numpy.full(num_az_shear_fields, False, dtype=bool)

    num_storms = list_of_optimized_matrices[0].shape[0]

    for i in range(num_storms):
        print '\n'

        if pmm_flag:
            this_base_title_string = 'Probability-matched mean'
            this_base_pathless_file_name = 'pmm'
        else:
            if have_storm_ids:
                this_storm_time_string = time_conversion.unix_sec_to_string(
                    storm_times_unix_sec[i], TIME_FORMAT)

                this_base_title_string = 'Storm "{0:s}" at {1:s}'.format(
                    storm_ids[i], this_storm_time_string)

                this_base_pathless_file_name = '{0:s}_{1:s}'.format(
                    storm_ids[i].replace('_', '-'), this_storm_time_string)

            else:
                this_base_title_string = 'Example {0:d}'.format(i + 1)
                this_base_pathless_file_name = 'example{0:06d}'.format(i)

        this_reflectivity_matrix_dbz = numpy.flip(
            list_of_optimized_matrices[0][i, ..., 0], axis=0)

        this_num_heights = this_reflectivity_matrix_dbz.shape[-1]
        this_num_panel_rows = int(numpy.floor(
            numpy.sqrt(this_num_heights)
        ))

        _, these_axes_objects = radar_plotting.plot_3d_grid_without_coords(
            field_matrix=this_reflectivity_matrix_dbz,
            field_name=radar_utils.REFL_NAME,
            grid_point_heights_metres=training_option_dict[
                trainval_io.RADAR_HEIGHTS_KEY],
            ground_relative=True, num_panel_rows=this_num_panel_rows,
            font_size=FONT_SIZE_SANS_COLOUR_BARS)

        this_colour_map_object, this_colour_norm_object = (
            radar_plotting.get_default_colour_scheme(radar_utils.REFL_NAME)
        )

        plotting_utils.add_colour_bar(
            axes_object_or_list=these_axes_objects,
            values_to_colour=this_reflectivity_matrix_dbz,
            colour_map=this_colour_map_object,
            colour_norm_object=this_colour_norm_object,
            orientation='horizontal', extend_min=True, extend_max=True)

        this_title_string = '{0:s} (after optimization)'.format(
            this_base_title_string)

        this_file_name = (
            '{0:s}/{1:s}_after-optimization_reflectivity.jpg'
        ).format(after_optimization_dir_name, this_base_pathless_file_name)

        pyplot.suptitle(this_title_string, fontsize=TITLE_FONT_SIZE)
        print 'Saving figure to: "{0:s}"...'.format(this_file_name)
        pyplot.savefig(this_file_name, dpi=FIGURE_RESOLUTION_DPI)
        pyplot.close()

        this_az_shear_matrix_s01 = numpy.flip(
            list_of_optimized_matrices[1][i, ..., 0], axis=0)

        _, these_axes_objects = (
            radar_plotting.plot_many_2d_grids_without_coords(
                field_matrix=this_az_shear_matrix_s01,
                field_name_by_panel=az_shear_field_names, num_panel_rows=1,
                panel_names=az_shear_field_names,
                plot_colour_bar_by_panel=plot_colour_bar_flags,
                font_size=FONT_SIZE_SANS_COLOUR_BARS)
        )

        this_colour_map_object, this_colour_norm_object = (
            radar_plotting.get_default_colour_scheme(
                radar_utils.LOW_LEVEL_SHEAR_NAME)
        )

        plotting_utils.add_colour_bar(
            axes_object_or_list=these_axes_objects,
            values_to_colour=this_az_shear_matrix_s01,
            colour_map=this_colour_map_object,
            colour_norm_object=this_colour_norm_object,
            orientation='horizontal', extend_min=True, extend_max=True)

        this_title_string = '{0:s} (after optimization)'.format(
            this_base_title_string)

        this_file_name = (
            '{0:s}/{1:s}_after-optimization_azimuthal-shear.jpg'
        ).format(after_optimization_dir_name, this_base_pathless_file_name)

        pyplot.suptitle(this_title_string, fontsize=TITLE_FONT_SIZE)
        print 'Saving figure to: "{0:s}"...'.format(this_file_name)
        pyplot.savefig(this_file_name, dpi=FIGURE_RESOLUTION_DPI)
        pyplot.close()

        if list_of_input_matrices is None:
            continue

        this_reflectivity_matrix_dbz = numpy.flip(
            list_of_input_matrices[0][i, ..., 0], axis=0)

        _, these_axes_objects = radar_plotting.plot_3d_grid_without_coords(
            field_matrix=this_reflectivity_matrix_dbz,
            field_name=radar_utils.REFL_NAME,
            grid_point_heights_metres=training_option_dict[
                trainval_io.RADAR_HEIGHTS_KEY],
            ground_relative=True, num_panel_rows=this_num_panel_rows,
            font_size=FONT_SIZE_SANS_COLOUR_BARS)

        this_colour_map_object, this_colour_norm_object = (
            radar_plotting.get_default_colour_scheme(radar_utils.REFL_NAME)
        )

        plotting_utils.add_colour_bar(
            axes_object_or_list=these_axes_objects,
            values_to_colour=this_reflectivity_matrix_dbz,
            colour_map=this_colour_map_object,
            colour_norm_object=this_colour_norm_object,
            orientation='horizontal', extend_min=True, extend_max=True)

        this_title_string = '{0:s} (before optimization)'.format(
            this_base_title_string)

        this_file_name = (
            '{0:s}/{1:s}_before-optimization_reflectivity.jpg'
        ).format(before_optimization_dir_name, this_base_pathless_file_name)

        pyplot.suptitle(this_title_string, fontsize=TITLE_FONT_SIZE)
        print 'Saving figure to: "{0:s}"...'.format(this_file_name)
        pyplot.savefig(this_file_name, dpi=FIGURE_RESOLUTION_DPI)
        pyplot.close()

        this_az_shear_matrix_s01 = numpy.flip(
            list_of_input_matrices[1][i, ..., 0], axis=0)

        _, these_axes_objects = (
            radar_plotting.plot_many_2d_grids_without_coords(
                field_matrix=this_az_shear_matrix_s01,
                field_name_by_panel=az_shear_field_names, num_panel_rows=1,
                panel_names=az_shear_field_names,
                plot_colour_bar_by_panel=plot_colour_bar_flags,
                font_size=FONT_SIZE_SANS_COLOUR_BARS)
        )

        this_colour_map_object, this_colour_norm_object = (
            radar_plotting.get_default_colour_scheme(
                radar_utils.LOW_LEVEL_SHEAR_NAME)
        )

        plotting_utils.add_colour_bar(
            axes_object_or_list=these_axes_objects,
            values_to_colour=this_az_shear_matrix_s01,
            colour_map=this_colour_map_object,
            colour_norm_object=this_colour_norm_object,
            orientation='horizontal', extend_min=True, extend_max=True)

        this_title_string = '{0:s} (before optimization)'.format(
            this_base_title_string)

        this_file_name = (
            '{0:s}/{1:s}_before-optimization_azimuthal-shear.jpg'
        ).format(before_optimization_dir_name, this_base_pathless_file_name)

        pyplot.suptitle(this_title_string, fontsize=TITLE_FONT_SIZE)
        print 'Saving figure to: "{0:s}"...'.format(this_file_name)
        pyplot.savefig(this_file_name, dpi=FIGURE_RESOLUTION_DPI)
        pyplot.close()

        this_refl_diff_matrix_dbz = (
            list_of_optimized_matrices[0][i, ..., 0] -
            list_of_input_matrices[0][i, ..., 0]
        )
        this_refl_diff_matrix_dbz = numpy.flip(
            this_refl_diff_matrix_dbz, axis=0)

        this_max_value_dbz = numpy.percentile(
            numpy.absolute(this_refl_diff_matrix_dbz),
            max_colour_percentile_for_diff)

        this_colour_norm_object = matplotlib.colors.Normalize(
            vmin=-1 * this_max_value_dbz, vmax=this_max_value_dbz, clip=False)

        _, these_axes_objects = radar_plotting.plot_3d_grid_without_coords(
            field_matrix=this_refl_diff_matrix_dbz,
            field_name=radar_utils.REFL_NAME,
            grid_point_heights_metres=training_option_dict[
                trainval_io.RADAR_HEIGHTS_KEY],
            ground_relative=True, num_panel_rows=this_num_panel_rows,
            font_size=FONT_SIZE_SANS_COLOUR_BARS,
            colour_map_object=diff_colour_map_object,
            colour_norm_object=this_colour_norm_object)

        plotting_utils.add_colour_bar(
            axes_object_or_list=these_axes_objects,
            values_to_colour=this_refl_diff_matrix_dbz,
            colour_map=diff_colour_map_object,
            colour_norm_object=this_colour_norm_object,
            orientation='horizontal', extend_min=True, extend_max=True)

        this_title_string = '{0:s} (after minus before optimization)'.format(
            this_base_title_string)

        this_file_name = (
            '{0:s}/{1:s}_optimization-diff_reflectivity.jpg'
        ).format(difference_dir_name, this_base_pathless_file_name)

        pyplot.suptitle(this_title_string, fontsize=TITLE_FONT_SIZE)
        print 'Saving figure to: "{0:s}"...'.format(this_file_name)
        pyplot.savefig(this_file_name, dpi=FIGURE_RESOLUTION_DPI)
        pyplot.close()

        this_shear_diff_matrix_s01 = (
            list_of_optimized_matrices[1][i, ..., 0] -
            list_of_input_matrices[1][i, ..., 0]
        )
        this_shear_diff_matrix_s01 = numpy.flip(
            this_shear_diff_matrix_s01, axis=0)

        this_max_value_s01 = numpy.percentile(
            numpy.absolute(this_shear_diff_matrix_s01),
            max_colour_percentile_for_diff)

        this_colour_norm_object = matplotlib.colors.Normalize(
            vmin=-1 * this_max_value_s01, vmax=this_max_value_s01, clip=False)

        _, these_axes_objects = (
            radar_plotting.plot_many_2d_grids_without_coords(
                field_matrix=this_shear_diff_matrix_s01,
                field_name_by_panel=az_shear_field_names, num_panel_rows=1,
                panel_names=az_shear_field_names,
                colour_map_object_by_panel=
                [diff_colour_map_object] * num_az_shear_fields,
                colour_norm_object_by_panel=
                [copy.deepcopy(this_colour_norm_object)] * num_az_shear_fields,
                plot_colour_bar_by_panel=plot_colour_bar_flags,
                font_size=FONT_SIZE_SANS_COLOUR_BARS)
        )

        plotting_utils.add_colour_bar(
            axes_object_or_list=these_axes_objects,
            values_to_colour=this_shear_diff_matrix_s01,
            colour_map=diff_colour_map_object,
            colour_norm_object=this_colour_norm_object,
            orientation='horizontal', extend_min=True, extend_max=True)

        this_title_string = '{0:s} (after minus before optimization)'.format(
            this_base_title_string)

        this_file_name = (
            '{0:s}/{1:s}_optimization-diff_azimuthal-shear.jpg'
        ).format(difference_dir_name, this_base_pathless_file_name)

        pyplot.suptitle(this_title_string, fontsize=TITLE_FONT_SIZE)
        print 'Saving figure to: "{0:s}"...'.format(this_file_name)
        pyplot.savefig(this_file_name, dpi=FIGURE_RESOLUTION_DPI)
        pyplot.close()
def _plot_2d_radar_cams(
        radar_matrix, model_metadata_dict, cam_colour_map_object,
        max_colour_prctile_for_cam, output_dir_name,
        class_activation_matrix=None, ggradcam_output_matrix=None,
        storm_ids=None, storm_times_unix_sec=None):
    """Plots class-activation maps for 2-D radar data.

    E = number of examples
    M = number of rows in spatial grid
    N = number of columns in spatial grid
    C = number of channels (field/height pairs)

    This method will plot either `class_activation_matrix` or
    `ggradcam_output_matrix`, not both.

    If `storm_ids is None` and `storm_times_unix_sec is None`, will assume that
    the input matrices contain probability-matched means.

    :param radar_matrix: E-by-M-by-N-by-C numpy array of radar values.
    :param model_metadata_dict: See doc for `_plot_3d_radar_cams`.
    :param cam_colour_map_object: Same.
    :param max_colour_prctile_for_cam: Same.
    :param output_dir_name: Same.
    :param class_activation_matrix: E-by-M-by-N numpy array of class
        activations.
    :param ggradcam_output_matrix: E-by-M-by-N-by-C numpy array of output values
        from guided Grad-CAM.
    :param storm_ids: length-E list of storm IDs (strings).
    :param storm_times_unix_sec: length-E numpy array of storm times.
    """

    pmm_flag = storm_ids is None and storm_times_unix_sec is None

    training_option_dict = model_metadata_dict[cnn.TRAINING_OPTION_DICT_KEY]
    list_of_layer_operation_dicts = model_metadata_dict[
        cnn.LAYER_OPERATIONS_KEY]

    if list_of_layer_operation_dicts is None:
        field_name_by_panel = training_option_dict[
            trainval_io.RADAR_FIELDS_KEY]

        panel_names = (
            radar_plotting.radar_fields_and_heights_to_panel_names(
                field_names=field_name_by_panel,
                heights_m_agl=training_option_dict[
                    trainval_io.RADAR_HEIGHTS_KEY]
            )
        )

        plot_colour_bar_by_panel = numpy.full(
            len(panel_names), True, dtype=bool)

    else:
        field_name_by_panel, panel_names = (
            radar_plotting.layer_ops_to_field_and_panel_names(
                list_of_layer_operation_dicts)
        )

        plot_colour_bar_by_panel = numpy.full(
            len(panel_names), False, dtype=bool)
        plot_colour_bar_by_panel[2::3] = True

    num_examples = radar_matrix.shape[0]
    num_channels = radar_matrix.shape[-1]
    num_panel_rows = int(numpy.floor(numpy.sqrt(num_channels)))

    if class_activation_matrix is None:
        quantity_string = 'max absolute guided Grad-CAM output'
        pathless_file_name_prefix = 'guided-gradcam'
    else:
        quantity_string = 'max class activation'
        pathless_file_name_prefix = 'gradcam'

    for i in range(num_examples):
        _, these_axes_objects = (
            radar_plotting.plot_many_2d_grids_without_coords(
                field_matrix=numpy.flip(radar_matrix[i, ...], axis=0),
                field_name_by_panel=field_name_by_panel,
                panel_names=panel_names, num_panel_rows=num_panel_rows,
                plot_colour_bar_by_panel=plot_colour_bar_by_panel,
                font_size=FONT_SIZE_WITH_COLOUR_BARS, row_major=False)
        )

        if class_activation_matrix is None:
            this_matrix = ggradcam_output_matrix[i, ...]

            this_max_contour_level = numpy.percentile(
                numpy.absolute(this_matrix), max_colour_prctile_for_cam)
            if this_max_contour_level == 0:
                this_max_contour_level = 10.

            saliency_plotting.plot_many_2d_grids_with_contours(
                saliency_matrix_3d=numpy.flip(this_matrix, axis=0),
                axes_objects_2d_list=these_axes_objects,
                colour_map_object=cam_colour_map_object,
                max_absolute_contour_level=this_max_contour_level,
                contour_interval=this_max_contour_level / 10, row_major=False)

        else:
            this_matrix = numpy.expand_dims(
                class_activation_matrix[i, ...], axis=-1)
            this_matrix = numpy.repeat(
                this_matrix, repeats=num_channels, axis=-1)

            this_max_contour_level = numpy.percentile(
                this_matrix, max_colour_prctile_for_cam)
            if this_max_contour_level == 0:
                this_max_contour_level = 10.

            cam_plotting.plot_many_2d_grids(
                class_activation_matrix_3d=numpy.flip(this_matrix, axis=0),
                axes_objects_2d_list=these_axes_objects,
                colour_map_object=cam_colour_map_object,
                max_contour_level=this_max_contour_level,
                contour_interval=this_max_contour_level / NUM_CONTOURS,
                row_major=False)

        if pmm_flag:
            this_title_string = 'Probability-matched mean'
            this_figure_file_name = '{0:s}/{1:s}_pmm_radar.jpg'.format(
                output_dir_name, pathless_file_name_prefix)

        else:
            this_storm_time_string = time_conversion.unix_sec_to_string(
                storm_times_unix_sec[i], TIME_FORMAT)

            this_title_string = 'Storm "{0:s}" at {1:s}'.format(
                storm_ids[i], this_storm_time_string)

            this_figure_file_name = '{0:s}/{1:s}_{2:s}_{3:s}_radar.jpg'.format(
                output_dir_name, pathless_file_name_prefix,
                storm_ids[i].replace('_', '-'), this_storm_time_string)

        this_title_string += ' ({0:s} = {1:.3f})'.format(
            quantity_string, this_max_contour_level)
        pyplot.suptitle(this_title_string, fontsize=TITLE_FONT_SIZE)

        print 'Saving figure to file: "{0:s}"...'.format(this_figure_file_name)
        pyplot.savefig(this_figure_file_name, dpi=FIGURE_RESOLUTION_DPI)
        pyplot.close()
예제 #7
0
def _plot_bwo_for_2d_radar(
        optimized_radar_matrix, model_metadata_dict, diff_colour_map_object,
        max_colour_percentile_for_diff, top_output_dir_name, pmm_flag,
        input_radar_matrix=None, storm_ids=None, storm_times_unix_sec=None):
    """Plots BWO results for 2-D radar fields.

    E = number of examples (storm objects)
    M = number of rows in spatial grid
    N = number of columns in spatial grid
    C = number of channels (field/height pairs)

    :param optimized_radar_matrix: E-by-M-by-N-by-C numpy array of radar values
        (predictors).
    :param model_metadata_dict: Dictionary returned by
        `cnn.read_model_metadata`.
    :param diff_colour_map_object: See doc for `_plot_bwo_for_2d3d_radar`.
    :param max_colour_percentile_for_diff: Same.
    :param top_output_dir_name: Same.
    :param pmm_flag: Same.
    :param input_radar_matrix: Same as `optimized_radar_matrix` but with
        non-optimized input.
    :param storm_ids: See doc for `_plot_bwo_for_2d3d_radar`.
    :param storm_times_unix_sec: Same.
    """

    before_optimization_dir_name = '{0:s}/before_optimization'.format(
        top_output_dir_name)
    after_optimization_dir_name = '{0:s}/after_optimization'.format(
        top_output_dir_name)
    difference_dir_name = '{0:s}/after_minus_before_optimization'.format(
        top_output_dir_name)

    file_system_utils.mkdir_recursive_if_necessary(
        directory_name=before_optimization_dir_name)
    file_system_utils.mkdir_recursive_if_necessary(
        directory_name=after_optimization_dir_name)
    file_system_utils.mkdir_recursive_if_necessary(
        directory_name=difference_dir_name)

    training_option_dict = model_metadata_dict[cnn.TRAINING_OPTION_DICT_KEY]
    list_of_layer_operation_dicts = model_metadata_dict[
        cnn.LAYER_OPERATIONS_KEY]

    if pmm_flag:
        have_storm_ids = False
    else:
        have_storm_ids = not (storm_ids is None or storm_times_unix_sec is None)

    if list_of_layer_operation_dicts is None:
        field_name_by_panel = training_option_dict[
            trainval_io.RADAR_FIELDS_KEY]

        panel_names = (
            radar_plotting.radar_fields_and_heights_to_panel_names(
                field_names=field_name_by_panel,
                heights_m_agl=training_option_dict[
                    trainval_io.RADAR_HEIGHTS_KEY]
            )
        )

        plot_colour_bar_by_panel = numpy.full(
            len(panel_names), True, dtype=bool)

    else:
        field_name_by_panel, panel_names = (
            radar_plotting.layer_ops_to_field_and_panel_names(
                list_of_layer_operation_dicts)
        )

        plot_colour_bar_by_panel = numpy.full(
            len(panel_names), False, dtype=bool)
        plot_colour_bar_by_panel[2::3] = True

    num_panels = len(panel_names)
    num_storms = optimized_radar_matrix.shape[0]
    num_channels = optimized_radar_matrix.shape[-1]
    num_panel_rows = int(numpy.floor(
        numpy.sqrt(num_channels)
    ))

    for i in range(num_storms):
        print '\n'

        if pmm_flag:
            this_base_title_string = 'Probability-matched mean'
            this_base_pathless_file_name = 'pmm'
        else:
            if have_storm_ids:
                this_storm_time_string = time_conversion.unix_sec_to_string(
                    storm_times_unix_sec[i], TIME_FORMAT)

                this_base_title_string = 'Storm "{0:s}" at {1:s}'.format(
                    storm_ids[i], this_storm_time_string)

                this_base_pathless_file_name = '{0:s}_{1:s}'.format(
                    storm_ids[i].replace('_', '-'), this_storm_time_string)

            else:
                this_base_title_string = 'Example {0:d}'.format(i + 1)
                this_base_pathless_file_name = 'example{0:06d}'.format(i)

        radar_plotting.plot_many_2d_grids_without_coords(
            field_matrix=numpy.flip(optimized_radar_matrix[i, ...], axis=0),
            field_name_by_panel=field_name_by_panel,
            num_panel_rows=num_panel_rows, panel_names=panel_names,
            plot_colour_bar_by_panel=plot_colour_bar_by_panel,
            font_size=FONT_SIZE_WITH_COLOUR_BARS, row_major=False)

        this_title_string = '{0:s} (after optimization)'.format(
            this_base_title_string)
        this_file_name = '{0:s}/{1:s}_after-optimization_radar.jpg'.format(
            after_optimization_dir_name, this_base_pathless_file_name)

        pyplot.suptitle(this_title_string, fontsize=TITLE_FONT_SIZE)
        print 'Saving figure to: "{0:s}"...'.format(this_file_name)
        pyplot.savefig(this_file_name, dpi=FIGURE_RESOLUTION_DPI)
        pyplot.close()

        if input_radar_matrix is None:
            continue

        radar_plotting.plot_many_2d_grids_without_coords(
            field_matrix=numpy.flip(input_radar_matrix[i, ...], axis=0),
            field_name_by_panel=field_name_by_panel,
            num_panel_rows=num_panel_rows, panel_names=panel_names,
            plot_colour_bar_by_panel=plot_colour_bar_by_panel,
            font_size=FONT_SIZE_WITH_COLOUR_BARS, row_major=False)

        this_title_string = '{0:s} (before optimization)'.format(
            this_base_title_string)
        this_file_name = '{0:s}/{1:s}_before-optimization_radar.jpg'.format(
            before_optimization_dir_name, this_base_pathless_file_name)

        pyplot.suptitle(this_title_string, fontsize=TITLE_FONT_SIZE)
        print 'Saving figure to: "{0:s}"...'.format(this_file_name)
        pyplot.savefig(this_file_name, dpi=FIGURE_RESOLUTION_DPI)
        pyplot.close()

        this_cmap_object_by_panel = [diff_colour_map_object] * num_panels
        this_cnorm_object_by_panel = [None] * num_panels

        if list_of_layer_operation_dicts is None:
            for j in range(num_panels):
                this_diff_matrix = (
                    optimized_radar_matrix[i, ..., j] -
                    input_radar_matrix[i, ..., j]
                )

                this_max_value = numpy.percentile(
                    numpy.absolute(this_diff_matrix),
                    max_colour_percentile_for_diff)

                this_cnorm_object_by_panel[j] = matplotlib.colors.Normalize(
                    vmin=-1 * this_max_value, vmax=this_max_value, clip=False)

        else:
            unique_field_names = numpy.unique(numpy.array(field_name_by_panel))

            for this_field_name in unique_field_names:
                these_panel_indices = numpy.where(
                    numpy.array(field_name_by_panel) == this_field_name
                )[0]

                this_diff_matrix = (
                    optimized_radar_matrix[i, ..., these_panel_indices] -
                    input_radar_matrix[i, ..., these_panel_indices]
                )

                this_max_value = numpy.percentile(
                    numpy.absolute(this_diff_matrix),
                    max_colour_percentile_for_diff)

                for this_index in these_panel_indices:
                    this_cnorm_object_by_panel[this_index] = (
                        matplotlib.colors.Normalize(
                            vmin=-1 * this_max_value, vmax=this_max_value,
                            clip=False)
                    )

        this_diff_matrix = (
            optimized_radar_matrix[i, ...] - input_radar_matrix[i, ...]
        )

        radar_plotting.plot_many_2d_grids_without_coords(
            field_matrix=numpy.flip(this_diff_matrix, axis=0),
            field_name_by_panel=field_name_by_panel,
            num_panel_rows=num_panel_rows, panel_names=panel_names,
            colour_map_object_by_panel=this_cmap_object_by_panel,
            colour_norm_object_by_panel=this_cnorm_object_by_panel,
            plot_colour_bar_by_panel=plot_colour_bar_by_panel,
            font_size=FONT_SIZE_WITH_COLOUR_BARS, row_major=False)

        this_title_string = '{0:s} (after minus before optimization)'.format(
            this_base_title_string)
        this_file_name = '{0:s}/{1:s}_optimization-diff_radar.jpg'.format(
            difference_dir_name, this_base_pathless_file_name)

        pyplot.suptitle(this_title_string, fontsize=TITLE_FONT_SIZE)
        print 'Saving figure to: "{0:s}"...'.format(this_file_name)
        pyplot.savefig(this_file_name, dpi=FIGURE_RESOLUTION_DPI)
        pyplot.close()
예제 #8
0
def _plot_2d_radar_scan(list_of_predictor_matrices,
                        model_metadata_dict,
                        allow_whitespace,
                        title_string=None):
    """Plots 2-D radar scan for one example.

    J = number of panel rows in image
    K = number of panel columns in image

    :param list_of_predictor_matrices: See doc for `_plot_3d_radar_scan`.
    :param model_metadata_dict: Same.
    :param allow_whitespace: Same.
    :param title_string: Same.
    :return: figure_objects: length-1 list of figure handles (instances of
        `matplotlib.figure.Figure`).
    :return: axes_object_matrices: length-1 list.  Each element is a J-by-K
        numpy array of axes handles (instances of
        `matplotlib.axes._subplots.AxesSubplot`).
    """

    training_option_dict = model_metadata_dict[cnn.TRAINING_OPTION_DICT_KEY]
    list_of_layer_operation_dicts = model_metadata_dict[
        cnn.LAYER_OPERATIONS_KEY]

    if list_of_layer_operation_dicts is None:
        field_name_by_panel = training_option_dict[
            trainval_io.RADAR_FIELDS_KEY]
        num_panels = len(field_name_by_panel)

        panel_names = radar_plotting.radar_fields_and_heights_to_panel_names(
            field_names=field_name_by_panel,
            heights_m_agl=training_option_dict[trainval_io.RADAR_HEIGHTS_KEY])

        plot_cbar_by_panel = numpy.full(num_panels, True, dtype=bool)
    else:
        list_of_layer_operation_dicts = [
            list_of_layer_operation_dicts[k] for k in LAYER_OP_INDICES_TO_KEEP
        ]

        list_of_predictor_matrices[0] = list_of_predictor_matrices[0][
            ..., LAYER_OP_INDICES_TO_KEEP]

        field_name_by_panel, panel_names = (
            radar_plotting.layer_ops_to_field_and_panel_names(
                list_of_layer_operation_dicts=list_of_layer_operation_dicts))

        num_panels = len(field_name_by_panel)
        plot_cbar_by_panel = numpy.full(num_panels, True, dtype=bool)

        # if allow_whitespace:
        #     if len(field_name_by_panel) == 12:
        #         plot_cbar_by_panel[2::3] = True
        #     else:
        #         plot_cbar_by_panel[:] = True

    num_panel_rows = int(numpy.floor(numpy.sqrt(num_panels)))
    num_panel_columns = int(numpy.ceil(float(num_panels) / num_panel_rows))

    if allow_whitespace:
        figure_object = None
        axes_object_matrix = None
    else:
        figure_object, axes_object_matrix = (
            plotting_utils.create_paneled_figure(num_rows=num_panel_rows,
                                                 num_columns=num_panel_columns,
                                                 horizontal_spacing=0.,
                                                 vertical_spacing=0.,
                                                 shared_x_axis=False,
                                                 shared_y_axis=False,
                                                 keep_aspect_ratio=True))

    figure_object, axes_object_matrix = (
        radar_plotting.plot_many_2d_grids_without_coords(
            field_matrix=numpy.flip(list_of_predictor_matrices[0], axis=0),
            field_name_by_panel=field_name_by_panel,
            panel_names=panel_names,
            num_panel_rows=num_panel_rows,
            figure_object=figure_object,
            axes_object_matrix=axes_object_matrix,
            plot_colour_bar_by_panel=plot_cbar_by_panel,
            font_size=FONT_SIZE_WITH_COLOUR_BARS,
            row_major=False))

    if allow_whitespace and title_string is not None:
        pyplot.suptitle(title_string, fontsize=TITLE_FONT_SIZE)

    return [figure_object], [axes_object_matrix]
예제 #9
0
def _plot_2d3d_radar_scan(list_of_predictor_matrices,
                          model_metadata_dict,
                          allow_whitespace,
                          title_string=None):
    """Plots 3-D reflectivity and 2-D azimuthal shear for one example.

    :param list_of_predictor_matrices: See doc for `_plot_3d_radar_scan`.
    :param model_metadata_dict: Same.
    :param allow_whitespace: Same.
    :param title_string: Same.
    :return: figure_objects: length-2 list of figure handles (instances of
        `matplotlib.figure.Figure`).  The first is for reflectivity; the second
        is for azimuthal shear.
    :return: axes_object_matrices: length-2 list (the first is for reflectivity;
        the second is for azimuthal shear).  Each element is a 2-D numpy
        array of axes handles (instances of
        `matplotlib.axes._subplots.AxesSubplot`).
    """

    training_option_dict = model_metadata_dict[cnn.TRAINING_OPTION_DICT_KEY]
    az_shear_field_names = training_option_dict[trainval_io.RADAR_FIELDS_KEY]
    refl_heights_m_agl = training_option_dict[trainval_io.RADAR_HEIGHTS_KEY]

    num_az_shear_fields = len(az_shear_field_names)
    num_refl_heights = len(refl_heights_m_agl)

    this_num_panel_rows = int(numpy.floor(numpy.sqrt(num_refl_heights)))
    this_num_panel_columns = int(
        numpy.ceil(float(num_refl_heights) / this_num_panel_rows))

    if allow_whitespace:
        refl_figure_object = None
        refl_axes_object_matrix = None
    else:
        refl_figure_object, refl_axes_object_matrix = (
            plotting_utils.create_paneled_figure(
                num_rows=this_num_panel_rows,
                num_columns=this_num_panel_columns,
                horizontal_spacing=0.,
                vertical_spacing=0.,
                shared_x_axis=False,
                shared_y_axis=False,
                keep_aspect_ratio=True))

    refl_figure_object, refl_axes_object_matrix = (
        radar_plotting.plot_3d_grid_without_coords(
            field_matrix=numpy.flip(list_of_predictor_matrices[0][..., 0],
                                    axis=0),
            field_name=radar_utils.REFL_NAME,
            grid_point_heights_metres=refl_heights_m_agl,
            ground_relative=True,
            num_panel_rows=this_num_panel_rows,
            figure_object=refl_figure_object,
            axes_object_matrix=refl_axes_object_matrix,
            font_size=FONT_SIZE_SANS_COLOUR_BARS))

    if allow_whitespace:
        this_colour_map_object, this_colour_norm_object = (
            radar_plotting.get_default_colour_scheme(radar_utils.REFL_NAME))

        plotting_utils.plot_colour_bar(
            axes_object_or_matrix=refl_axes_object_matrix,
            data_matrix=list_of_predictor_matrices[0],
            colour_map_object=this_colour_map_object,
            colour_norm_object=this_colour_norm_object,
            orientation_string='horizontal',
            extend_min=True,
            extend_max=True)

        if title_string is not None:
            this_title_string = '{0:s}; {1:s}'.format(title_string,
                                                      radar_utils.REFL_NAME)
            pyplot.suptitle(this_title_string, fontsize=TITLE_FONT_SIZE)

    if allow_whitespace:
        shear_figure_object = None
        shear_axes_object_matrix = None
    else:
        shear_figure_object, shear_axes_object_matrix = (
            plotting_utils.create_paneled_figure(
                num_rows=1,
                num_columns=num_az_shear_fields,
                horizontal_spacing=0.,
                vertical_spacing=0.,
                shared_x_axis=False,
                shared_y_axis=False,
                keep_aspect_ratio=True))

    shear_figure_object, shear_axes_object_matrix = (
        radar_plotting.plot_many_2d_grids_without_coords(
            field_matrix=numpy.flip(list_of_predictor_matrices[1], axis=0),
            field_name_by_panel=az_shear_field_names,
            panel_names=az_shear_field_names,
            num_panel_rows=1,
            figure_object=shear_figure_object,
            axes_object_matrix=shear_axes_object_matrix,
            plot_colour_bar_by_panel=numpy.full(num_az_shear_fields,
                                                False,
                                                dtype=bool),
            font_size=FONT_SIZE_SANS_COLOUR_BARS))

    if allow_whitespace:
        this_colour_map_object, this_colour_norm_object = (
            radar_plotting.get_default_colour_scheme(
                radar_utils.LOW_LEVEL_SHEAR_NAME))

        plotting_utils.plot_colour_bar(
            axes_object_or_matrix=shear_axes_object_matrix,
            data_matrix=list_of_predictor_matrices[1],
            colour_map_object=this_colour_map_object,
            colour_norm_object=this_colour_norm_object,
            orientation_string='horizontal',
            extend_min=True,
            extend_max=True)

        if title_string is not None:
            pyplot.suptitle(title_string, fontsize=TITLE_FONT_SIZE)

    figure_objects = [refl_figure_object, shear_figure_object]
    axes_object_matrices = [refl_axes_object_matrix, shear_axes_object_matrix]
    return figure_objects, axes_object_matrices
def plot_examples(list_of_predictor_matrices,
                  storm_ids,
                  storm_times_unix_sec,
                  model_metadata_dict,
                  output_dir_name,
                  storm_activations=None):
    """Plots one or more learning examples.

    E = number of examples (storm objects)

    :param list_of_predictor_matrices: List created by
        `testing_io.read_specific_examples`.  Contains data to be plotted.
    :param storm_ids: length-E list of storm IDs.
    :param storm_times_unix_sec: length-E numpy array of storm times.
    :param model_metadata_dict: See doc for `cnn.read_model_metadata`.
    :param output_dir_name: Name of output directory (figures will be saved
        here).
    :param storm_activations: length-E numpy array of storm activations (may be
        None).  Will be included in title of each figure.
    """

    training_option_dict = model_metadata_dict[cnn.TRAINING_OPTION_DICT_KEY]
    sounding_field_names = training_option_dict[
        trainval_io.SOUNDING_FIELDS_KEY]
    plot_soundings = sounding_field_names is not None

    if plot_soundings:
        list_of_metpy_dictionaries = dl_utils.soundings_to_metpy_dictionaries(
            sounding_matrix=list_of_predictor_matrices[-1],
            field_names=sounding_field_names)
    else:
        list_of_metpy_dictionaries = None

    num_radar_dimensions = len(list_of_predictor_matrices[0].shape) - 2
    list_of_layer_operation_dicts = model_metadata_dict[
        cnn.LAYER_OPERATIONS_KEY]

    if num_radar_dimensions == 2:
        if list_of_layer_operation_dicts is None:
            field_name_by_panel = training_option_dict[
                trainval_io.RADAR_FIELDS_KEY]

            panel_names = (
                radar_plotting.radar_fields_and_heights_to_panel_names(
                    field_names=field_name_by_panel,
                    heights_m_agl=training_option_dict[
                        trainval_io.RADAR_HEIGHTS_KEY]))

            plot_colour_bar_by_panel = numpy.full(len(panel_names),
                                                  True,
                                                  dtype=bool)

        else:
            field_name_by_panel, panel_names = (
                radar_plotting.layer_ops_to_field_and_panel_names(
                    list_of_layer_operation_dicts))

            plot_colour_bar_by_panel = numpy.full(len(panel_names),
                                                  False,
                                                  dtype=bool)
            plot_colour_bar_by_panel[2::3] = True
    else:
        field_name_by_panel = None
        panel_names = None
        plot_colour_bar_by_panel = None

    az_shear_field_names = training_option_dict[trainval_io.RADAR_FIELDS_KEY]
    num_az_shear_fields = len(az_shear_field_names)

    num_storms = len(storm_ids)
    myrorss_2d3d = len(list_of_predictor_matrices) == 3

    for i in range(num_storms):
        this_time_string = time_conversion.unix_sec_to_string(
            storm_times_unix_sec[i], TIME_FORMAT)
        this_base_title_string = 'Storm "{0:s}" at {1:s}'.format(
            storm_ids[i], this_time_string)

        if storm_activations is not None:
            this_base_title_string += ' (activation = {0:.3f})'.format(
                storm_activations[i])

        this_base_file_name = '{0:s}/storm={1:s}_{2:s}'.format(
            output_dir_name, storm_ids[i].replace('_', '-'), this_time_string)

        if plot_soundings:
            sounding_plotting.plot_sounding(
                sounding_dict_for_metpy=list_of_metpy_dictionaries[i],
                title_string=this_base_title_string)

            this_file_name = '{0:s}_sounding.jpg'.format(this_base_file_name)
            print 'Saving figure to: "{0:s}"...'.format(this_file_name)
            pyplot.savefig(this_file_name, dpi=FIGURE_RESOLUTION_DPI)
            pyplot.close()

        if myrorss_2d3d:
            this_reflectivity_matrix_dbz = numpy.flip(
                list_of_predictor_matrices[0][i, ..., 0], axis=0)

            this_num_heights = this_reflectivity_matrix_dbz.shape[-1]
            this_num_panel_rows = int(numpy.floor(
                numpy.sqrt(this_num_heights)))

            _, these_axes_objects = radar_plotting.plot_3d_grid_without_coords(
                field_matrix=this_reflectivity_matrix_dbz,
                field_name=radar_utils.REFL_NAME,
                grid_point_heights_metres=training_option_dict[
                    trainval_io.RADAR_HEIGHTS_KEY],
                ground_relative=True,
                num_panel_rows=this_num_panel_rows,
                font_size=FONT_SIZE_SANS_COLOUR_BARS)

            this_colour_map_object, this_colour_norm_object = (
                radar_plotting.get_default_colour_scheme(
                    radar_utils.REFL_NAME))

            plotting_utils.add_colour_bar(
                axes_object_or_list=these_axes_objects,
                values_to_colour=this_reflectivity_matrix_dbz,
                colour_map=this_colour_map_object,
                colour_norm_object=this_colour_norm_object,
                orientation='horizontal',
                extend_min=True,
                extend_max=True)

            this_title_string = '{0:s}; {1:s}'.format(this_base_title_string,
                                                      radar_utils.REFL_NAME)
            this_file_name = '{0:s}_reflectivity.jpg'.format(
                this_base_file_name)

            pyplot.suptitle(this_title_string, fontsize=TITLE_FONT_SIZE)
            print 'Saving figure to: "{0:s}"...'.format(this_file_name)
            pyplot.savefig(this_file_name, dpi=FIGURE_RESOLUTION_DPI)
            pyplot.close()

            this_az_shear_matrix_s01 = numpy.flip(
                list_of_predictor_matrices[1][i, ..., 0], axis=0)

            _, these_axes_objects = (
                radar_plotting.plot_many_2d_grids_without_coords(
                    field_matrix=this_az_shear_matrix_s01,
                    field_name_by_panel=az_shear_field_names,
                    panel_names=az_shear_field_names,
                    num_panel_rows=1,
                    plot_colour_bar_by_panel=numpy.full(num_az_shear_fields,
                                                        False,
                                                        dtype=bool),
                    font_size=FONT_SIZE_SANS_COLOUR_BARS))

            this_colour_map_object, this_colour_norm_object = (
                radar_plotting.get_default_colour_scheme(
                    radar_utils.LOW_LEVEL_SHEAR_NAME))

            plotting_utils.add_colour_bar(
                axes_object_or_list=these_axes_objects,
                values_to_colour=this_az_shear_matrix_s01,
                colour_map=this_colour_map_object,
                colour_norm_object=this_colour_norm_object,
                orientation='horizontal',
                extend_min=True,
                extend_max=True)

            this_file_name = '{0:s}_shear.jpg'.format(this_base_file_name)
            pyplot.suptitle(this_base_title_string, fontsize=TITLE_FONT_SIZE)
            print 'Saving figure to: "{0:s}"...'.format(this_file_name)
            pyplot.savefig(this_file_name, dpi=FIGURE_RESOLUTION_DPI)
            pyplot.close()

            continue

        this_radar_matrix = list_of_predictor_matrices[0]

        if num_radar_dimensions == 2:
            this_num_channels = this_radar_matrix.shape[-1]
            this_num_panel_rows = int(
                numpy.floor(numpy.sqrt(this_num_channels)))

            radar_plotting.plot_many_2d_grids_without_coords(
                field_matrix=numpy.flip(this_radar_matrix[i, ...], axis=0),
                field_name_by_panel=field_name_by_panel,
                panel_names=panel_names,
                num_panel_rows=this_num_panel_rows,
                plot_colour_bar_by_panel=plot_colour_bar_by_panel,
                font_size=FONT_SIZE_WITH_COLOUR_BARS,
                row_major=False)

            this_title_string = this_base_title_string + ''
            pyplot.suptitle(this_title_string, fontsize=TITLE_FONT_SIZE)

            this_file_name = '{0:s}.jpg'.format(this_base_file_name)
            print 'Saving figure to: "{0:s}"...'.format(this_file_name)
            pyplot.savefig(this_file_name, dpi=FIGURE_RESOLUTION_DPI)
            pyplot.close()

            continue

        radar_field_names = training_option_dict[trainval_io.RADAR_FIELDS_KEY]
        radar_heights_m_agl = training_option_dict[
            trainval_io.RADAR_HEIGHTS_KEY]

        for j in range(len(radar_field_names)):
            this_num_heights = this_radar_matrix.shape[-2]
            this_num_panel_rows = int(numpy.floor(
                numpy.sqrt(this_num_heights)))

            _, these_axes_objects = radar_plotting.plot_3d_grid_without_coords(
                field_matrix=numpy.flip(this_radar_matrix[i, ..., j], axis=0),
                field_name=radar_field_names[j],
                grid_point_heights_metres=radar_heights_m_agl,
                ground_relative=True,
                num_panel_rows=this_num_panel_rows,
                font_size=FONT_SIZE_SANS_COLOUR_BARS)

            this_colour_map_object, this_colour_norm_object = (
                radar_plotting.get_default_colour_scheme(radar_field_names[j]))

            plotting_utils.add_colour_bar(
                axes_object_or_list=these_axes_objects,
                values_to_colour=this_radar_matrix[i, ..., j],
                colour_map=this_colour_map_object,
                colour_norm_object=this_colour_norm_object,
                orientation='horizontal',
                extend_min=True,
                extend_max=True)

            this_title_string = '{0:s}; {1:s}'.format(this_base_title_string,
                                                      radar_field_names[j])
            this_file_name = '{0:s}_{1:s}.jpg'.format(
                this_base_file_name, radar_field_names[j].replace('_', '-'))

            pyplot.suptitle(this_title_string, fontsize=TITLE_FONT_SIZE)
            print 'Saving figure to: "{0:s}"...'.format(this_file_name)
            pyplot.savefig(this_file_name, dpi=FIGURE_RESOLUTION_DPI)
            pyplot.close()
def _plot_2d_radar_difference(difference_matrix,
                              colour_map_object,
                              max_colour_percentile,
                              model_metadata_dict,
                              backwards_opt_dict,
                              output_dir_name,
                              example_index=None,
                              significance_matrix=None):
    """Plots difference (after minus before optimization) for 2-D radar data.

    M = number of rows in spatial grid
    N = number of columns in spatial grid
    C = number of channels

    :param difference_matrix: M-by-N-by-C numpy array of differences (after
        minus before optimization).
    :param colour_map_object: See doc for `_plot_3d_radar_difference`.
    :param max_colour_percentile: Same.
    :param model_metadata_dict: Same.
    :param backwards_opt_dict: Same.
    :param output_dir_name: Same.
    :param example_index: Same.
    :param significance_matrix: M-by-N-by-C numpy array of Boolean flags,
        indicating where these differences are significantly different than
        differences from another backwards optimization.
    """

    pmm_flag = backwards_opt.MEAN_FINAL_ACTIVATION_KEY in backwards_opt_dict
    if pmm_flag:
        initial_activation = backwards_opt_dict[
            backwards_opt.MEAN_INITIAL_ACTIVATION_KEY]
        final_activation = backwards_opt_dict[
            backwards_opt.MEAN_FINAL_ACTIVATION_KEY]

        full_storm_id_string = None
        storm_time_string = None
    else:
        initial_activation = backwards_opt_dict[
            backwards_opt.INITIAL_ACTIVATIONS_KEY][example_index]
        final_activation = backwards_opt_dict[
            backwards_opt.FINAL_ACTIVATIONS_KEY][example_index]

        full_storm_id_string = backwards_opt_dict[
            backwards_opt.FULL_IDS_KEY][example_index]

        storm_time_string = time_conversion.unix_sec_to_string(
            backwards_opt_dict[backwards_opt.STORM_TIMES_KEY][example_index],
            plot_input_examples.TIME_FORMAT)

    conv_2d3d = model_metadata_dict[cnn.CONV_2D3D_KEY]
    training_option_dict = model_metadata_dict[cnn.TRAINING_OPTION_DICT_KEY]

    if conv_2d3d:
        num_fields = len(training_option_dict[trainval_io.RADAR_FIELDS_KEY])
        radar_heights_m_agl = numpy.full(num_fields,
                                         radar_utils.SHEAR_HEIGHT_M_ASL,
                                         dtype=int)
    else:
        radar_heights_m_agl = training_option_dict[
            trainval_io.RADAR_HEIGHTS_KEY]

    list_of_layer_operation_dicts = model_metadata_dict[
        cnn.LAYER_OPERATIONS_KEY]

    if list_of_layer_operation_dicts is None:
        field_name_by_panel = training_option_dict[
            trainval_io.RADAR_FIELDS_KEY]

        panel_names = radar_plotting.radar_fields_and_heights_to_panel_names(
            field_names=field_name_by_panel, heights_m_agl=radar_heights_m_agl)
    else:
        field_name_by_panel, panel_names = (
            radar_plotting.layer_ops_to_field_and_panel_names(
                list_of_layer_operation_dicts=list_of_layer_operation_dicts))

    num_panels = len(field_name_by_panel)
    plot_cbar_by_panel = numpy.full(num_panels, True, dtype=bool)
    cmap_object_by_panel = [colour_map_object] * num_panels
    cnorm_object_by_panel = [None] * num_panels

    for j in range(num_panels):
        this_max_colour_value = numpy.percentile(
            numpy.absolute(difference_matrix[..., j]), max_colour_percentile)

        cnorm_object_by_panel[j] = matplotlib.colors.Normalize(
            vmin=-1 * this_max_colour_value,
            vmax=this_max_colour_value,
            clip=False)

    num_panel_rows = int(numpy.floor(numpy.sqrt(num_panels)))

    figure_object, axes_object_matrix = (
        radar_plotting.plot_many_2d_grids_without_coords(
            field_matrix=numpy.flip(difference_matrix, axis=0),
            field_name_by_panel=field_name_by_panel,
            num_panel_rows=num_panel_rows,
            panel_names=panel_names,
            row_major=False,
            colour_map_object_by_panel=cmap_object_by_panel,
            colour_norm_object_by_panel=cnorm_object_by_panel,
            plot_colour_bar_by_panel=plot_cbar_by_panel,
            font_size=FONT_SIZE_WITH_COLOUR_BARS))

    if significance_matrix is not None:
        significance_plotting.plot_many_2d_grids_without_coords(
            significance_matrix=numpy.flip(significance_matrix, axis=0),
            axes_object_matrix=axes_object_matrix,
            row_major=False)

    if pmm_flag:
        this_title_string = 'PMM'
    else:
        this_title_string = 'Storm "{0:s}" at {1:s}'.format(
            full_storm_id_string, storm_time_string)

    this_title_string += '; activation from {0:.2e} to {1:.2e}'.format(
        initial_activation, final_activation)
    figure_object.suptitle(this_title_string, fontsize=TITLE_FONT_SIZE)

    output_file_name = plot_input_examples.metadata_to_radar_fig_file_name(
        output_dir_name=output_dir_name,
        pmm_flag=pmm_flag,
        full_storm_id_string=full_storm_id_string,
        storm_time_string=storm_time_string,
        radar_field_name='shear' if conv_2d3d else None)

    print('Saving figure to: "{0:s}"...'.format(output_file_name))
    figure_object.savefig(output_file_name,
                          dpi=FIGURE_RESOLUTION_DPI,
                          pad_inches=0,
                          bbox_inches='tight')
    pyplot.close(figure_object)