Exemple #1
0
def _plot_2d_radar_cam(colour_map_object,
                       max_colour_percentile,
                       figure_objects,
                       axes_object_matrices,
                       model_metadata_dict,
                       output_dir_name,
                       cam_matrix=None,
                       guided_cam_matrix=None,
                       significance_matrix=None,
                       full_storm_id_string=None,
                       storm_time_unix_sec=None):
    """Plots guided or unguided class-activation map for 2-D radar data.

    M = number of rows in spatial grid
    N = number of columns in spatial grid
    C = number of radar channels

    :param colour_map_object: See doc for `_plot_3d_radar_cam`.
    :param max_colour_percentile: Same.
    :param figure_objects: Same.
    :param axes_object_matrices: Same.
    :param model_metadata_dict: Same.
    :param output_dir_name: Same.
    :param cam_matrix: M-by-N numpy array of class activations.
    :param guided_cam_matrix: M-by-N-by-C numpy array of guided-CAM output.
    :param significance_matrix: See doc for `_plot_3d_radar_cam`.
    :param full_storm_id_string: Same.
    :param storm_time_unix_sec: Same.
    """

    if cam_matrix is None:
        quantity_string = 'max abs value'
    else:
        quantity_string = 'max activation'

    pmm_flag = full_storm_id_string is None and storm_time_unix_sec is None
    conv_2d3d = model_metadata_dict[cnn.CONV_2D3D_KEY]
    figure_index = 1 if conv_2d3d else 0

    list_of_layer_operation_dicts = model_metadata_dict[
        cnn.LAYER_OPERATIONS_KEY]

    if list_of_layer_operation_dicts is None:
        training_option_dict = model_metadata_dict[
            cnn.TRAINING_OPTION_DICT_KEY]
        radar_field_names = training_option_dict[trainval_io.RADAR_FIELDS_KEY]
        num_channels = len(radar_field_names)
    else:
        num_channels = len(list_of_layer_operation_dicts)

    if cam_matrix is None:
        this_matrix = guided_cam_matrix
    else:
        this_matrix = numpy.expand_dims(cam_matrix, axis=-1)
        this_matrix = numpy.repeat(this_matrix, repeats=num_channels, axis=-1)

    max_contour_level = numpy.percentile(numpy.absolute(this_matrix),
                                         max_colour_percentile)

    if cam_matrix is None:
        saliency_plotting.plot_many_2d_grids_with_contours(
            saliency_matrix_3d=numpy.flip(this_matrix, axis=0),
            axes_object_matrix=axes_object_matrices[figure_index],
            colour_map_object=colour_map_object,
            max_absolute_contour_level=max_contour_level,
            contour_interval=max_contour_level / HALF_NUM_CONTOURS,
            row_major=False)
    else:
        cam_plotting.plot_many_2d_grids(
            class_activation_matrix_3d=numpy.flip(this_matrix, axis=0),
            axes_object_matrix=axes_object_matrices[figure_index],
            colour_map_object=colour_map_object,
            max_contour_level=max_contour_level,
            contour_interval=max_contour_level / NUM_CONTOURS,
            row_major=False)

    if significance_matrix is not None:
        if cam_matrix is None:
            this_matrix = significance_matrix
        else:
            this_matrix = numpy.expand_dims(significance_matrix, axis=-1)
            this_matrix = numpy.repeat(this_matrix,
                                       repeats=num_channels,
                                       axis=-1)

        significance_plotting.plot_many_2d_grids_without_coords(
            significance_matrix=numpy.flip(this_matrix, axis=0),
            axes_object_matrix=axes_object_matrices[figure_index],
            row_major=False)

    this_title_string = figure_objects[figure_index]._suptitle

    if this_title_string is not None:
        this_title_string += ' ({0:s} = {1:.2e})'.format(
            quantity_string, max_contour_level)

        figure_objects[figure_index].suptitle(
            this_title_string, fontsize=plot_input_examples.TITLE_FONT_SIZE)

    output_file_name = plot_input_examples.metadata_to_file_name(
        output_dir_name=output_dir_name,
        is_sounding=False,
        pmm_flag=pmm_flag,
        full_storm_id_string=full_storm_id_string,
        storm_time_unix_sec=storm_time_unix_sec,
        radar_field_name='shear' if conv_2d3d else None)

    print('Saving figure to: "{0:s}"...'.format(output_file_name))
    figure_objects[figure_index].savefig(output_file_name,
                                         dpi=FIGURE_RESOLUTION_DPI,
                                         pad_inches=0,
                                         bbox_inches='tight')
    pyplot.close(figure_objects[figure_index])
def _plot_3d_radar_cams(
        radar_matrix, model_metadata_dict, cam_colour_map_object,
        max_colour_prctile_for_cam, output_dir_name,
        class_activation_matrix=None, ggradcam_output_matrix=None,
        storm_ids=None, storm_times_unix_sec=None):
    """Plots class-activation maps for 3-D radar data.

    E = number of examples
    M = number of rows in spatial grid
    N = number of columns in spatial grid
    H = number of heights in spatial grid
    F = number of radar fields

    This method will plot either `class_activation_matrix` or
    `ggradcam_output_matrix`, not both.

    If `storm_ids is None` and `storm_times_unix_sec is None`, will assume that
    the input matrices contain probability-matched means.

    :param radar_matrix: E-by-M-by-N-by-H-by-F numpy array of radar values.
    :param model_metadata_dict: Dictionary with CNN metadata (see doc for
        `cnn.read_model_metadata`).
    :param cam_colour_map_object: See documentation at top of file.
    :param max_colour_prctile_for_cam: Same.
    :param output_dir_name: Same.
    :param class_activation_matrix: E-by-M-by-N-by-H numpy array of class
        activations.
    :param ggradcam_output_matrix: E-by-M-by-N-by-H-by-F numpy array of output
        values from guided Grad-CAM.
    :param storm_ids: length-E list of storm IDs (strings).
    :param storm_times_unix_sec: length-E numpy array of storm times.
    """

    pmm_flag = storm_ids is None and storm_times_unix_sec is None

    num_examples = radar_matrix.shape[0]
    num_heights = radar_matrix.shape[-2]
    num_fields = radar_matrix.shape[-1]
    num_panel_rows = int(numpy.floor(numpy.sqrt(num_heights)))

    if class_activation_matrix is None:
        quantity_string = 'max absolute guided Grad-CAM output'
        pathless_file_name_prefix = 'guided-gradcam'
    else:
        quantity_string = 'max class activation'
        pathless_file_name_prefix = 'gradcam'

    training_option_dict = model_metadata_dict[cnn.TRAINING_OPTION_DICT_KEY]

    for i in range(num_examples):
        for k in range(num_fields):
            this_field_name = training_option_dict[
                trainval_io.RADAR_FIELDS_KEY][k]

            _, these_axes_objects = radar_plotting.plot_3d_grid_without_coords(
                field_matrix=numpy.flip(radar_matrix[i, ..., k], axis=0),
                field_name=this_field_name,
                grid_point_heights_metres=training_option_dict[
                    trainval_io.RADAR_HEIGHTS_KEY],
                ground_relative=True, num_panel_rows=num_panel_rows,
                font_size=FONT_SIZE_SANS_COLOUR_BARS)

            if class_activation_matrix is None:
                this_matrix = ggradcam_output_matrix[i, ..., k]

                this_max_contour_level = numpy.percentile(
                    numpy.absolute(this_matrix), max_colour_prctile_for_cam)
                if this_max_contour_level == 0:
                    this_max_contour_level = 10.

                saliency_plotting.plot_many_2d_grids_with_contours(
                    saliency_matrix_3d=numpy.flip(this_matrix, axis=0),
                    axes_objects_2d_list=these_axes_objects,
                    colour_map_object=cam_colour_map_object,
                    max_absolute_contour_level=this_max_contour_level,
                    contour_interval=this_max_contour_level / 10)

            else:
                this_matrix = class_activation_matrix[i, ...]

                this_max_contour_level = numpy.percentile(
                    this_matrix, max_colour_prctile_for_cam)
                if this_max_contour_level == 0:
                    this_max_contour_level = 10.

                cam_plotting.plot_many_2d_grids(
                    class_activation_matrix_3d=numpy.flip(this_matrix, axis=0),
                    axes_objects_2d_list=these_axes_objects,
                    colour_map_object=cam_colour_map_object,
                    max_contour_level=this_max_contour_level,
                    contour_interval=this_max_contour_level / NUM_CONTOURS)

            this_colour_map_object, this_colour_norm_object = (
                radar_plotting.get_default_colour_scheme(this_field_name)
            )

            plotting_utils.add_colour_bar(
                axes_object_or_list=these_axes_objects,
                values_to_colour=radar_matrix[i, ..., k],
                colour_map=this_colour_map_object,
                colour_norm_object=this_colour_norm_object,
                orientation='horizontal', extend_min=True, extend_max=True)

            if pmm_flag:
                this_title_string = 'Probability-matched mean'
                this_figure_file_name = '{0:s}/{1:s}_pmm_{2:s}.jpg'.format(
                    output_dir_name, pathless_file_name_prefix,
                    this_field_name.replace('_', '-')
                )

            else:
                this_storm_time_string = time_conversion.unix_sec_to_string(
                    storm_times_unix_sec[i], TIME_FORMAT)

                this_title_string = 'Storm "{0:s}" at {1:s}'.format(
                    storm_ids[i], this_storm_time_string)

                this_figure_file_name = (
                    '{0:s}/{1:s}_{2:s}_{3:s}_{4:s}.jpg'
                ).format(
                    output_dir_name, pathless_file_name_prefix,
                    storm_ids[i].replace('_', '-'), this_storm_time_string,
                    this_field_name.replace('_', '-')
                )

            this_title_string += ' ({0:s} = {1:.3f})'.format(
                quantity_string, this_max_contour_level)
            pyplot.suptitle(this_title_string, fontsize=TITLE_FONT_SIZE)

            print 'Saving figure to file: "{0:s}"...'.format(
                this_figure_file_name)
            pyplot.savefig(this_figure_file_name, dpi=FIGURE_RESOLUTION_DPI)
            pyplot.close()
Exemple #3
0
def _plot_3d_radar_cam(colour_map_object,
                       max_colour_percentile,
                       figure_objects,
                       axes_object_matrices,
                       model_metadata_dict,
                       output_dir_name,
                       cam_matrix=None,
                       guided_cam_matrix=None,
                       significance_matrix=None,
                       full_storm_id_string=None,
                       storm_time_unix_sec=None):
    """Plots guided or unguided class-activation map for 3-D radar data.

    This method will plot either `cam_matrix` or `guided_cam_matrix`, not both.

    M = number of rows in spatial grid
    N = number of columns in spatial grid
    H = number of heights in spatial grid
    F = number of radar fields

    If this method is plotting a composite rather than single example (storm
    object), `full_storm_id_string` and `storm_time_unix_sec` can be None.

    :param colour_map_object: See documentation at top of file.
    :param max_colour_percentile: Same.
    :param figure_objects: See doc for `plot_input_examples._plot_3d_example`.
    :param axes_object_matrices: Same.
    :param model_metadata_dict: Dictionary returned by
        `cnn.read_model_metadata`.
    :param output_dir_name: Path to output directory.  Figure(s) will be saved
        here.
    :param cam_matrix: M-by-N-by-H numpy array of class activations.
    :param guided_cam_matrix: M-by-N-by-H-by-F numpy array of guided-CAM output.
    :param significance_matrix: Boolean numpy array with the same dimensions as
        the array being plotted (`cam_matrix` or `guided_cam_matrix`),
        indicating where differences with some other CAM are significant.
    :param full_storm_id_string: Full storm ID.
    :param storm_time_unix_sec: Storm time.
    """

    if cam_matrix is None:
        quantity_string = 'max abs value'
    else:
        quantity_string = 'max activation'

    pmm_flag = full_storm_id_string is None and storm_time_unix_sec is None
    conv_2d3d = model_metadata_dict[cnn.CONV_2D3D_KEY]

    if conv_2d3d:
        loop_max = 1
        radar_field_names = ['reflectivity']
    else:
        loop_max = len(figure_objects)
        training_option_dict = model_metadata_dict[
            cnn.TRAINING_OPTION_DICT_KEY]
        radar_field_names = training_option_dict[trainval_io.RADAR_FIELDS_KEY]

    for j in range(loop_max):
        if cam_matrix is None:
            this_matrix = guided_cam_matrix[..., j]
        else:
            this_matrix = cam_matrix

        this_max_contour_level = numpy.percentile(numpy.absolute(this_matrix),
                                                  max_colour_percentile)

        if cam_matrix is None:
            saliency_plotting.plot_many_2d_grids_with_contours(
                saliency_matrix_3d=numpy.flip(this_matrix, axis=0),
                axes_object_matrix=axes_object_matrices[j],
                colour_map_object=colour_map_object,
                max_absolute_contour_level=this_max_contour_level,
                contour_interval=this_max_contour_level / HALF_NUM_CONTOURS)
        else:
            cam_plotting.plot_many_2d_grids(
                class_activation_matrix_3d=numpy.flip(this_matrix, axis=0),
                axes_object_matrix=axes_object_matrices[j],
                colour_map_object=colour_map_object,
                max_contour_level=this_max_contour_level,
                contour_interval=this_max_contour_level / NUM_CONTOURS)

        if significance_matrix is not None:
            if cam_matrix is None:
                this_matrix = significance_matrix[..., j]
            else:
                this_matrix = significance_matrix

            significance_plotting.plot_many_2d_grids_without_coords(
                significance_matrix=numpy.flip(this_matrix, axis=0),
                axes_object_matrix=axes_object_matrices[j])

        this_title_string = figure_objects[j]._suptitle

        if this_title_string is not None:
            this_title_string += ' ({0:s} = {1:.2e})'.format(
                quantity_string, this_max_contour_level)

            figure_objects[j].suptitle(
                this_title_string,
                fontsize=plot_input_examples.TITLE_FONT_SIZE)

        this_file_name = plot_input_examples.metadata_to_file_name(
            output_dir_name=output_dir_name,
            is_sounding=False,
            pmm_flag=pmm_flag,
            full_storm_id_string=full_storm_id_string,
            storm_time_unix_sec=storm_time_unix_sec,
            radar_field_name=radar_field_names[j])

        print('Saving figure to: "{0:s}"...'.format(this_file_name))
        figure_objects[j].savefig(this_file_name,
                                  dpi=FIGURE_RESOLUTION_DPI,
                                  pad_inches=0,
                                  bbox_inches='tight')
        pyplot.close(figure_objects[j])
def _plot_2d_radar_cams(
        radar_matrix, model_metadata_dict, cam_colour_map_object,
        max_colour_prctile_for_cam, output_dir_name,
        class_activation_matrix=None, ggradcam_output_matrix=None,
        storm_ids=None, storm_times_unix_sec=None):
    """Plots class-activation maps for 2-D radar data.

    E = number of examples
    M = number of rows in spatial grid
    N = number of columns in spatial grid
    C = number of channels (field/height pairs)

    This method will plot either `class_activation_matrix` or
    `ggradcam_output_matrix`, not both.

    If `storm_ids is None` and `storm_times_unix_sec is None`, will assume that
    the input matrices contain probability-matched means.

    :param radar_matrix: E-by-M-by-N-by-C numpy array of radar values.
    :param model_metadata_dict: See doc for `_plot_3d_radar_cams`.
    :param cam_colour_map_object: Same.
    :param max_colour_prctile_for_cam: Same.
    :param output_dir_name: Same.
    :param class_activation_matrix: E-by-M-by-N numpy array of class
        activations.
    :param ggradcam_output_matrix: E-by-M-by-N-by-C numpy array of output values
        from guided Grad-CAM.
    :param storm_ids: length-E list of storm IDs (strings).
    :param storm_times_unix_sec: length-E numpy array of storm times.
    """

    pmm_flag = storm_ids is None and storm_times_unix_sec is None

    training_option_dict = model_metadata_dict[cnn.TRAINING_OPTION_DICT_KEY]
    list_of_layer_operation_dicts = model_metadata_dict[
        cnn.LAYER_OPERATIONS_KEY]

    if list_of_layer_operation_dicts is None:
        field_name_by_panel = training_option_dict[
            trainval_io.RADAR_FIELDS_KEY]

        panel_names = (
            radar_plotting.radar_fields_and_heights_to_panel_names(
                field_names=field_name_by_panel,
                heights_m_agl=training_option_dict[
                    trainval_io.RADAR_HEIGHTS_KEY]
            )
        )

        plot_colour_bar_by_panel = numpy.full(
            len(panel_names), True, dtype=bool)

    else:
        field_name_by_panel, panel_names = (
            radar_plotting.layer_ops_to_field_and_panel_names(
                list_of_layer_operation_dicts)
        )

        plot_colour_bar_by_panel = numpy.full(
            len(panel_names), False, dtype=bool)
        plot_colour_bar_by_panel[2::3] = True

    num_examples = radar_matrix.shape[0]
    num_channels = radar_matrix.shape[-1]
    num_panel_rows = int(numpy.floor(numpy.sqrt(num_channels)))

    if class_activation_matrix is None:
        quantity_string = 'max absolute guided Grad-CAM output'
        pathless_file_name_prefix = 'guided-gradcam'
    else:
        quantity_string = 'max class activation'
        pathless_file_name_prefix = 'gradcam'

    for i in range(num_examples):
        _, these_axes_objects = (
            radar_plotting.plot_many_2d_grids_without_coords(
                field_matrix=numpy.flip(radar_matrix[i, ...], axis=0),
                field_name_by_panel=field_name_by_panel,
                panel_names=panel_names, num_panel_rows=num_panel_rows,
                plot_colour_bar_by_panel=plot_colour_bar_by_panel,
                font_size=FONT_SIZE_WITH_COLOUR_BARS, row_major=False)
        )

        if class_activation_matrix is None:
            this_matrix = ggradcam_output_matrix[i, ...]

            this_max_contour_level = numpy.percentile(
                numpy.absolute(this_matrix), max_colour_prctile_for_cam)
            if this_max_contour_level == 0:
                this_max_contour_level = 10.

            saliency_plotting.plot_many_2d_grids_with_contours(
                saliency_matrix_3d=numpy.flip(this_matrix, axis=0),
                axes_objects_2d_list=these_axes_objects,
                colour_map_object=cam_colour_map_object,
                max_absolute_contour_level=this_max_contour_level,
                contour_interval=this_max_contour_level / 10, row_major=False)

        else:
            this_matrix = numpy.expand_dims(
                class_activation_matrix[i, ...], axis=-1)
            this_matrix = numpy.repeat(
                this_matrix, repeats=num_channels, axis=-1)

            this_max_contour_level = numpy.percentile(
                this_matrix, max_colour_prctile_for_cam)
            if this_max_contour_level == 0:
                this_max_contour_level = 10.

            cam_plotting.plot_many_2d_grids(
                class_activation_matrix_3d=numpy.flip(this_matrix, axis=0),
                axes_objects_2d_list=these_axes_objects,
                colour_map_object=cam_colour_map_object,
                max_contour_level=this_max_contour_level,
                contour_interval=this_max_contour_level / NUM_CONTOURS,
                row_major=False)

        if pmm_flag:
            this_title_string = 'Probability-matched mean'
            this_figure_file_name = '{0:s}/{1:s}_pmm_radar.jpg'.format(
                output_dir_name, pathless_file_name_prefix)

        else:
            this_storm_time_string = time_conversion.unix_sec_to_string(
                storm_times_unix_sec[i], TIME_FORMAT)

            this_title_string = 'Storm "{0:s}" at {1:s}'.format(
                storm_ids[i], this_storm_time_string)

            this_figure_file_name = '{0:s}/{1:s}_{2:s}_{3:s}_radar.jpg'.format(
                output_dir_name, pathless_file_name_prefix,
                storm_ids[i].replace('_', '-'), this_storm_time_string)

        this_title_string += ' ({0:s} = {1:.3f})'.format(
            quantity_string, this_max_contour_level)
        pyplot.suptitle(this_title_string, fontsize=TITLE_FONT_SIZE)

        print 'Saving figure to file: "{0:s}"...'.format(this_figure_file_name)
        pyplot.savefig(this_figure_file_name, dpi=FIGURE_RESOLUTION_DPI)
        pyplot.close()
Exemple #5
0
def _plot_one_composite(gradcam_file_name, monte_carlo_file_name,
                        composite_name_abbrev, composite_name_verbose,
                        colour_map_object, min_colour_value, max_colour_value,
                        num_contours, smoothing_radius_grid_cells,
                        monte_carlo_max_fdr, output_dir_name):
    """Plots class-activation map for one composite.

    :param gradcam_file_name: Path to input file (will be read by
        `gradcam.read_file`).
    :param monte_carlo_file_name: Path to Monte Carlo file (will be read by
        `_read_monte_carlo_file`).
    :param composite_name_abbrev: Abbrev composite name (will be used in file
        names).
    :param composite_name_verbose: Verbose composite name (will be used in
        figure title).
    :param colour_map_object: See documentation at top of file.
    :param min_colour_value: Minimum value in colour bar (may be NaN).
    :param max_colour_value: Max value in colour bar (may be NaN).
    :param num_contours: See documentation at top of file.
    :param smoothing_radius_grid_cells: Same.
    :param monte_carlo_max_fdr: Same.
    :param output_dir_name: Name of output directory (figures will be saved
        here).
    :return: main_figure_file_name: Path to main image file created by this
        method.
    :return: min_colour_value: Same as input but cannot be None.
    :return: max_colour_value: Same as input but cannot be None.
    """

    (mean_radar_matrix, mean_class_activn_matrix, significance_matrix,
     model_metadata_dict) = _read_one_composite(
         gradcam_file_name=gradcam_file_name,
         smoothing_radius_grid_cells=smoothing_radius_grid_cells,
         monte_carlo_file_name=monte_carlo_file_name,
         monte_carlo_max_fdr=monte_carlo_max_fdr)

    if numpy.isnan(min_colour_value) or numpy.isnan(max_colour_value):
        min_colour_value_log10 = numpy.log10(
            numpy.percentile(mean_class_activn_matrix, 1.))
        max_colour_value_log10 = numpy.log10(
            numpy.percentile(mean_class_activn_matrix, 99.))

        min_colour_value_log10 = max([min_colour_value_log10, -2.])
        max_colour_value_log10 = max([max_colour_value_log10, -1.])

        min_colour_value_log10 = min([min_colour_value_log10, 1.])
        max_colour_value_log10 = min([max_colour_value_log10, 2.])

        min_colour_value = 10**min_colour_value_log10
        max_colour_value = 10**max_colour_value_log10
    else:
        min_colour_value_log10 = numpy.log10(min_colour_value)
        max_colour_value_log10 = numpy.log10(max_colour_value)

    contour_interval_log10 = (
        (max_colour_value_log10 - min_colour_value_log10) / (num_contours - 1))
    mean_activn_matrix_log10 = numpy.log10(mean_class_activn_matrix)

    training_option_dict = model_metadata_dict[cnn.TRAINING_OPTION_DICT_KEY]
    field_names = training_option_dict[trainval_io.RADAR_FIELDS_KEY]

    num_fields = mean_radar_matrix.shape[-1]
    num_heights = mean_radar_matrix.shape[-2]

    handle_dict = plot_examples.plot_one_example(
        list_of_predictor_matrices=[mean_radar_matrix],
        model_metadata_dict=model_metadata_dict,
        pmm_flag=True,
        allow_whitespace=True,
        plot_panel_names=True,
        panel_name_font_size=PANEL_NAME_FONT_SIZE,
        add_titles=False,
        label_colour_bars=True,
        colour_bar_length=COLOUR_BAR_LENGTH,
        colour_bar_font_size=COLOUR_BAR_FONT_SIZE,
        num_panel_rows=num_heights)

    figure_objects = handle_dict[plot_examples.RADAR_FIGURES_KEY]
    axes_object_matrices = handle_dict[plot_examples.RADAR_AXES_KEY]

    for k in range(num_fields):
        cam_plotting.plot_many_2d_grids(
            class_activation_matrix_3d=numpy.flip(
                mean_activn_matrix_log10[0, ...], axis=0),
            axes_object_matrix=axes_object_matrices[k],
            colour_map_object=colour_map_object,
            min_contour_level=min_colour_value_log10,
            max_contour_level=max_colour_value_log10,
            contour_interval=contour_interval_log10)

        significance_plotting.plot_many_2d_grids_without_coords(
            significance_matrix=numpy.flip(significance_matrix[0, ...],
                                           axis=0),
            axes_object_matrix=axes_object_matrices[k])

    panel_file_names = [None] * num_fields

    for k in range(num_fields):
        panel_file_names[k] = '{0:s}/{1:s}_{2:s}.jpg'.format(
            output_dir_name, composite_name_abbrev,
            field_names[k].replace('_', '-'))

        print('Saving figure to: "{0:s}"...'.format(panel_file_names[k]))

        figure_objects[k].savefig(panel_file_names[k],
                                  dpi=FIGURE_RESOLUTION_DPI,
                                  pad_inches=0,
                                  bbox_inches='tight')
        pyplot.close(figure_objects[k])

    main_figure_file_name = '{0:s}/{1:s}_gradcam.jpg'.format(
        output_dir_name, composite_name_abbrev)

    print('Concatenating panels to: "{0:s}"...'.format(main_figure_file_name))
    imagemagick_utils.concatenate_images(
        input_file_names=panel_file_names,
        output_file_name=main_figure_file_name,
        num_panel_rows=1,
        num_panel_columns=num_fields,
        border_width_pixels=50)
    imagemagick_utils.resize_image(input_file_name=main_figure_file_name,
                                   output_file_name=main_figure_file_name,
                                   output_size_pixels=CONCAT_FIGURE_SIZE_PX)
    imagemagick_utils.trim_whitespace(input_file_name=main_figure_file_name,
                                      output_file_name=main_figure_file_name,
                                      border_width_pixels=TITLE_FONT_SIZE + 25)
    _overlay_text(image_file_name=main_figure_file_name,
                  x_offset_from_center_px=0,
                  y_offset_from_top_px=0,
                  text_string=composite_name_verbose)
    imagemagick_utils.trim_whitespace(input_file_name=main_figure_file_name,
                                      output_file_name=main_figure_file_name,
                                      border_width_pixels=10)

    return main_figure_file_name, min_colour_value, max_colour_value
Exemple #6
0
def _plot_one_composite(gradcam_file_name, composite_name_abbrev,
                        composite_name_verbose, colour_map_object,
                        max_colour_value, num_contours,
                        smoothing_radius_grid_cells, output_dir_name):
    """Plots class-activation map for one composite.

    :param gradcam_file_name: Path to input file (will be read by
        `gradcam.read_file`).
    :param composite_name_abbrev: Abbrev composite name (will be used in file
        names).
    :param composite_name_verbose: Verbose composite name (will be used in
        figure title).
    :param colour_map_object: See documentation at top of file.
    :param max_colour_value: Same.
    :param num_contours: Same.
    :param smoothing_radius_grid_cells: Same.
    :param output_dir_name: Name of output directory (figures will be saved
        here).
    :return: main_figure_file_name: Path to main image file created by this
        method.
    """

    mean_predictor_matrices, mean_cam_matrices, model_metadata_dict = (
        _read_one_composite(
            gradcam_file_name=gradcam_file_name,
            smoothing_radius_grid_cells=smoothing_radius_grid_cells))

    refl_heights_m_agl = model_metadata_dict[cnn.TRAINING_OPTION_DICT_KEY][
        trainval_io.RADAR_HEIGHTS_KEY]
    num_refl_heights = len(refl_heights_m_agl)

    handle_dict = plot_examples.plot_one_example(
        list_of_predictor_matrices=mean_predictor_matrices,
        model_metadata_dict=model_metadata_dict,
        pmm_flag=True,
        plot_sounding=False,
        allow_whitespace=True,
        plot_panel_names=True,
        panel_name_font_size=PANEL_NAME_FONT_SIZE,
        add_titles=False,
        label_colour_bars=True,
        colour_bar_length=COLOUR_BAR_LENGTH,
        colour_bar_font_size=COLOUR_BAR_FONT_SIZE,
        num_panel_rows=num_refl_heights)

    axes_object_matrices = handle_dict[plot_examples.RADAR_AXES_KEY]

    max_colour_value_log10 = numpy.log10(max_colour_value)
    contour_interval_log10 = (
        (max_colour_value_log10 - MIN_COLOUR_VALUE_LOG10) / (num_contours - 1))

    this_matrix = numpy.flip(mean_cam_matrices[0][0, ...], axis=0)
    this_matrix = numpy.log10(this_matrix)

    cam_plotting.plot_many_2d_grids(class_activation_matrix_3d=this_matrix,
                                    axes_object_matrix=axes_object_matrices[0],
                                    colour_map_object=colour_map_object,
                                    min_contour_level=MIN_COLOUR_VALUE_LOG10,
                                    max_contour_level=max_colour_value_log10,
                                    contour_interval=contour_interval_log10,
                                    row_major=True)

    this_matrix = numpy.flip(mean_cam_matrices[1][0, ...], axis=0)
    this_num_channels = mean_predictor_matrices[1].shape[-1]
    this_matrix = numpy.repeat(a=numpy.expand_dims(this_matrix, axis=-1),
                               axis=-1,
                               repeats=this_num_channels)
    this_matrix = numpy.log10(this_matrix)

    cam_plotting.plot_many_2d_grids(class_activation_matrix_3d=this_matrix,
                                    axes_object_matrix=axes_object_matrices[1],
                                    colour_map_object=colour_map_object,
                                    min_contour_level=MIN_COLOUR_VALUE_LOG10,
                                    max_contour_level=max_colour_value_log10,
                                    contour_interval=contour_interval_log10,
                                    row_major=False)

    refl_figure_object = handle_dict[plot_examples.RADAR_FIGURES_KEY][0]
    refl_figure_file_name = '{0:s}/{1:s}_reflectivity.jpg'.format(
        output_dir_name, composite_name_abbrev)

    print('Saving figure to: "{0:s}"...'.format(refl_figure_file_name))
    refl_figure_object.savefig(refl_figure_file_name,
                               dpi=FIGURE_RESOLUTION_DPI,
                               pad_inches=0,
                               bbox_inches='tight')
    pyplot.close(refl_figure_object)

    shear_figure_object = handle_dict[plot_examples.RADAR_FIGURES_KEY][1]
    shear_figure_file_name = '{0:s}/{1:s}_shear.jpg'.format(
        output_dir_name, composite_name_abbrev)

    print('Saving figure to: "{0:s}"...'.format(shear_figure_file_name))
    shear_figure_object.savefig(shear_figure_file_name,
                                dpi=FIGURE_RESOLUTION_DPI,
                                pad_inches=0,
                                bbox_inches='tight')
    pyplot.close(shear_figure_object)

    main_figure_file_name = '{0:s}/{1:s}_radar.jpg'.format(
        output_dir_name, composite_name_abbrev)
    print('Concatenating panels to: "{0:s}"...'.format(main_figure_file_name))

    imagemagick_utils.concatenate_images(
        input_file_names=[refl_figure_file_name, shear_figure_file_name],
        output_file_name=main_figure_file_name,
        num_panel_rows=1,
        num_panel_columns=2,
        border_width_pixels=50,
        extra_args_string='-gravity south')
    imagemagick_utils.resize_image(input_file_name=main_figure_file_name,
                                   output_file_name=main_figure_file_name,
                                   output_size_pixels=CONCAT_FIGURE_SIZE_PX)
    imagemagick_utils.trim_whitespace(input_file_name=main_figure_file_name,
                                      output_file_name=main_figure_file_name,
                                      border_width_pixels=TITLE_FONT_SIZE + 25)
    _overlay_text(image_file_name=main_figure_file_name,
                  x_offset_from_center_px=0,
                  y_offset_from_top_px=0,
                  text_string=composite_name_verbose)
    imagemagick_utils.trim_whitespace(input_file_name=main_figure_file_name,
                                      output_file_name=main_figure_file_name,
                                      border_width_pixels=10)

    return main_figure_file_name
def _plot_2d_radar_cam(
        colour_map_object, min_unguided_value, max_unguided_value,
        num_unguided_contours, max_guided_value, half_num_guided_contours,
        label_colour_bars, colour_bar_length, figure_objects,
        axes_object_matrices, model_metadata_dict, output_dir_name,
        cam_matrix=None, guided_cam_matrix=None, full_storm_id_string=None,
        storm_time_unix_sec=None):
    """Plots class-activation map for 2-D radar data.

    M = number of rows in spatial grid
    N = number of columns in spatial grid
    F = number of radar fields

    If this method is plotting a composite rather than single example (storm
    object), `full_storm_id_string` and `storm_time_unix_sec` can be None.

    :param colour_map_object: See doc for `_plot_3d_radar_cam`.
    :param min_unguided_value: Same.
    :param max_unguided_value: Same.
    :param num_unguided_contours: Same.
    :param max_guided_value: Same.
    :param half_num_guided_contours: Same.
    :param label_colour_bars: Same.
    :param colour_bar_length: Same.
    :param figure_objects: See doc for
        `plot_input_examples._plot_2d_radar_scan`.
    :param axes_object_matrices: Same.
    :param model_metadata_dict: See doc for `_plot_3d_radar_cam`.
    :param output_dir_name: Same.
    :param cam_matrix: M-by-N numpy array of unguided class activations.
    :param guided_cam_matrix: [used only if `cam_matrix is None`]
        M-by-N-by-F numpy array of guided class activations.
    :param full_storm_id_string: Full storm ID.
    :param storm_time_unix_sec: Storm time.
    """

    pmm_flag = full_storm_id_string is None and storm_time_unix_sec is None
    conv_2d3d = model_metadata_dict[cnn.CONV_2D3D_KEY]

    if conv_2d3d:
        figure_index = 1
        radar_field_name = 'shear'
    else:
        figure_index = 0
        radar_field_name = None

    list_of_layer_operation_dicts = model_metadata_dict[
        cnn.LAYER_OPERATIONS_KEY]

    if list_of_layer_operation_dicts is None:
        training_option_dict = model_metadata_dict[cnn.TRAINING_OPTION_DICT_KEY]
        radar_field_names = training_option_dict[trainval_io.RADAR_FIELDS_KEY]
        num_channels = len(radar_field_names)
    else:
        num_channels = len(list_of_layer_operation_dicts)

    min_unguided_value_log10 = numpy.log10(min_unguided_value)
    max_unguided_value_log10 = numpy.log10(max_unguided_value)
    contour_interval_log10 = (
        (max_unguided_value_log10 - min_unguided_value_log10) /
        (num_unguided_contours - 1)
    )

    if cam_matrix is None:
        saliency_plotting.plot_many_2d_grids_with_contours(
            saliency_matrix_3d=numpy.flip(guided_cam_matrix, axis=0),
            axes_object_matrix=axes_object_matrices[figure_index],
            colour_map_object=colour_map_object,
            max_absolute_contour_level=max_guided_value,
            contour_interval=max_guided_value / half_num_guided_contours,
            row_major=False
        )

        this_colour_bar_object = plotting_utils.plot_linear_colour_bar(
            axes_object_or_matrix=axes_object_matrices[figure_index],
            data_matrix=guided_cam_matrix,
            colour_map_object=colour_map_object, min_value=0.,
            max_value=max_guided_value, orientation_string='horizontal',
            fraction_of_axis_length=colour_bar_length / (1 + int(conv_2d3d)),
            extend_min=False, extend_max=True,
            font_size=COLOUR_BAR_FONT_SIZE
        )

        if label_colour_bars:
            this_colour_bar_object.set_label(
                'Absolute guided class activation',
                fontsize=COLOUR_BAR_FONT_SIZE
            )
    else:
        this_cam_matrix_log10 = numpy.log10(
            numpy.expand_dims(cam_matrix, axis=-1)
        )
        this_cam_matrix_log10 = numpy.repeat(
            this_cam_matrix_log10, repeats=num_channels, axis=-1
        )

        cam_plotting.plot_many_2d_grids(
            class_activation_matrix_3d=numpy.flip(
                this_cam_matrix_log10, axis=0
            ),
            axes_object_matrix=axes_object_matrices[figure_index],
            colour_map_object=colour_map_object,
            min_contour_level=min_unguided_value_log10,
            max_contour_level=max_unguided_value_log10,
            contour_interval=contour_interval_log10, row_major=False
        )

        this_colour_bar_object = plotting_utils.plot_linear_colour_bar(
            axes_object_or_matrix=axes_object_matrices[figure_index],
            data_matrix=this_cam_matrix_log10,
            colour_map_object=colour_map_object,
            min_value=min_unguided_value_log10,
            max_value=max_unguided_value_log10,
            orientation_string='horizontal',
            fraction_of_axis_length=colour_bar_length / (1 + int(conv_2d3d)),
            extend_min=True, extend_max=True,
            font_size=COLOUR_BAR_FONT_SIZE
        )

        these_tick_values = this_colour_bar_object.get_ticks()
        these_tick_strings = [
            '{0:.2f}'.format(10 ** v)[:4] for v in these_tick_values
        ]
        this_colour_bar_object.set_ticks(these_tick_values)
        this_colour_bar_object.set_ticklabels(these_tick_strings)

        if label_colour_bars:
            this_colour_bar_object.set_label(
                'Class activation', fontsize=COLOUR_BAR_FONT_SIZE
            )

    output_file_name = plot_examples.metadata_to_file_name(
        output_dir_name=output_dir_name, is_sounding=False, pmm_flag=pmm_flag,
        full_storm_id_string=full_storm_id_string,
        storm_time_unix_sec=storm_time_unix_sec,
        radar_field_name=radar_field_name
    )

    print('Saving figure to: "{0:s}"...'.format(output_file_name))
    figure_objects[figure_index].savefig(
        output_file_name, dpi=FIGURE_RESOLUTION_DPI, pad_inches=0,
        bbox_inches='tight'
    )
    pyplot.close(figure_objects[figure_index])
def _plot_3d_radar_cam(
        colour_map_object, min_unguided_value, max_unguided_value,
        num_unguided_contours, max_guided_value, half_num_guided_contours,
        label_colour_bars, colour_bar_length, figure_objects,
        axes_object_matrices, model_metadata_dict, output_dir_name,
        cam_matrix=None, guided_cam_matrix=None, full_storm_id_string=None,
        storm_time_unix_sec=None):
    """Plots class-activation map for 3-D radar data.

    M = number of rows in spatial grid
    N = number of columns in spatial grid
    H = number of heights in spatial grid
    F = number of radar fields

    If this method is plotting a composite rather than single example (storm
    object), `full_storm_id_string` and `storm_time_unix_sec` can be None.

    :param colour_map_object: See documentation at top of file.
    :param min_unguided_value: Same.
    :param max_unguided_value: Same.
    :param num_unguided_contours: Same.
    :param max_guided_value: Same.
    :param half_num_guided_contours: Same.
    :param label_colour_bars: Same.
    :param colour_bar_length: Same.
    :param figure_objects: See doc for
        `plot_input_examples._plot_3d_radar_scan`.
    :param axes_object_matrices: Same.
    :param model_metadata_dict: Dictionary returned by
        `cnn.read_model_metadata`.
    :param output_dir_name: Path to output directory.  Figure(s) will be saved
        here.
    :param cam_matrix: M-by-N-by-H numpy array of unguided class activations.
    :param guided_cam_matrix: [used only if `cam_matrix is None`]
        M-by-N-by-H-by-F numpy array of guided class activations.
    :param full_storm_id_string: Full storm ID.
    :param storm_time_unix_sec: Storm time.
    """

    pmm_flag = full_storm_id_string is None and storm_time_unix_sec is None
    conv_2d3d = model_metadata_dict[cnn.CONV_2D3D_KEY]

    if conv_2d3d:
        loop_max = 1
        radar_field_names = ['reflectivity']
    else:
        loop_max = len(figure_objects)
        training_option_dict = model_metadata_dict[cnn.TRAINING_OPTION_DICT_KEY]
        radar_field_names = training_option_dict[trainval_io.RADAR_FIELDS_KEY]

    min_unguided_value_log10 = numpy.log10(min_unguided_value)
    max_unguided_value_log10 = numpy.log10(max_unguided_value)
    contour_interval_log10 = (
        (max_unguided_value_log10 - min_unguided_value_log10) /
        (num_unguided_contours - 1)
    )

    for j in range(loop_max):
        if cam_matrix is None:
            saliency_plotting.plot_many_2d_grids_with_contours(
                saliency_matrix_3d=numpy.flip(
                    guided_cam_matrix[..., j], axis=0
                ),
                axes_object_matrix=axes_object_matrices[j],
                colour_map_object=colour_map_object,
                max_absolute_contour_level=max_guided_value,
                contour_interval=max_guided_value / half_num_guided_contours
            )

            this_colour_bar_object = plotting_utils.plot_linear_colour_bar(
                axes_object_or_matrix=axes_object_matrices[j],
                data_matrix=guided_cam_matrix[..., j],
                colour_map_object=colour_map_object, min_value=0.,
                max_value=max_guided_value, orientation_string='horizontal',
                fraction_of_axis_length=colour_bar_length,
                extend_min=False, extend_max=True,
                font_size=COLOUR_BAR_FONT_SIZE
            )

            if label_colour_bars:
                this_colour_bar_object.set_label(
                    'Absolute guided class activation',
                    fontsize=COLOUR_BAR_FONT_SIZE
                )
        else:
            cam_matrix_log10 = numpy.log10(cam_matrix)

            cam_plotting.plot_many_2d_grids(
                class_activation_matrix_3d=numpy.flip(cam_matrix_log10, axis=0),
                axes_object_matrix=axes_object_matrices[j],
                colour_map_object=colour_map_object,
                min_contour_level=min_unguided_value_log10,
                max_contour_level=max_unguided_value_log10,
                contour_interval=contour_interval_log10
            )

            this_colour_bar_object = plotting_utils.plot_linear_colour_bar(
                axes_object_or_matrix=axes_object_matrices[j],
                data_matrix=cam_matrix_log10,
                colour_map_object=colour_map_object,
                min_value=min_unguided_value_log10,
                max_value=max_unguided_value_log10,
                orientation_string='horizontal',
                fraction_of_axis_length=colour_bar_length,
                extend_min=True, extend_max=True,
                font_size=COLOUR_BAR_FONT_SIZE
            )

            these_tick_values = this_colour_bar_object.get_ticks()
            these_tick_strings = [
                '{0:.2f}'.format(10 ** v)[:4] for v in these_tick_values
            ]
            this_colour_bar_object.set_ticks(these_tick_values)
            this_colour_bar_object.set_ticklabels(these_tick_strings)

            if label_colour_bars:
                this_colour_bar_object.set_label(
                    'Class activation', fontsize=COLOUR_BAR_FONT_SIZE
                )

        this_file_name = plot_examples.metadata_to_file_name(
            output_dir_name=output_dir_name, is_sounding=False,
            pmm_flag=pmm_flag, full_storm_id_string=full_storm_id_string,
            storm_time_unix_sec=storm_time_unix_sec,
            radar_field_name=radar_field_names[j]
        )

        print('Saving figure to: "{0:s}"...'.format(this_file_name))
        figure_objects[j].savefig(
            this_file_name, dpi=FIGURE_RESOLUTION_DPI, pad_inches=0,
            bbox_inches='tight'
        )
        pyplot.close(figure_objects[j])