Example #1
0
def _plot_one_composite(saliency_file_name, composite_name_abbrev,
                        composite_name_verbose, colour_map_object,
                        max_colour_value, half_num_contours,
                        smoothing_radius_grid_cells, output_dir_name):
    """Plots saliency map for one composite.

    :param saliency_file_name: Path to input file (will be read by
        `saliency.read_file`).
    :param composite_name_abbrev: Abbrev composite name (will be used in file
        names).
    :param composite_name_verbose: Verbose composite name (will be used in
        figure title).
    :param colour_map_object: See documentation at top of file.
    :param max_colour_value: Same.
    :param half_num_contours: Same.
    :param smoothing_radius_grid_cells: Same.
    :param output_dir_name: Name of output directory (figures will be saved
        here).
    :return: main_figure_file_name: Path to main image file created by this
        method.
    :return: max_colour_value: See input doc.
    """

    mean_radar_matrix, mean_saliency_matrix, model_metadata_dict = (
        _read_one_composite(
            saliency_file_name=saliency_file_name,
            smoothing_radius_grid_cells=smoothing_radius_grid_cells))

    if numpy.isnan(max_colour_value):
        max_colour_value = numpy.percentile(mean_saliency_matrix,
                                            MAX_COLOUR_PERCENTILE)

    training_option_dict = model_metadata_dict[cnn.TRAINING_OPTION_DICT_KEY]
    field_names = training_option_dict[trainval_io.RADAR_FIELDS_KEY]

    num_fields = mean_radar_matrix.shape[-1]
    num_heights = mean_radar_matrix.shape[-2]

    handle_dict = plot_examples.plot_one_example(
        list_of_predictor_matrices=[mean_radar_matrix],
        model_metadata_dict=model_metadata_dict,
        pmm_flag=True,
        allow_whitespace=True,
        plot_panel_names=True,
        panel_name_font_size=PANEL_NAME_FONT_SIZE,
        add_titles=False,
        label_colour_bars=True,
        colour_bar_length=COLOUR_BAR_LENGTH,
        colour_bar_font_size=COLOUR_BAR_FONT_SIZE,
        num_panel_rows=num_heights)

    figure_objects = handle_dict[plot_examples.RADAR_FIGURES_KEY]
    axes_object_matrices = handle_dict[plot_examples.RADAR_AXES_KEY]

    for k in range(num_fields):
        this_saliency_matrix = mean_saliency_matrix[0, ..., k]

        saliency_plotting.plot_many_2d_grids_with_contours(
            saliency_matrix_3d=numpy.flip(this_saliency_matrix, axis=0),
            axes_object_matrix=axes_object_matrices[k],
            colour_map_object=colour_map_object,
            max_absolute_contour_level=max_colour_value,
            contour_interval=max_colour_value / half_num_contours)

    panel_file_names = [None] * num_fields

    for k in range(num_fields):
        panel_file_names[k] = '{0:s}/{1:s}_{2:s}.jpg'.format(
            output_dir_name, composite_name_abbrev,
            field_names[k].replace('_', '-'))

        print('Saving figure to: "{0:s}"...'.format(panel_file_names[k]))

        figure_objects[k].savefig(panel_file_names[k],
                                  dpi=FIGURE_RESOLUTION_DPI,
                                  pad_inches=0,
                                  bbox_inches='tight')
        pyplot.close(figure_objects[k])

    main_figure_file_name = '{0:s}/{1:s}_saliency.jpg'.format(
        output_dir_name, composite_name_abbrev)

    print('Concatenating panels to: "{0:s}"...'.format(main_figure_file_name))

    imagemagick_utils.concatenate_images(
        input_file_names=panel_file_names,
        output_file_name=main_figure_file_name,
        num_panel_rows=1,
        num_panel_columns=num_fields,
        border_width_pixels=50)

    imagemagick_utils.resize_image(input_file_name=main_figure_file_name,
                                   output_file_name=main_figure_file_name,
                                   output_size_pixels=CONCAT_FIGURE_SIZE_PX)

    imagemagick_utils.trim_whitespace(input_file_name=main_figure_file_name,
                                      output_file_name=main_figure_file_name,
                                      border_width_pixels=TITLE_FONT_SIZE + 25)

    _overlay_text(image_file_name=main_figure_file_name,
                  x_offset_from_center_px=0,
                  y_offset_from_top_px=0,
                  text_string=composite_name_verbose)

    imagemagick_utils.trim_whitespace(input_file_name=main_figure_file_name,
                                      output_file_name=main_figure_file_name,
                                      border_width_pixels=10)

    return main_figure_file_name, max_colour_value
Example #2
0
def _run(input_file_name, allow_whitespace, plot_significance,
         plot_regions_of_interest, colour_map_name, max_colour_percentile,
         top_output_dir_name):
    """Plots Grad-CAM output (class-activation maps).

    This is effectively the main method.

    :param input_file_name: See documentation at top of file.
    :param allow_whitespace: Same.
    :param plot_significance: Same.
    :param plot_regions_of_interest: Same.
    :param colour_map_name: Same.
    :param max_colour_percentile: Same.
    :param top_output_dir_name: Same.
    """

    if plot_significance:
        plot_regions_of_interest = False

    unguided_cam_dir_name = '{0:s}/main_gradcam'.format(top_output_dir_name)
    guided_cam_dir_name = '{0:s}/guided_gradcam'.format(top_output_dir_name)

    file_system_utils.mkdir_recursive_if_necessary(
        directory_name=unguided_cam_dir_name)
    file_system_utils.mkdir_recursive_if_necessary(
        directory_name=guided_cam_dir_name)

    # Check input args.
    error_checking.assert_is_geq(max_colour_percentile, 0.)
    error_checking.assert_is_leq(max_colour_percentile, 100.)
    colour_map_object = pyplot.cm.get_cmap(colour_map_name)

    print('Reading data from: "{0:s}"...'.format(input_file_name))

    try:
        gradcam_dict = gradcam.read_standard_file(input_file_name)
        list_of_input_matrices = gradcam_dict[gradcam.INPUT_MATRICES_KEY]
        list_of_cam_matrices = gradcam_dict[gradcam.CAM_MATRICES_KEY]
        list_of_guided_cam_matrices = gradcam_dict[
            gradcam.GUIDED_CAM_MATRICES_KEY]

        full_storm_id_strings = gradcam_dict[gradcam.FULL_IDS_KEY]
        storm_times_unix_sec = gradcam_dict[gradcam.STORM_TIMES_KEY]

    except ValueError:
        gradcam_dict = gradcam.read_pmm_file(input_file_name)
        list_of_input_matrices = gradcam_dict[gradcam.MEAN_INPUT_MATRICES_KEY]
        list_of_cam_matrices = gradcam_dict[gradcam.MEAN_CAM_MATRICES_KEY]
        list_of_guided_cam_matrices = gradcam_dict[
            gradcam.MEAN_GUIDED_CAM_MATRICES_KEY]

        for i in range(len(list_of_input_matrices)):
            list_of_input_matrices[i] = numpy.expand_dims(
                list_of_input_matrices[i], axis=0)

            if list_of_cam_matrices[i] is None:
                continue

            list_of_cam_matrices[i] = numpy.expand_dims(
                list_of_cam_matrices[i], axis=0)
            list_of_guided_cam_matrices[i] = numpy.expand_dims(
                list_of_guided_cam_matrices[i], axis=0)

        full_storm_id_strings = [None]
        storm_times_unix_sec = [None]

    pmm_flag = (full_storm_id_strings[0] is None
                and storm_times_unix_sec[0] is None)

    # Read metadata for CNN.
    model_file_name = gradcam_dict[gradcam.MODEL_FILE_KEY]
    model_metafile_name = '{0:s}/model_metadata.p'.format(
        os.path.split(model_file_name)[0])

    print(
        'Reading model metadata from: "{0:s}"...'.format(model_metafile_name))
    model_metadata_dict = cnn.read_model_metadata(model_metafile_name)
    print(SEPARATOR_STRING)

    cam_monte_carlo_dict = (gradcam_dict[gradcam.CAM_MONTE_CARLO_KEY]
                            if plot_significance else None)

    guided_cam_monte_carlo_dict = (
        gradcam_dict[gradcam.GUIDED_CAM_MONTE_CARLO_KEY]
        if plot_significance else None)

    region_dict = (gradcam_dict[gradcam.REGION_DICT_KEY]
                   if plot_regions_of_interest else None)

    num_examples = list_of_input_matrices[0].shape[0]
    num_input_matrices = len(list_of_input_matrices)

    for i in range(num_examples):
        these_figure_objects, these_axes_object_matrices = (
            plot_input_examples.plot_one_example(
                list_of_predictor_matrices=list_of_input_matrices,
                model_metadata_dict=model_metadata_dict,
                plot_sounding=False,
                allow_whitespace=allow_whitespace,
                pmm_flag=pmm_flag,
                example_index=i,
                full_storm_id_string=full_storm_id_strings[i],
                storm_time_unix_sec=storm_times_unix_sec[i]))

        for j in range(num_input_matrices):
            if list_of_cam_matrices[j] is None:
                continue

            if cam_monte_carlo_dict is None:
                this_significance_matrix = None
            else:
                this_significance_matrix = numpy.logical_or(
                    cam_monte_carlo_dict[
                        monte_carlo.TRIAL_PMM_MATRICES_KEY][j][i, ...] <
                    cam_monte_carlo_dict[monte_carlo.MIN_MATRICES_KEY][j][i,
                                                                          ...],
                    cam_monte_carlo_dict[
                        monte_carlo.TRIAL_PMM_MATRICES_KEY][j][i, ...] >
                    cam_monte_carlo_dict[monte_carlo.MAX_MATRICES_KEY][j][i,
                                                                          ...])

            this_num_spatial_dim = len(list_of_input_matrices[j].shape) - 2

            if this_num_spatial_dim == 3:
                _plot_3d_radar_cam(
                    colour_map_object=colour_map_object,
                    max_colour_percentile=max_colour_percentile,
                    figure_objects=these_figure_objects,
                    axes_object_matrices=these_axes_object_matrices,
                    model_metadata_dict=model_metadata_dict,
                    output_dir_name=unguided_cam_dir_name,
                    cam_matrix=list_of_cam_matrices[j][i, ...],
                    significance_matrix=this_significance_matrix,
                    full_storm_id_string=full_storm_id_strings[i],
                    storm_time_unix_sec=storm_times_unix_sec[i])
            else:
                if region_dict is None:
                    _plot_2d_radar_cam(
                        colour_map_object=colour_map_object,
                        max_colour_percentile=max_colour_percentile,
                        figure_objects=these_figure_objects,
                        axes_object_matrices=these_axes_object_matrices,
                        model_metadata_dict=model_metadata_dict,
                        output_dir_name=unguided_cam_dir_name,
                        cam_matrix=list_of_cam_matrices[j][i, ...],
                        significance_matrix=this_significance_matrix,
                        full_storm_id_string=full_storm_id_strings[i],
                        storm_time_unix_sec=storm_times_unix_sec[i])
                else:
                    _plot_2d_regions(
                        figure_objects=these_figure_objects,
                        axes_object_matrices=these_axes_object_matrices,
                        model_metadata_dict=model_metadata_dict,
                        list_of_polygon_objects=region_dict[
                            gradcam.POLYGON_OBJECTS_KEY][j][i],
                        output_dir_name=unguided_cam_dir_name,
                        full_storm_id_string=full_storm_id_strings[i],
                        storm_time_unix_sec=storm_times_unix_sec[i])

        these_figure_objects, these_axes_object_matrices = (
            plot_input_examples.plot_one_example(
                list_of_predictor_matrices=list_of_input_matrices,
                model_metadata_dict=model_metadata_dict,
                plot_sounding=False,
                allow_whitespace=allow_whitespace,
                pmm_flag=pmm_flag,
                example_index=i,
                full_storm_id_string=full_storm_id_strings[i],
                storm_time_unix_sec=storm_times_unix_sec[i]))

        for j in range(num_input_matrices):
            if list_of_guided_cam_matrices[j] is None:
                continue

            if guided_cam_monte_carlo_dict is None:
                this_significance_matrix = None
            else:
                this_significance_matrix = numpy.logical_or(
                    guided_cam_monte_carlo_dict[
                        monte_carlo.TRIAL_PMM_MATRICES_KEY][j][i, ...] <
                    guided_cam_monte_carlo_dict[monte_carlo.MIN_MATRICES_KEY]
                    [j][i, ...], guided_cam_monte_carlo_dict[
                        monte_carlo.TRIAL_PMM_MATRICES_KEY][j][i, ...] >
                    guided_cam_monte_carlo_dict[
                        monte_carlo.MAX_MATRICES_KEY][j][i, ...])

            this_num_spatial_dim = len(list_of_input_matrices[j].shape) - 2

            if this_num_spatial_dim == 3:
                _plot_3d_radar_cam(
                    colour_map_object=colour_map_object,
                    max_colour_percentile=max_colour_percentile,
                    figure_objects=these_figure_objects,
                    axes_object_matrices=these_axes_object_matrices,
                    model_metadata_dict=model_metadata_dict,
                    output_dir_name=guided_cam_dir_name,
                    guided_cam_matrix=list_of_guided_cam_matrices[j][i, ...],
                    significance_matrix=this_significance_matrix,
                    full_storm_id_string=full_storm_id_strings[i],
                    storm_time_unix_sec=storm_times_unix_sec[i])
            else:
                _plot_2d_radar_cam(
                    colour_map_object=colour_map_object,
                    max_colour_percentile=max_colour_percentile,
                    figure_objects=these_figure_objects,
                    axes_object_matrices=these_axes_object_matrices,
                    model_metadata_dict=model_metadata_dict,
                    output_dir_name=guided_cam_dir_name,
                    guided_cam_matrix=list_of_guided_cam_matrices[j][i, ...],
                    significance_matrix=this_significance_matrix,
                    full_storm_id_string=full_storm_id_strings[i],
                    storm_time_unix_sec=storm_times_unix_sec[i])
Example #3
0
def _plot_one_composite(saliency_file_name, monte_carlo_file_name,
                        composite_name_abbrev, composite_name_verbose,
                        colour_map_object, max_colour_value, half_num_contours,
                        smoothing_radius_grid_cells, output_dir_name):
    """Plots saliency map for one composite.

    :param saliency_file_name: Path to saliency file (will be read by
        `saliency.read_file`).
    :param monte_carlo_file_name: Path to Monte Carlo file (will be read by
        `_read_monte_carlo_file`).
    :param composite_name_abbrev: Abbrev composite name (will be used in file
        names).
    :param composite_name_verbose: Verbose composite name (will be used in
        figure title).
    :param colour_map_object: See documentation at top of file.
    :param max_colour_value: Same.
    :param half_num_contours: Same.
    :param smoothing_radius_grid_cells: Same.
    :param output_dir_name: Name of output directory (figures will be saved
        here).
    :return: main_figure_file_name: Path to main image file created by this
        method.
    """

    (mean_radar_matrices, mean_saliency_matrices, significance_matrices,
     model_metadata_dict) = _read_one_composite(
         saliency_file_name=saliency_file_name,
         smoothing_radius_grid_cells=smoothing_radius_grid_cells,
         monte_carlo_file_name=monte_carlo_file_name)

    refl_heights_m_agl = model_metadata_dict[cnn.TRAINING_OPTION_DICT_KEY][
        trainval_io.RADAR_HEIGHTS_KEY]
    num_refl_heights = len(refl_heights_m_agl)

    handle_dict = plot_examples.plot_one_example(
        list_of_predictor_matrices=mean_radar_matrices,
        model_metadata_dict=model_metadata_dict,
        pmm_flag=True,
        plot_sounding=False,
        allow_whitespace=True,
        plot_panel_names=True,
        panel_name_font_size=PANEL_NAME_FONT_SIZE,
        add_titles=False,
        label_colour_bars=True,
        colour_bar_length=COLOUR_BAR_LENGTH,
        colour_bar_font_size=COLOUR_BAR_FONT_SIZE,
        num_panel_rows=num_refl_heights)

    axes_object_matrices = handle_dict[plot_examples.RADAR_AXES_KEY]

    this_saliency_matrix = numpy.flip(mean_saliency_matrices[0][0, ..., 0],
                                      axis=0)
    saliency_plotting.plot_many_2d_grids_with_contours(
        saliency_matrix_3d=this_saliency_matrix,
        axes_object_matrix=axes_object_matrices[0],
        colour_map_object=colour_map_object,
        max_absolute_contour_level=max_colour_value,
        contour_interval=max_colour_value / half_num_contours,
        row_major=True)

    this_sig_matrix = numpy.flip(significance_matrices[0][0, ..., 0], axis=0)
    significance_plotting.plot_many_2d_grids_without_coords(
        significance_matrix=this_sig_matrix,
        axes_object_matrix=axes_object_matrices[0],
        marker_size=2,
        row_major=True)

    this_saliency_matrix = numpy.flip(mean_saliency_matrices[1][0, ...],
                                      axis=0)
    saliency_plotting.plot_many_2d_grids_with_contours(
        saliency_matrix_3d=this_saliency_matrix,
        axes_object_matrix=axes_object_matrices[1],
        colour_map_object=colour_map_object,
        max_absolute_contour_level=max_colour_value,
        contour_interval=max_colour_value / half_num_contours,
        row_major=False)

    this_sig_matrix = numpy.flip(significance_matrices[1][0, ...], axis=0)
    significance_plotting.plot_many_2d_grids_without_coords(
        significance_matrix=this_sig_matrix,
        axes_object_matrix=axes_object_matrices[1],
        marker_size=2,
        row_major=False)

    refl_figure_object = handle_dict[plot_examples.RADAR_FIGURES_KEY][0]
    refl_figure_file_name = '{0:s}/{1:s}_reflectivity.jpg'.format(
        output_dir_name, composite_name_abbrev)

    print('Saving figure to: "{0:s}"...'.format(refl_figure_file_name))
    refl_figure_object.savefig(refl_figure_file_name,
                               dpi=FIGURE_RESOLUTION_DPI,
                               pad_inches=0,
                               bbox_inches='tight')
    pyplot.close(refl_figure_object)

    shear_figure_object = handle_dict[plot_examples.RADAR_FIGURES_KEY][1]
    shear_figure_file_name = '{0:s}/{1:s}_shear.jpg'.format(
        output_dir_name, composite_name_abbrev)

    print('Saving figure to: "{0:s}"...'.format(shear_figure_file_name))
    shear_figure_object.savefig(shear_figure_file_name,
                                dpi=FIGURE_RESOLUTION_DPI,
                                pad_inches=0,
                                bbox_inches='tight')
    pyplot.close(shear_figure_object)

    main_figure_file_name = '{0:s}/{1:s}_saliency.jpg'.format(
        output_dir_name, composite_name_abbrev)
    print('Concatenating panels to: "{0:s}"...'.format(main_figure_file_name))

    imagemagick_utils.concatenate_images(
        input_file_names=[refl_figure_file_name, shear_figure_file_name],
        output_file_name=main_figure_file_name,
        num_panel_rows=1,
        num_panel_columns=2,
        border_width_pixels=50,
        extra_args_string='-gravity south')
    imagemagick_utils.resize_image(input_file_name=main_figure_file_name,
                                   output_file_name=main_figure_file_name,
                                   output_size_pixels=CONCAT_FIGURE_SIZE_PX)
    imagemagick_utils.trim_whitespace(input_file_name=main_figure_file_name,
                                      output_file_name=main_figure_file_name,
                                      border_width_pixels=TITLE_FONT_SIZE + 25)
    _overlay_text(image_file_name=main_figure_file_name,
                  x_offset_from_center_px=0,
                  y_offset_from_top_px=0,
                  text_string=composite_name_verbose)
    imagemagick_utils.trim_whitespace(input_file_name=main_figure_file_name,
                                      output_file_name=main_figure_file_name,
                                      border_width_pixels=10)

    return main_figure_file_name
def _run(input_file_name, plot_soundings, allow_whitespace, plot_significance,
         colour_map_name, max_colour_percentile, output_dir_name):
    """Plots saliency maps for a CNN (convolutional neural network).

    This is effectively the main method.

    :param input_file_name: See documentation at top of file.
    :param plot_soundings: Same.
    :param allow_whitespace: Same.
    :param plot_significance: Same.
    :param colour_map_name: Same.
    :param max_colour_percentile: Same.
    :param output_dir_name: Same.
    """

    file_system_utils.mkdir_recursive_if_necessary(
        directory_name=output_dir_name)

    error_checking.assert_is_geq(max_colour_percentile, 0.)
    error_checking.assert_is_leq(max_colour_percentile, 100.)
    colour_map_object = pyplot.cm.get_cmap(colour_map_name)

    print('Reading data from: "{0:s}"...'.format(input_file_name))

    try:
        saliency_dict = saliency_maps.read_standard_file(input_file_name)
        list_of_input_matrices = saliency_dict.pop(
            saliency_maps.INPUT_MATRICES_KEY)
        list_of_saliency_matrices = saliency_dict.pop(
            saliency_maps.SALIENCY_MATRICES_KEY)

        full_storm_id_strings = saliency_dict[saliency_maps.FULL_IDS_KEY]
        storm_times_unix_sec = saliency_dict[saliency_maps.STORM_TIMES_KEY]

    except ValueError:
        saliency_dict = saliency_maps.read_pmm_file(input_file_name)
        list_of_input_matrices = saliency_dict.pop(
            saliency_maps.MEAN_INPUT_MATRICES_KEY)
        list_of_saliency_matrices = saliency_dict.pop(
            saliency_maps.MEAN_SALIENCY_MATRICES_KEY)

        for i in range(len(list_of_input_matrices)):
            list_of_input_matrices[i] = numpy.expand_dims(
                list_of_input_matrices[i], axis=0
            )
            list_of_saliency_matrices[i] = numpy.expand_dims(
                list_of_saliency_matrices[i], axis=0
            )

        full_storm_id_strings = [None]
        storm_times_unix_sec = [None]

    pmm_flag = (
        full_storm_id_strings[0] is None and storm_times_unix_sec[0] is None
    )

    num_examples = list_of_input_matrices[0].shape[0]
    max_colour_value_by_example = numpy.full(num_examples, numpy.nan)

    for i in range(num_examples):
        these_saliency_values = numpy.concatenate(
            [numpy.ravel(s[i, ...]) for s in list_of_saliency_matrices]
        )
        max_colour_value_by_example[i] = numpy.percentile(
            numpy.absolute(these_saliency_values), max_colour_percentile
        )

    model_file_name = saliency_dict[saliency_maps.MODEL_FILE_KEY]
    model_metafile_name = '{0:s}/model_metadata.p'.format(
        os.path.split(model_file_name)[0]
    )

    print('Reading metadata from: "{0:s}"...'.format(model_metafile_name))
    model_metadata_dict = cnn.read_model_metadata(model_metafile_name)
    print(SEPARATOR_STRING)

    training_option_dict = model_metadata_dict[cnn.TRAINING_OPTION_DICT_KEY]
    has_soundings = (
        training_option_dict[trainval_io.SOUNDING_FIELDS_KEY] is not None
    )
    num_radar_matrices = len(list_of_input_matrices) - int(has_soundings)

    monte_carlo_dict = (
        saliency_dict[saliency_maps.MONTE_CARLO_DICT_KEY]
        if plot_significance and
        saliency_maps.MONTE_CARLO_DICT_KEY in saliency_dict
        else None
    )

    for i in range(num_examples):
        if has_soundings and plot_soundings:
            _plot_sounding_saliency(
                saliency_matrix=list_of_saliency_matrices[-1][i, ...],
                colour_map_object=colour_map_object,
                max_colour_value=max_colour_value_by_example[i],
                sounding_matrix=list_of_input_matrices[-1][i, ...],
                saliency_dict=saliency_dict,
                model_metadata_dict=model_metadata_dict,
                output_dir_name=output_dir_name,
                pmm_flag=pmm_flag, example_index=i)

        this_handle_dict = plot_input_examples.plot_one_example(
            list_of_predictor_matrices=list_of_input_matrices,
            model_metadata_dict=model_metadata_dict,
            plot_sounding=False, allow_whitespace=allow_whitespace,
            pmm_flag=pmm_flag, example_index=i,
            full_storm_id_string=full_storm_id_strings[i],
            storm_time_unix_sec=storm_times_unix_sec[i]
        )

        these_figure_objects = this_handle_dict[
            plot_input_examples.RADAR_FIGURES_KEY]
        these_axes_object_matrices = this_handle_dict[
            plot_input_examples.RADAR_AXES_KEY]

        for j in range(num_radar_matrices):
            if monte_carlo_dict is None:
                this_significance_matrix = None
            else:
                this_significance_matrix = numpy.logical_or(
                    monte_carlo_dict[
                        monte_carlo.TRIAL_PMM_MATRICES_KEY][j][i, ...] <
                    monte_carlo_dict[monte_carlo.MIN_MATRICES_KEY][j][i, ...],
                    monte_carlo_dict[
                        monte_carlo.TRIAL_PMM_MATRICES_KEY][j][i, ...] >
                    monte_carlo_dict[monte_carlo.MAX_MATRICES_KEY][j][i, ...]
                )

            this_num_spatial_dim = len(list_of_input_matrices[j].shape) - 2

            if this_num_spatial_dim == 3:
                _plot_3d_radar_saliency(
                    saliency_matrix=list_of_saliency_matrices[j][i, ...],
                    colour_map_object=colour_map_object,
                    max_colour_value=max_colour_value_by_example[i],
                    figure_objects=these_figure_objects,
                    axes_object_matrices=these_axes_object_matrices,
                    model_metadata_dict=model_metadata_dict,
                    output_dir_name=output_dir_name,
                    significance_matrix=this_significance_matrix,
                    full_storm_id_string=full_storm_id_strings[i],
                    storm_time_unix_sec=storm_times_unix_sec[i]
                )
            else:
                _plot_2d_radar_saliency(
                    saliency_matrix=list_of_saliency_matrices[j][i, ...],
                    colour_map_object=colour_map_object,
                    max_colour_value=max_colour_value_by_example[i],
                    figure_objects=these_figure_objects,
                    axes_object_matrices=these_axes_object_matrices,
                    model_metadata_dict=model_metadata_dict,
                    output_dir_name=output_dir_name,
                    significance_matrix=this_significance_matrix,
                    full_storm_id_string=full_storm_id_strings[i],
                    storm_time_unix_sec=storm_times_unix_sec[i]
                )
def _plot_one_example(
        radar_matrix, sounding_matrix, sounding_pressures_pascals,
        full_storm_id_string, storm_time_unix_sec, model_metadata_dict,
        output_dir_name):
    """Plots predictors for one example.

    M = number of rows in radar grid
    N = number of columns in radar grid
    H_r = number of heights in radar grid
    F_r = number of radar fields
    H_s = number of sounding heights
    F_s = number of sounding fields

    :param radar_matrix: numpy array (1 x M x N x H_r x F_r) of radar values.
    :param sounding_matrix: numpy array (1 x H_s x F_s) of sounding values.
    :param sounding_pressures_pascals: numpy array (length H_s) of sounding
        pressures.
    :param full_storm_id_string: Full storm ID.
    :param storm_time_unix_sec: Valid time.
    :param model_metadata_dict: Dictionary returned by
        `cnn.read_model_metadata`.
    :param output_dir_name: Name of output directory (figures will be saved
        here).
    :return: radar_figure_file_name: Path to radar figure created by this
        method.
    """

    training_option_dict = model_metadata_dict[cnn.TRAINING_OPTION_DICT_KEY]
    radar_field_names = training_option_dict[trainval_io.RADAR_FIELDS_KEY]

    num_radar_fields = radar_matrix.shape[-1]
    num_radar_heights = radar_matrix.shape[-2]

    handle_dict = plot_examples.plot_one_example(
        list_of_predictor_matrices=[radar_matrix, sounding_matrix],
        model_metadata_dict=model_metadata_dict,
        pmm_flag=False, example_index=0, plot_sounding=True,
        sounding_pressures_pascals=sounding_pressures_pascals,
        allow_whitespace=True, plot_panel_names=True,
        panel_name_font_size=PANEL_NAME_FONT_SIZE,
        add_titles=False, label_colour_bars=True,
        colour_bar_length=COLOUR_BAR_LENGTH,
        colour_bar_font_size=COLOUR_BAR_FONT_SIZE,
        num_panel_rows=num_radar_heights
    )

    sounding_file_name = plot_examples.metadata_to_file_name(
        output_dir_name=output_dir_name, is_sounding=True, pmm_flag=False,
        full_storm_id_string=full_storm_id_string,
        storm_time_unix_sec=storm_time_unix_sec
    )

    print('Saving figure to: "{0:s}"...'.format(sounding_file_name))

    sounding_figure_object = handle_dict[plot_examples.SOUNDING_FIGURE_KEY]
    sounding_figure_object.savefig(
        sounding_file_name, dpi=FIGURE_RESOLUTION_DPI,
        pad_inches=0, bbox_inches='tight'
    )
    pyplot.close(sounding_figure_object)

    figure_objects = handle_dict[plot_examples.RADAR_FIGURES_KEY]
    panel_file_names = [None] * num_radar_fields

    for k in range(num_radar_fields):
        panel_file_names[k] = plot_examples.metadata_to_file_name(
            output_dir_name=output_dir_name, is_sounding=False, pmm_flag=False,
            full_storm_id_string=full_storm_id_string,
            storm_time_unix_sec=storm_time_unix_sec,
            radar_field_name=radar_field_names[k]
        )

        print('Saving figure to: "{0:s}"...'.format(panel_file_names[k]))

        figure_objects[k].savefig(
            panel_file_names[k], dpi=FIGURE_RESOLUTION_DPI,
            pad_inches=0, bbox_inches='tight'
        )
        pyplot.close(figure_objects[k])

    radar_figure_file_name = plot_examples.metadata_to_file_name(
        output_dir_name=output_dir_name, is_sounding=False, pmm_flag=False,
        full_storm_id_string=full_storm_id_string,
        storm_time_unix_sec=storm_time_unix_sec
    )

    print('Concatenating panels to: "{0:s}"...'.format(radar_figure_file_name))

    imagemagick_utils.concatenate_images(
        input_file_names=panel_file_names,
        output_file_name=radar_figure_file_name,
        num_panel_rows=1, num_panel_columns=num_radar_fields,
        border_width_pixels=50
    )
    imagemagick_utils.resize_image(
        input_file_name=radar_figure_file_name,
        output_file_name=radar_figure_file_name,
        output_size_pixels=CONCAT_FIGURE_SIZE_PX
    )
    imagemagick_utils.trim_whitespace(
        input_file_name=radar_figure_file_name,
        output_file_name=radar_figure_file_name,
        border_width_pixels=10
    )

    for this_file_name in panel_file_names:
        os.remove(this_file_name)

    return radar_figure_file_name
def _run(input_file_name, colour_map_name, max_colour_value, half_num_contours,
         smoothing_radius_grid_cells, plot_soundings, allow_whitespace,
         plot_panel_names, add_titles, label_colour_bars, colour_bar_length,
         output_dir_name):
    """Plots saliency maps.

    This is effectively the main method.

    :param input_file_name: See documentation at top of file.
    :param colour_map_name: Same.
    :param max_colour_value: Same.
    :param half_num_contours: Same.
    :param smoothing_radius_grid_cells: Same.
    :param plot_soundings: Same.
    :param allow_whitespace: Same.
    :param plot_panel_names: Same.
    :param add_titles: Same.
    :param label_colour_bars: Same.
    :param colour_bar_length: Same.
    :param output_dir_name: Same.
    """

    if max_colour_value <= 0:
        max_colour_value = None
    if smoothing_radius_grid_cells <= 0:
        smoothing_radius_grid_cells = None

    file_system_utils.mkdir_recursive_if_necessary(
        directory_name=output_dir_name)

    colour_map_object = pyplot.cm.get_cmap(colour_map_name)
    error_checking.assert_is_geq(half_num_contours, 5)

    print('Reading data from: "{0:s}"...'.format(input_file_name))
    saliency_dict, pmm_flag = saliency_maps.read_file(input_file_name)

    if pmm_flag:
        predictor_matrices = saliency_dict.pop(
            saliency_maps.MEAN_PREDICTOR_MATRICES_KEY)
        saliency_matrices = saliency_dict.pop(
            saliency_maps.MEAN_SALIENCY_MATRICES_KEY)

        full_storm_id_strings = [None]
        storm_times_unix_sec = [None]

        mean_sounding_pressures_pa = saliency_dict[
            saliency_maps.MEAN_SOUNDING_PRESSURES_KEY]
        sounding_pressure_matrix_pa = numpy.reshape(
            mean_sounding_pressures_pa, (1, len(mean_sounding_pressures_pa))
        )

        for i in range(len(predictor_matrices)):
            predictor_matrices[i] = numpy.expand_dims(
                predictor_matrices[i], axis=0
            )
            saliency_matrices[i] = numpy.expand_dims(
                saliency_matrices[i], axis=0
            )
    else:
        predictor_matrices = saliency_dict.pop(
            saliency_maps.PREDICTOR_MATRICES_KEY)
        saliency_matrices = saliency_dict.pop(
            saliency_maps.SALIENCY_MATRICES_KEY)

        full_storm_id_strings = saliency_dict[saliency_maps.FULL_STORM_IDS_KEY]
        storm_times_unix_sec = saliency_dict[saliency_maps.STORM_TIMES_KEY]
        sounding_pressure_matrix_pa = saliency_dict[
            saliency_maps.SOUNDING_PRESSURES_KEY]

    if smoothing_radius_grid_cells is not None:
        saliency_matrices = _smooth_maps(
            saliency_matrices=saliency_matrices,
            smoothing_radius_grid_cells=smoothing_radius_grid_cells)

    model_file_name = saliency_dict[saliency_maps.MODEL_FILE_KEY]
    model_metafile_name = '{0:s}/model_metadata.p'.format(
        os.path.split(model_file_name)[0]
    )

    print('Reading metadata from: "{0:s}"...'.format(model_metafile_name))
    model_metadata_dict = cnn.read_model_metadata(model_metafile_name)
    print(SEPARATOR_STRING)

    training_option_dict = model_metadata_dict[cnn.TRAINING_OPTION_DICT_KEY]
    num_radar_matrices = len(predictor_matrices)

    if training_option_dict[trainval_io.SOUNDING_FIELDS_KEY] is None:
        plot_soundings = False
    else:
        num_radar_matrices -= 1

    num_examples = predictor_matrices[0].shape[0]

    for i in range(num_examples):
        this_handle_dict = plot_examples.plot_one_example(
            list_of_predictor_matrices=predictor_matrices,
            model_metadata_dict=model_metadata_dict, pmm_flag=pmm_flag,
            example_index=i, plot_sounding=plot_soundings,
            sounding_pressures_pascals=sounding_pressure_matrix_pa[i, ...],
            allow_whitespace=allow_whitespace,
            plot_panel_names=plot_panel_names, add_titles=add_titles,
            label_colour_bars=label_colour_bars,
            colour_bar_length=colour_bar_length)

        if plot_soundings:
            _plot_sounding_saliency(
                saliency_matrix=saliency_matrices[-1][i, ...],
                colour_map_object=colour_map_object,
                max_colour_value=max_colour_value,
                sounding_figure_object=this_handle_dict[
                    plot_examples.SOUNDING_FIGURE_KEY],
                sounding_axes_object=this_handle_dict[
                    plot_examples.SOUNDING_AXES_KEY],
                sounding_pressures_pascals=sounding_pressure_matrix_pa[i, ...],
                saliency_dict=saliency_dict,
                model_metadata_dict=model_metadata_dict, add_title=add_titles,
                output_dir_name=output_dir_name, pmm_flag=pmm_flag,
                example_index=i)

        these_figure_objects = this_handle_dict[plot_examples.RADAR_FIGURES_KEY]
        these_axes_object_matrices = this_handle_dict[
            plot_examples.RADAR_AXES_KEY]

        for j in range(num_radar_matrices):
            this_num_spatial_dim = len(predictor_matrices[j].shape) - 2

            if this_num_spatial_dim == 3:
                _plot_3d_radar_saliency(
                    saliency_matrix=saliency_matrices[j][i, ...],
                    colour_map_object=colour_map_object,
                    max_colour_value=max_colour_value,
                    half_num_contours=half_num_contours,
                    label_colour_bars=label_colour_bars,
                    colour_bar_length=colour_bar_length,
                    figure_objects=these_figure_objects,
                    axes_object_matrices=these_axes_object_matrices,
                    model_metadata_dict=model_metadata_dict,
                    output_dir_name=output_dir_name,
                    significance_matrix=None,
                    full_storm_id_string=full_storm_id_strings[i],
                    storm_time_unix_sec=storm_times_unix_sec[i]
                )
            else:
                _plot_2d_radar_saliency(
                    saliency_matrix=saliency_matrices[j][i, ...],
                    colour_map_object=colour_map_object,
                    max_colour_value=max_colour_value,
                    half_num_contours=half_num_contours,
                    label_colour_bars=label_colour_bars,
                    colour_bar_length=colour_bar_length,
                    figure_objects=these_figure_objects,
                    axes_object_matrices=these_axes_object_matrices,
                    model_metadata_dict=model_metadata_dict,
                    output_dir_name=output_dir_name,
                    significance_matrix=None,
                    full_storm_id_string=full_storm_id_strings[i],
                    storm_time_unix_sec=storm_times_unix_sec[i]
                )
def _plot_composite(composite_file_name, composite_name_abbrev,
                    composite_name_verbose, output_dir_name):
    """Plots one composite.

    :param composite_file_name: Path to input file.  Will be read by
        `_read_composite`.
    :param composite_name_abbrev: Abbreviated name for composite.  Will be used
        in names of output files.
    :param composite_name_verbose: Verbose name for composite.  Will be used as
        figure title.
    :param output_dir_name: Path to output directory.  Figures will be saved
        here.
    :return: radar_figure_file_name: Path to file with radar figure for this
        composite.
    :return: sounding_figure_file_name: Path to file with sounding figure for
        this composite.
    """

    mean_predictor_matrices, model_metadata_dict, mean_sounding_pressures_pa = (
        _read_composite(composite_file_name))

    radar_field_names = model_metadata_dict[cnn.TRAINING_OPTION_DICT_KEY][
        trainval_io.RADAR_FIELDS_KEY]
    radar_heights_m_agl = model_metadata_dict[cnn.TRAINING_OPTION_DICT_KEY][
        trainval_io.RADAR_HEIGHTS_KEY]

    num_radar_fields = len(radar_field_names)
    num_radar_heights = len(radar_heights_m_agl)

    handle_dict = plot_examples.plot_one_example(
        list_of_predictor_matrices=mean_predictor_matrices,
        model_metadata_dict=model_metadata_dict,
        pmm_flag=True,
        plot_sounding=True,
        sounding_pressures_pascals=mean_sounding_pressures_pa,
        allow_whitespace=True,
        plot_panel_names=True,
        panel_name_font_size=PANEL_NAME_FONT_SIZE,
        add_titles=False,
        label_colour_bars=True,
        colour_bar_length=COLOUR_BAR_LENGTH,
        colour_bar_font_size=COLOUR_BAR_FONT_SIZE,
        sounding_font_size=SOUNDING_FONT_SIZE,
        num_panel_rows=num_radar_heights)

    sounding_figure_file_name = '{0:s}/{1:s}_sounding.jpg'.format(
        output_dir_name, composite_name_abbrev)

    print('Saving figure to: "{0:s}"...'.format(sounding_figure_file_name))
    sounding_figure_object = handle_dict[plot_examples.SOUNDING_FIGURE_KEY]

    sounding_figure_object.savefig(sounding_figure_file_name,
                                   dpi=FIGURE_RESOLUTION_DPI,
                                   pad_inches=0,
                                   bbox_inches='tight')
    pyplot.close(sounding_figure_object)

    imagemagick_utils.resize_image(input_file_name=sounding_figure_file_name,
                                   output_file_name=sounding_figure_file_name,
                                   output_size_pixels=CONCAT_FIGURE_SIZE_PX)

    imagemagick_utils.trim_whitespace(
        input_file_name=sounding_figure_file_name,
        output_file_name=sounding_figure_file_name,
        border_width_pixels=TITLE_FONT_SIZE + 25)

    _overlay_text(image_file_name=sounding_figure_file_name,
                  x_offset_from_center_px=0,
                  y_offset_from_top_px=0,
                  text_string=composite_name_verbose)

    imagemagick_utils.trim_whitespace(
        input_file_name=sounding_figure_file_name,
        output_file_name=sounding_figure_file_name,
        border_width_pixels=10)

    radar_figure_objects = handle_dict[plot_examples.RADAR_FIGURES_KEY]
    panel_file_names = [None] * num_radar_fields

    for j in range(num_radar_fields):
        panel_file_names[j] = '{0:s}/{1:s}_{2:s}.jpg'.format(
            output_dir_name, composite_name_abbrev,
            radar_field_names[j].replace('_', '-'))

        print('Saving figure to: "{0:s}"...'.format(panel_file_names[j]))

        radar_figure_objects[j].savefig(panel_file_names[j],
                                        dpi=FIGURE_RESOLUTION_DPI,
                                        pad_inches=0,
                                        bbox_inches='tight')
        pyplot.close(radar_figure_objects[j])

    radar_figure_file_name = '{0:s}/{1:s}_radar.jpg'.format(
        output_dir_name, composite_name_abbrev)

    print('Concatenating panels to: "{0:s}"...'.format(radar_figure_file_name))

    imagemagick_utils.concatenate_images(
        input_file_names=panel_file_names,
        output_file_name=radar_figure_file_name,
        num_panel_rows=1,
        num_panel_columns=num_radar_fields,
        border_width_pixels=50)

    imagemagick_utils.resize_image(input_file_name=radar_figure_file_name,
                                   output_file_name=radar_figure_file_name,
                                   output_size_pixels=CONCAT_FIGURE_SIZE_PX)

    imagemagick_utils.trim_whitespace(input_file_name=radar_figure_file_name,
                                      output_file_name=radar_figure_file_name,
                                      border_width_pixels=TITLE_FONT_SIZE + 25)

    _overlay_text(image_file_name=radar_figure_file_name,
                  x_offset_from_center_px=0,
                  y_offset_from_top_px=0,
                  text_string=composite_name_verbose)

    imagemagick_utils.trim_whitespace(input_file_name=radar_figure_file_name,
                                      output_file_name=radar_figure_file_name,
                                      border_width_pixels=10)

    return radar_figure_file_name, sounding_figure_file_name
Example #8
0
def _plot_one_composite(gradcam_file_name, monte_carlo_file_name,
                        composite_name_abbrev, composite_name_verbose,
                        colour_map_object, min_colour_value, max_colour_value,
                        num_contours, smoothing_radius_grid_cells,
                        monte_carlo_max_fdr, output_dir_name):
    """Plots class-activation map for one composite.

    :param gradcam_file_name: Path to input file (will be read by
        `gradcam.read_file`).
    :param monte_carlo_file_name: Path to Monte Carlo file (will be read by
        `_read_monte_carlo_file`).
    :param composite_name_abbrev: Abbrev composite name (will be used in file
        names).
    :param composite_name_verbose: Verbose composite name (will be used in
        figure title).
    :param colour_map_object: See documentation at top of file.
    :param min_colour_value: Minimum value in colour bar (may be NaN).
    :param max_colour_value: Max value in colour bar (may be NaN).
    :param num_contours: See documentation at top of file.
    :param smoothing_radius_grid_cells: Same.
    :param monte_carlo_max_fdr: Same.
    :param output_dir_name: Name of output directory (figures will be saved
        here).
    :return: main_figure_file_name: Path to main image file created by this
        method.
    :return: min_colour_value: Same as input but cannot be None.
    :return: max_colour_value: Same as input but cannot be None.
    """

    (mean_radar_matrix, mean_class_activn_matrix, significance_matrix,
     model_metadata_dict) = _read_one_composite(
         gradcam_file_name=gradcam_file_name,
         smoothing_radius_grid_cells=smoothing_radius_grid_cells,
         monte_carlo_file_name=monte_carlo_file_name,
         monte_carlo_max_fdr=monte_carlo_max_fdr)

    if numpy.isnan(min_colour_value) or numpy.isnan(max_colour_value):
        min_colour_value_log10 = numpy.log10(
            numpy.percentile(mean_class_activn_matrix, 1.))
        max_colour_value_log10 = numpy.log10(
            numpy.percentile(mean_class_activn_matrix, 99.))

        min_colour_value_log10 = max([min_colour_value_log10, -2.])
        max_colour_value_log10 = max([max_colour_value_log10, -1.])

        min_colour_value_log10 = min([min_colour_value_log10, 1.])
        max_colour_value_log10 = min([max_colour_value_log10, 2.])

        min_colour_value = 10**min_colour_value_log10
        max_colour_value = 10**max_colour_value_log10
    else:
        min_colour_value_log10 = numpy.log10(min_colour_value)
        max_colour_value_log10 = numpy.log10(max_colour_value)

    contour_interval_log10 = (
        (max_colour_value_log10 - min_colour_value_log10) / (num_contours - 1))
    mean_activn_matrix_log10 = numpy.log10(mean_class_activn_matrix)

    training_option_dict = model_metadata_dict[cnn.TRAINING_OPTION_DICT_KEY]
    field_names = training_option_dict[trainval_io.RADAR_FIELDS_KEY]

    num_fields = mean_radar_matrix.shape[-1]
    num_heights = mean_radar_matrix.shape[-2]

    handle_dict = plot_examples.plot_one_example(
        list_of_predictor_matrices=[mean_radar_matrix],
        model_metadata_dict=model_metadata_dict,
        pmm_flag=True,
        allow_whitespace=True,
        plot_panel_names=True,
        panel_name_font_size=PANEL_NAME_FONT_SIZE,
        add_titles=False,
        label_colour_bars=True,
        colour_bar_length=COLOUR_BAR_LENGTH,
        colour_bar_font_size=COLOUR_BAR_FONT_SIZE,
        num_panel_rows=num_heights)

    figure_objects = handle_dict[plot_examples.RADAR_FIGURES_KEY]
    axes_object_matrices = handle_dict[plot_examples.RADAR_AXES_KEY]

    for k in range(num_fields):
        cam_plotting.plot_many_2d_grids(
            class_activation_matrix_3d=numpy.flip(
                mean_activn_matrix_log10[0, ...], axis=0),
            axes_object_matrix=axes_object_matrices[k],
            colour_map_object=colour_map_object,
            min_contour_level=min_colour_value_log10,
            max_contour_level=max_colour_value_log10,
            contour_interval=contour_interval_log10)

        significance_plotting.plot_many_2d_grids_without_coords(
            significance_matrix=numpy.flip(significance_matrix[0, ...],
                                           axis=0),
            axes_object_matrix=axes_object_matrices[k])

    panel_file_names = [None] * num_fields

    for k in range(num_fields):
        panel_file_names[k] = '{0:s}/{1:s}_{2:s}.jpg'.format(
            output_dir_name, composite_name_abbrev,
            field_names[k].replace('_', '-'))

        print('Saving figure to: "{0:s}"...'.format(panel_file_names[k]))

        figure_objects[k].savefig(panel_file_names[k],
                                  dpi=FIGURE_RESOLUTION_DPI,
                                  pad_inches=0,
                                  bbox_inches='tight')
        pyplot.close(figure_objects[k])

    main_figure_file_name = '{0:s}/{1:s}_gradcam.jpg'.format(
        output_dir_name, composite_name_abbrev)

    print('Concatenating panels to: "{0:s}"...'.format(main_figure_file_name))
    imagemagick_utils.concatenate_images(
        input_file_names=panel_file_names,
        output_file_name=main_figure_file_name,
        num_panel_rows=1,
        num_panel_columns=num_fields,
        border_width_pixels=50)
    imagemagick_utils.resize_image(input_file_name=main_figure_file_name,
                                   output_file_name=main_figure_file_name,
                                   output_size_pixels=CONCAT_FIGURE_SIZE_PX)
    imagemagick_utils.trim_whitespace(input_file_name=main_figure_file_name,
                                      output_file_name=main_figure_file_name,
                                      border_width_pixels=TITLE_FONT_SIZE + 25)
    _overlay_text(image_file_name=main_figure_file_name,
                  x_offset_from_center_px=0,
                  y_offset_from_top_px=0,
                  text_string=composite_name_verbose)
    imagemagick_utils.trim_whitespace(input_file_name=main_figure_file_name,
                                      output_file_name=main_figure_file_name,
                                      border_width_pixels=10)

    return main_figure_file_name, min_colour_value, max_colour_value
Example #9
0
def _run(gridrad_example_dir_name, gridrad_full_id_string, gridrad_time_string,
         myrorss_example_dir_name, myrorss_full_id_string, myrorss_time_string,
         output_dir_name):
    """Makes figure with GridRad and MYRORSS predictors.

    This is effectively the main method.

    :param gridrad_example_dir_name: See documentation at top of file.
    :param gridrad_full_id_string: Same.
    :param gridrad_time_string: Same.
    :param myrorss_example_dir_name: Same.
    :param myrorss_full_id_string: Same.
    :param myrorss_time_string: Same.
    :param output_dir_name: Same.
    """

    file_system_utils.mkdir_recursive_if_necessary(
        directory_name=output_dir_name)

    gridrad_time_unix_sec = time_conversion.string_to_unix_sec(
        gridrad_time_string, TIME_FORMAT)
    myrorss_time_unix_sec = time_conversion.string_to_unix_sec(
        myrorss_time_string, TIME_FORMAT)

    letter_label = None
    num_gridrad_fields = len(GRIDRAD_FIELD_NAMES)
    panel_file_names = [None] * num_gridrad_fields * 2

    for j in range(num_gridrad_fields):
        these_predictor_matrices, this_metadata_dict = _read_one_example(
            top_example_dir_name=gridrad_example_dir_name,
            full_storm_id_string=gridrad_full_id_string,
            storm_time_unix_sec=gridrad_time_unix_sec,
            source_name=radar_utils.GRIDRAD_SOURCE_ID,
            radar_field_name=GRIDRAD_FIELD_NAMES[j],
            include_sounding=False)[:2]

        print(MINOR_SEPARATOR_STRING)

        this_handle_dict = plot_examples.plot_one_example(
            list_of_predictor_matrices=these_predictor_matrices,
            model_metadata_dict=this_metadata_dict,
            pmm_flag=False,
            example_index=0,
            plot_sounding=False,
            allow_whitespace=True,
            plot_panel_names=False,
            add_titles=False,
            label_colour_bars=False,
            colour_bar_font_size=COLOUR_BAR_FONT_SIZE,
            colour_bar_length=COLOUR_BAR_LENGTH)

        this_title_string = radar_plotting.fields_and_heights_to_names(
            field_names=[GRIDRAD_FIELD_NAMES[j]],
            heights_m_agl=RADAR_HEIGHTS_M_AGL[[0]],
            include_units=True)[0]

        this_title_string = this_title_string.replace('\n', ' ').replace(
            ' km AGL', ' km')
        this_title_string = 'GridRad {0:s}{1:s}'.format(
            this_title_string[0].lower(), this_title_string[1:])

        this_figure_object = this_handle_dict[
            plot_examples.RADAR_FIGURES_KEY][0]
        this_axes_object = this_handle_dict[plot_examples.RADAR_AXES_KEY][0][0,
                                                                             0]

        this_figure_object.suptitle('')
        this_axes_object.set_title(this_title_string, fontsize=TITLE_FONT_SIZE)

        # this_axes_object.set_yticklabels(
        #     this_axes_object.get_yticks(), color=ALMOST_WHITE_COLOUR
        # )

        if letter_label is None:
            letter_label = 'a'
        else:
            letter_label = chr(ord(letter_label) + 1)

        plotting_utils.label_axes(axes_object=this_axes_object,
                                  label_string='({0:s})'.format(letter_label),
                                  font_size=PANEL_LETTER_FONT_SIZE,
                                  x_coord_normalized=X_LABEL_COORD_NORMALIZED,
                                  y_coord_normalized=Y_LABEL_COORD_NORMALIZED)

        panel_file_names[j * 2] = '{0:s}/gridrad_{1:s}.jpg'.format(
            output_dir_name, GRIDRAD_FIELD_NAMES[j].replace('_', '-'))

        print('Saving figure to: "{0:s}"...'.format(panel_file_names[j * 2]))
        this_figure_object.savefig(panel_file_names[j * 2],
                                   dpi=FIGURE_RESOLUTION_DPI,
                                   pad_inches=0,
                                   bbox_inches='tight')
        pyplot.close(this_figure_object)

        print(SEPARATOR_STRING)

    num_myrorss_shear_fields = len(MYRORSS_SHEAR_FIELD_NAMES)

    for j in range(num_myrorss_shear_fields):
        (these_predictor_matrices, this_metadata_dict,
         these_pressures_pascals) = _read_one_example(
             top_example_dir_name=myrorss_example_dir_name,
             full_storm_id_string=myrorss_full_id_string,
             storm_time_unix_sec=myrorss_time_unix_sec,
             source_name=radar_utils.MYRORSS_SOURCE_ID,
             radar_field_name=MYRORSS_SHEAR_FIELD_NAMES[j],
             include_sounding=j == 0)

        print(MINOR_SEPARATOR_STRING)

        this_handle_dict = plot_examples.plot_one_example(
            list_of_predictor_matrices=these_predictor_matrices,
            model_metadata_dict=this_metadata_dict,
            pmm_flag=False,
            example_index=0,
            plot_sounding=j == 0,
            sounding_pressures_pascals=these_pressures_pascals,
            allow_whitespace=True,
            plot_panel_names=False,
            add_titles=False,
            label_colour_bars=False,
            colour_bar_font_size=COLOUR_BAR_FONT_SIZE,
            colour_bar_length=COLOUR_BAR_LENGTH,
            sounding_font_size=SOUNDING_FONT_SIZE)

        if j == 0:
            this_axes_object = this_handle_dict[
                plot_examples.SOUNDING_AXES_KEY]
            this_axes_object.set_title('Proximity sounding')

            letter_label = chr(ord(letter_label) + 1)
            plotting_utils.label_axes(
                axes_object=this_axes_object,
                label_string='({0:s})'.format(letter_label),
                font_size=PANEL_LETTER_FONT_SIZE,
                x_coord_normalized=X_LABEL_COORD_NORMALIZED,
                y_coord_normalized=Y_LABEL_COORD_NORMALIZED)

            this_figure_object = this_handle_dict[
                plot_examples.SOUNDING_FIGURE_KEY]
            panel_file_names[1] = '{0:s}/sounding.jpg'.format(output_dir_name)

            print('Saving figure to: "{0:s}"...'.format(panel_file_names[1]))
            this_figure_object.savefig(panel_file_names[1],
                                       dpi=FIGURE_RESOLUTION_DPI,
                                       pad_inches=0,
                                       bbox_inches='tight')
            pyplot.close(this_figure_object)

            this_title_string = radar_plotting.fields_and_heights_to_names(
                field_names=[radar_utils.REFL_NAME],
                heights_m_agl=RADAR_HEIGHTS_M_AGL[[0]],
                include_units=True)[0]

            this_title_string = this_title_string.replace('\n', ' ').replace(
                ' km AGL', ' km')
            this_title_string = 'MYRORSS {0:s}{1:s}'.format(
                this_title_string[0].lower(), this_title_string[1:])

            this_figure_object = this_handle_dict[
                plot_examples.RADAR_FIGURES_KEY][0]
            this_axes_object = this_handle_dict[
                plot_examples.RADAR_AXES_KEY][0][0, 0]

            this_figure_object.suptitle('')
            this_axes_object.set_title(this_title_string,
                                       fontsize=TITLE_FONT_SIZE)

            letter_label = chr(ord(letter_label) + 1)
            plotting_utils.label_axes(
                axes_object=this_axes_object,
                label_string='({0:s})'.format(letter_label),
                font_size=PANEL_LETTER_FONT_SIZE,
                x_coord_normalized=X_LABEL_COORD_NORMALIZED,
                y_coord_normalized=Y_LABEL_COORD_NORMALIZED)

            panel_file_names[3] = '{0:s}/myrorss_{1:s}.jpg'.format(
                output_dir_name, radar_utils.REFL_NAME.replace('_', '-'))

            print('Saving figure to: "{0:s}"...'.format(panel_file_names[3]))
            this_figure_object.savefig(panel_file_names[3],
                                       dpi=FIGURE_RESOLUTION_DPI,
                                       pad_inches=0,
                                       bbox_inches='tight')
            pyplot.close(this_figure_object)

        this_title_string = radar_plotting.fields_and_heights_to_names(
            field_names=[MYRORSS_SHEAR_FIELD_NAMES[j]],
            heights_m_agl=RADAR_HEIGHTS_M_AGL[[0]],
            include_units=True)[0]

        this_title_string = this_title_string.split('\n')[0]
        this_title_string = 'MYRORSS {0:s}{1:s}'.format(
            this_title_string[0].lower(), this_title_string[1:])

        this_figure_object = this_handle_dict[
            plot_examples.RADAR_FIGURES_KEY][1]
        this_axes_object = this_handle_dict[plot_examples.RADAR_AXES_KEY][1][0,
                                                                             0]

        this_figure_object.suptitle('')
        this_axes_object.set_title(this_title_string, fontsize=TITLE_FONT_SIZE)

        letter_label = chr(ord(letter_label) + 1)
        plotting_utils.label_axes(axes_object=this_axes_object,
                                  label_string='({0:s})'.format(letter_label),
                                  font_size=PANEL_LETTER_FONT_SIZE,
                                  x_coord_normalized=X_LABEL_COORD_NORMALIZED,
                                  y_coord_normalized=Y_LABEL_COORD_NORMALIZED)

        panel_file_names[5 + j * 2] = '{0:s}/myrorss_{1:s}.jpg'.format(
            output_dir_name, MYRORSS_SHEAR_FIELD_NAMES[j].replace('_', '-'))

        print('Saving figure to: "{0:s}"...'.format(panel_file_names[5 +
                                                                     j * 2]))
        this_figure_object.savefig(panel_file_names[5 + j * 2],
                                   dpi=FIGURE_RESOLUTION_DPI,
                                   pad_inches=0,
                                   bbox_inches='tight')
        pyplot.close(this_figure_object)

        if j != num_myrorss_shear_fields:
            print(SEPARATOR_STRING)

    concat_file_name = '{0:s}/predictors.jpg'.format(output_dir_name)
    print('Concatenating panels to: "{0:s}"...'.format(concat_file_name))

    imagemagick_utils.concatenate_images(input_file_names=panel_file_names,
                                         output_file_name=concat_file_name,
                                         num_panel_rows=4,
                                         num_panel_columns=2)

    imagemagick_utils.resize_image(input_file_name=concat_file_name,
                                   output_file_name=concat_file_name,
                                   output_size_pixels=CONCAT_FIGURE_SIZE_PX)
def _run(main_bwo_file_name, unconstrained_bwo_file_name, output_dir_name):
    """Makes figure with backwards-optimization results.

    This is effectively the main method.

    :param main_bwo_file_name: See documentation at top of file.
    :param unconstrained_bwo_file_name: Same.
    :param output_dir_name: Same.
    """

    file_system_utils.mkdir_recursive_if_necessary(
        directory_name=output_dir_name)

    main_bwo_dict, model_metadata_dict = _read_bwo_file(main_bwo_file_name)
    training_option_dict = model_metadata_dict[cnn.TRAINING_OPTION_DICT_KEY]
    radar_field_names = training_option_dict[trainval_io.RADAR_FIELDS_KEY]

    mean_before_matrices = main_bwo_dict[backwards_opt.MEAN_INPUT_MATRICES_KEY]
    mean_after_matrices = main_bwo_dict[backwards_opt.MEAN_OUTPUT_MATRICES_KEY]
    mean_sounding_pressures_pa = (
        main_bwo_dict[backwards_opt.MEAN_SOUNDING_PRESSURES_KEY])

    num_radar_heights = mean_before_matrices[0].shape[-2]

    # Plot radar and sounding after optimization.
    handle_dict = plot_examples.plot_one_example(
        list_of_predictor_matrices=mean_before_matrices,
        model_metadata_dict=model_metadata_dict,
        pmm_flag=True,
        plot_sounding=True,
        sounding_pressures_pascals=mean_sounding_pressures_pa,
        allow_whitespace=True,
        plot_panel_names=True,
        panel_name_font_size=PANEL_NAME_FONT_SIZE,
        add_titles=False,
        label_colour_bars=True,
        colour_bar_length=COLOUR_BAR_LENGTH,
        colour_bar_font_size=COLOUR_BAR_FONT_SIZE,
        sounding_font_size=SOUNDING_FONT_SIZE,
        num_panel_rows=num_radar_heights,
        plot_radar_diffs=False)

    panel_file_names = [None] * 6
    panel_file_names[3] = '{0:s}/sounding_before.jpg'.format(output_dir_name)

    _write_sounding_figure(
        figure_object=handle_dict[plot_examples.SOUNDING_FIGURE_KEY],
        title_string='(d) Original sounding',
        output_file_name=panel_file_names[3])

    panel_file_names[0] = _write_radar_figures(
        figure_objects=handle_dict[plot_examples.RADAR_FIGURES_KEY],
        field_names=radar_field_names,
        composite_name='before',
        concat_title_string='(a) Original radar image',
        output_dir_name=output_dir_name)

    print(SEPARATOR_STRING)

    # Plot radar and sounding after optimization.
    handle_dict = plot_examples.plot_one_example(
        list_of_predictor_matrices=mean_after_matrices,
        model_metadata_dict=model_metadata_dict,
        pmm_flag=True,
        plot_sounding=True,
        sounding_pressures_pascals=mean_sounding_pressures_pa,
        allow_whitespace=True,
        plot_panel_names=True,
        panel_name_font_size=PANEL_NAME_FONT_SIZE,
        add_titles=False,
        label_colour_bars=True,
        colour_bar_length=COLOUR_BAR_LENGTH,
        colour_bar_font_size=COLOUR_BAR_FONT_SIZE,
        sounding_font_size=SOUNDING_FONT_SIZE,
        num_panel_rows=num_radar_heights,
        plot_radar_diffs=False)

    panel_file_names[4] = '{0:s}/sounding_after.jpg'.format(output_dir_name)

    _write_sounding_figure(
        figure_object=handle_dict[plot_examples.SOUNDING_FIGURE_KEY],
        title_string='(e) Sounding after constrained BWO',
        output_file_name=panel_file_names[4])

    panel_file_names[1] = _write_radar_figures(
        figure_objects=handle_dict[plot_examples.RADAR_FIGURES_KEY],
        field_names=radar_field_names,
        composite_name='after',
        concat_title_string='(b) Radar after constrained BWO',
        output_dir_name=output_dir_name)

    print(SEPARATOR_STRING)

    # Plot radar and sounding after unconstrained optimization.
    unconstrained_bwo_dict, model_metadata_dict = _read_bwo_file(
        unconstrained_bwo_file_name)

    mean_after_matrices = (
        unconstrained_bwo_dict[backwards_opt.MEAN_OUTPUT_MATRICES_KEY])
    mean_sounding_pressures_pa = (
        unconstrained_bwo_dict[backwards_opt.MEAN_SOUNDING_PRESSURES_KEY])

    handle_dict = plot_examples.plot_one_example(
        list_of_predictor_matrices=mean_after_matrices,
        model_metadata_dict=model_metadata_dict,
        pmm_flag=True,
        plot_sounding=True,
        sounding_pressures_pascals=mean_sounding_pressures_pa,
        allow_whitespace=True,
        plot_panel_names=True,
        panel_name_font_size=PANEL_NAME_FONT_SIZE,
        add_titles=False,
        label_colour_bars=True,
        colour_bar_length=COLOUR_BAR_LENGTH,
        colour_bar_font_size=COLOUR_BAR_FONT_SIZE,
        sounding_font_size=SOUNDING_FONT_SIZE,
        num_panel_rows=num_radar_heights,
        plot_radar_diffs=False)

    panel_file_names[5] = (
        '{0:s}/sounding_after_unconstrained.jpg'.format(output_dir_name))

    _write_sounding_figure(
        figure_object=handle_dict[plot_examples.SOUNDING_FIGURE_KEY],
        title_string='(f) Sounding after unconstrained BWO',
        output_file_name=panel_file_names[5])

    panel_file_names[2] = _write_radar_figures(
        figure_objects=handle_dict[plot_examples.RADAR_FIGURES_KEY],
        field_names=radar_field_names,
        composite_name='after_unconstrained',
        concat_title_string='(c) Radar after unconstrained BWO',
        output_dir_name=output_dir_name)

    print(SEPARATOR_STRING)

    figure_file_name = '{0:s}/bwo_concat.jpg'.format(output_dir_name)
    print('Concatenating panels to: "{0:s}"...'.format(figure_file_name))

    imagemagick_utils.concatenate_images(input_file_names=panel_file_names,
                                         output_file_name=figure_file_name,
                                         border_width_pixels=100,
                                         num_panel_rows=2,
                                         num_panel_columns=3,
                                         extra_args_string='-gravity Center')
    imagemagick_utils.trim_whitespace(input_file_name=figure_file_name,
                                      output_file_name=figure_file_name,
                                      border_width_pixels=10)
Example #11
0
def _plot_composite(composite_file_name, composite_name_abbrev,
                    composite_name_verbose, plot_saliency, output_dir_name):
    """Plots one composite.

    :param composite_file_name: Path to input file.  Will be read by
        `_read_composite`.
    :param composite_name_abbrev: Abbreviated name for composite.  Will be used
        in names of output files.
    :param composite_name_verbose: Verbose name for composite.  Will be used as
        figure title.
    :param plot_saliency: See documentation at top of file.
    :param output_dir_name: Path to output directory.  Figures will be saved
        here.
    :return: radar_figure_file_name: Path to file with radar figure for this
        composite.
    :return: sounding_figure_file_name: Path to file with sounding figure for
        this composite.
    """

    (mean_predictor_matrices, model_metadata_dict, mean_sounding_pressures_pa,
     mean_saliency_matrices) = _read_composite(
         pickle_file_name=composite_file_name, read_saliency=plot_saliency)

    refl_heights_m_agl = model_metadata_dict[cnn.TRAINING_OPTION_DICT_KEY][
        trainval_io.RADAR_HEIGHTS_KEY]
    num_refl_heights = len(refl_heights_m_agl)

    handle_dict = plot_examples.plot_one_example(
        list_of_predictor_matrices=mean_predictor_matrices,
        model_metadata_dict=model_metadata_dict,
        pmm_flag=True,
        plot_sounding=True,
        sounding_pressures_pascals=mean_sounding_pressures_pa,
        allow_whitespace=True,
        plot_panel_names=True,
        panel_name_font_size=PANEL_NAME_FONT_SIZE,
        add_titles=False,
        label_colour_bars=True,
        colour_bar_length=COLOUR_BAR_LENGTH,
        colour_bar_font_size=COLOUR_BAR_FONT_SIZE,
        sounding_font_size=SOUNDING_FONT_SIZE,
        num_panel_rows=num_refl_heights)

    if plot_saliency:
        axes_object_matrices = handle_dict[plot_examples.RADAR_AXES_KEY]

        all_saliency_values = numpy.concatenate(
            (numpy.ravel(mean_saliency_matrices[0]),
             numpy.ravel(mean_saliency_matrices[1])))
        max_contour_value = numpy.percentile(
            numpy.absolute(all_saliency_values), 99)

        this_matrix = numpy.flip(mean_saliency_matrices[0][0, ..., 0], axis=0)
        saliency_plotting.plot_many_2d_grids_with_contours(
            saliency_matrix_3d=this_matrix,
            axes_object_matrix=axes_object_matrices[0],
            colour_map_object=SALIENCY_COLOUR_MAP_OBJECT,
            max_absolute_contour_level=max_contour_value,
            contour_interval=max_contour_value / 10,
            row_major=True)

        this_matrix = numpy.flip(mean_saliency_matrices[1][0, ...], axis=0)
        saliency_plotting.plot_many_2d_grids_with_contours(
            saliency_matrix_3d=this_matrix,
            axes_object_matrix=axes_object_matrices[1],
            colour_map_object=SALIENCY_COLOUR_MAP_OBJECT,
            max_absolute_contour_level=max_contour_value,
            contour_interval=max_contour_value / 10,
            row_major=False)

    sounding_figure_object = handle_dict[plot_examples.SOUNDING_FIGURE_KEY]
    sounding_figure_file_name = '{0:s}/{1:s}_sounding.jpg'.format(
        output_dir_name, composite_name_abbrev)

    print('Saving figure to: "{0:s}"...'.format(sounding_figure_file_name))
    sounding_figure_object.savefig(sounding_figure_file_name,
                                   dpi=FIGURE_RESOLUTION_DPI,
                                   pad_inches=0,
                                   bbox_inches='tight')
    pyplot.close(sounding_figure_object)

    imagemagick_utils.resize_image(input_file_name=sounding_figure_file_name,
                                   output_file_name=sounding_figure_file_name,
                                   output_size_pixels=CONCAT_FIGURE_SIZE_PX)
    imagemagick_utils.trim_whitespace(
        input_file_name=sounding_figure_file_name,
        output_file_name=sounding_figure_file_name,
        border_width_pixels=TITLE_FONT_SIZE + 25)
    _overlay_text(image_file_name=sounding_figure_file_name,
                  x_offset_from_center_px=0,
                  y_offset_from_top_px=0,
                  text_string=composite_name_verbose)
    imagemagick_utils.trim_whitespace(
        input_file_name=sounding_figure_file_name,
        output_file_name=sounding_figure_file_name,
        border_width_pixels=10)

    refl_figure_object = handle_dict[plot_examples.RADAR_FIGURES_KEY][0]
    refl_figure_file_name = '{0:s}/{1:s}_reflectivity.jpg'.format(
        output_dir_name, composite_name_abbrev)

    print('Saving figure to: "{0:s}"...'.format(refl_figure_file_name))
    refl_figure_object.savefig(refl_figure_file_name,
                               dpi=FIGURE_RESOLUTION_DPI,
                               pad_inches=0,
                               bbox_inches='tight')
    pyplot.close(refl_figure_object)

    shear_figure_object = handle_dict[plot_examples.RADAR_FIGURES_KEY][1]
    shear_figure_file_name = '{0:s}/{1:s}_shear.jpg'.format(
        output_dir_name, composite_name_abbrev)

    print('Saving figure to: "{0:s}"...'.format(shear_figure_file_name))
    shear_figure_object.savefig(shear_figure_file_name,
                                dpi=FIGURE_RESOLUTION_DPI,
                                pad_inches=0,
                                bbox_inches='tight')
    pyplot.close(shear_figure_object)

    radar_figure_file_name = '{0:s}/{1:s}_radar.jpg'.format(
        output_dir_name, composite_name_abbrev)

    print('Concatenating panels to: "{0:s}"...'.format(radar_figure_file_name))

    imagemagick_utils.concatenate_images(
        input_file_names=[refl_figure_file_name, shear_figure_file_name],
        output_file_name=radar_figure_file_name,
        num_panel_rows=1,
        num_panel_columns=2,
        border_width_pixels=50,
        extra_args_string='-gravity south')
    imagemagick_utils.resize_image(input_file_name=radar_figure_file_name,
                                   output_file_name=radar_figure_file_name,
                                   output_size_pixels=CONCAT_FIGURE_SIZE_PX)
    imagemagick_utils.trim_whitespace(input_file_name=radar_figure_file_name,
                                      output_file_name=radar_figure_file_name,
                                      border_width_pixels=TITLE_FONT_SIZE + 25)
    _overlay_text(image_file_name=radar_figure_file_name,
                  x_offset_from_center_px=0,
                  y_offset_from_top_px=0,
                  text_string=composite_name_verbose)
    imagemagick_utils.trim_whitespace(input_file_name=radar_figure_file_name,
                                      output_file_name=radar_figure_file_name,
                                      border_width_pixels=10)

    return radar_figure_file_name, sounding_figure_file_name
Example #12
0
def _plot_one_composite(gradcam_file_name, composite_name_abbrev,
                        composite_name_verbose, colour_map_object,
                        max_colour_value, num_contours,
                        smoothing_radius_grid_cells, output_dir_name):
    """Plots class-activation map for one composite.

    :param gradcam_file_name: Path to input file (will be read by
        `gradcam.read_file`).
    :param composite_name_abbrev: Abbrev composite name (will be used in file
        names).
    :param composite_name_verbose: Verbose composite name (will be used in
        figure title).
    :param colour_map_object: See documentation at top of file.
    :param max_colour_value: Same.
    :param num_contours: Same.
    :param smoothing_radius_grid_cells: Same.
    :param output_dir_name: Name of output directory (figures will be saved
        here).
    :return: main_figure_file_name: Path to main image file created by this
        method.
    """

    mean_predictor_matrices, mean_cam_matrices, model_metadata_dict = (
        _read_one_composite(
            gradcam_file_name=gradcam_file_name,
            smoothing_radius_grid_cells=smoothing_radius_grid_cells))

    refl_heights_m_agl = model_metadata_dict[cnn.TRAINING_OPTION_DICT_KEY][
        trainval_io.RADAR_HEIGHTS_KEY]
    num_refl_heights = len(refl_heights_m_agl)

    handle_dict = plot_examples.plot_one_example(
        list_of_predictor_matrices=mean_predictor_matrices,
        model_metadata_dict=model_metadata_dict,
        pmm_flag=True,
        plot_sounding=False,
        allow_whitespace=True,
        plot_panel_names=True,
        panel_name_font_size=PANEL_NAME_FONT_SIZE,
        add_titles=False,
        label_colour_bars=True,
        colour_bar_length=COLOUR_BAR_LENGTH,
        colour_bar_font_size=COLOUR_BAR_FONT_SIZE,
        num_panel_rows=num_refl_heights)

    axes_object_matrices = handle_dict[plot_examples.RADAR_AXES_KEY]

    max_colour_value_log10 = numpy.log10(max_colour_value)
    contour_interval_log10 = (
        (max_colour_value_log10 - MIN_COLOUR_VALUE_LOG10) / (num_contours - 1))

    this_matrix = numpy.flip(mean_cam_matrices[0][0, ...], axis=0)
    this_matrix = numpy.log10(this_matrix)

    cam_plotting.plot_many_2d_grids(class_activation_matrix_3d=this_matrix,
                                    axes_object_matrix=axes_object_matrices[0],
                                    colour_map_object=colour_map_object,
                                    min_contour_level=MIN_COLOUR_VALUE_LOG10,
                                    max_contour_level=max_colour_value_log10,
                                    contour_interval=contour_interval_log10,
                                    row_major=True)

    this_matrix = numpy.flip(mean_cam_matrices[1][0, ...], axis=0)
    this_num_channels = mean_predictor_matrices[1].shape[-1]
    this_matrix = numpy.repeat(a=numpy.expand_dims(this_matrix, axis=-1),
                               axis=-1,
                               repeats=this_num_channels)
    this_matrix = numpy.log10(this_matrix)

    cam_plotting.plot_many_2d_grids(class_activation_matrix_3d=this_matrix,
                                    axes_object_matrix=axes_object_matrices[1],
                                    colour_map_object=colour_map_object,
                                    min_contour_level=MIN_COLOUR_VALUE_LOG10,
                                    max_contour_level=max_colour_value_log10,
                                    contour_interval=contour_interval_log10,
                                    row_major=False)

    refl_figure_object = handle_dict[plot_examples.RADAR_FIGURES_KEY][0]
    refl_figure_file_name = '{0:s}/{1:s}_reflectivity.jpg'.format(
        output_dir_name, composite_name_abbrev)

    print('Saving figure to: "{0:s}"...'.format(refl_figure_file_name))
    refl_figure_object.savefig(refl_figure_file_name,
                               dpi=FIGURE_RESOLUTION_DPI,
                               pad_inches=0,
                               bbox_inches='tight')
    pyplot.close(refl_figure_object)

    shear_figure_object = handle_dict[plot_examples.RADAR_FIGURES_KEY][1]
    shear_figure_file_name = '{0:s}/{1:s}_shear.jpg'.format(
        output_dir_name, composite_name_abbrev)

    print('Saving figure to: "{0:s}"...'.format(shear_figure_file_name))
    shear_figure_object.savefig(shear_figure_file_name,
                                dpi=FIGURE_RESOLUTION_DPI,
                                pad_inches=0,
                                bbox_inches='tight')
    pyplot.close(shear_figure_object)

    main_figure_file_name = '{0:s}/{1:s}_radar.jpg'.format(
        output_dir_name, composite_name_abbrev)
    print('Concatenating panels to: "{0:s}"...'.format(main_figure_file_name))

    imagemagick_utils.concatenate_images(
        input_file_names=[refl_figure_file_name, shear_figure_file_name],
        output_file_name=main_figure_file_name,
        num_panel_rows=1,
        num_panel_columns=2,
        border_width_pixels=50,
        extra_args_string='-gravity south')
    imagemagick_utils.resize_image(input_file_name=main_figure_file_name,
                                   output_file_name=main_figure_file_name,
                                   output_size_pixels=CONCAT_FIGURE_SIZE_PX)
    imagemagick_utils.trim_whitespace(input_file_name=main_figure_file_name,
                                      output_file_name=main_figure_file_name,
                                      border_width_pixels=TITLE_FONT_SIZE + 25)
    _overlay_text(image_file_name=main_figure_file_name,
                  x_offset_from_center_px=0,
                  y_offset_from_top_px=0,
                  text_string=composite_name_verbose)
    imagemagick_utils.trim_whitespace(input_file_name=main_figure_file_name,
                                      output_file_name=main_figure_file_name,
                                      border_width_pixels=10)

    return main_figure_file_name
def _run(bwo_file_name, output_dir_name):
    """Makes figure with backwards-optimization results for MYRORSS model.

    This is effectively the main method.

    :param bwo_file_name: See documentation at top of file.
    :param output_dir_name: Same.
    """

    file_system_utils.mkdir_recursive_if_necessary(
        directory_name=output_dir_name
    )

    bwo_dictionary, model_metadata_dict = _read_bwo_file(bwo_file_name)

    mean_before_matrices = bwo_dictionary[backwards_opt.MEAN_INPUT_MATRICES_KEY]
    mean_after_matrices = bwo_dictionary[backwards_opt.MEAN_OUTPUT_MATRICES_KEY]
    mean_sounding_pressures_pa = (
        bwo_dictionary[backwards_opt.MEAN_SOUNDING_PRESSURES_KEY]
    )

    num_radar_heights = mean_before_matrices[0].shape[-2]

    # Plot sounding before optimization.
    handle_dict = plot_examples.plot_one_example(
        list_of_predictor_matrices=mean_before_matrices,
        model_metadata_dict=model_metadata_dict,
        pmm_flag=True, plot_sounding=True,
        sounding_pressures_pascals=mean_sounding_pressures_pa,
        allow_whitespace=True, plot_panel_names=True,
        panel_name_font_size=PANEL_NAME_FONT_SIZE, add_titles=False,
        label_colour_bars=True, colour_bar_length=COLOUR_BAR_LENGTH,
        colour_bar_font_size=COLOUR_BAR_FONT_SIZE,
        sounding_font_size=SOUNDING_FONT_SIZE,
        num_panel_rows=num_radar_heights, plot_radar_diffs=False
    )

    panel_file_names = [None] * 4
    panel_file_names[2] = '{0:s}/sounding_before.jpg'.format(output_dir_name)

    _write_sounding_figure(
        figure_object=handle_dict[plot_examples.SOUNDING_FIGURE_KEY],
        title_string='(c) Original sounding',
        output_file_name=panel_file_names[2]
    )
    print(SEPARATOR_STRING)

    # Plot radar and sounding after optimization.
    handle_dict = plot_examples.plot_one_example(
        list_of_predictor_matrices=mean_after_matrices,
        model_metadata_dict=model_metadata_dict,
        pmm_flag=True, plot_sounding=True,
        sounding_pressures_pascals=mean_sounding_pressures_pa,
        allow_whitespace=True, plot_panel_names=True,
        panel_name_font_size=PANEL_NAME_FONT_SIZE, add_titles=False,
        label_colour_bars=True, colour_bar_length=COLOUR_BAR_LENGTH,
        colour_bar_font_size=COLOUR_BAR_FONT_SIZE,
        sounding_font_size=SOUNDING_FONT_SIZE,
        num_panel_rows=num_radar_heights, plot_radar_diffs=False
    )

    panel_file_names[3] = '{0:s}/sounding_after.jpg'.format(output_dir_name)

    _write_sounding_figure(
        figure_object=handle_dict[plot_examples.SOUNDING_FIGURE_KEY],
        title_string='(d) Synthetic sounding',
        output_file_name=panel_file_names[3]
    )

    panel_file_names[0] = _write_radar_figures(
        refl_figure_object=handle_dict[plot_examples.RADAR_FIGURES_KEY][0],
        shear_figure_object=handle_dict[plot_examples.RADAR_FIGURES_KEY][1],
        composite_name='after',
        concat_title_string='(a) Synthetic radar image',
        output_dir_name=output_dir_name
    )

    print(SEPARATOR_STRING)

    mean_difference_matrices = [
        a - b for a, b in zip(mean_after_matrices, mean_before_matrices)
    ]

    handle_dict = plot_examples.plot_one_example(
        list_of_predictor_matrices=mean_difference_matrices,
        model_metadata_dict=model_metadata_dict,
        pmm_flag=True, plot_sounding=False,
        allow_whitespace=True, plot_panel_names=True,
        panel_name_font_size=PANEL_NAME_FONT_SIZE, add_titles=False,
        label_colour_bars=True, colour_bar_length=COLOUR_BAR_LENGTH,
        colour_bar_font_size=COLOUR_BAR_FONT_SIZE,
        num_panel_rows=num_radar_heights, plot_radar_diffs=True,
        diff_colour_map_object=DIFF_COLOUR_MAP_OBJECT,
        max_diff_percentile=MAX_DIFF_PERCENTILE
    )

    panel_file_names[1] = _write_radar_figures(
        refl_figure_object=handle_dict[plot_examples.RADAR_FIGURES_KEY][0],
        shear_figure_object=handle_dict[plot_examples.RADAR_FIGURES_KEY][1],
        composite_name='difference',
        concat_title_string='(b) Radar difference',
        output_dir_name=output_dir_name
    )

    print(SEPARATOR_STRING)

    figure_file_name = '{0:s}/bwo_concat.jpg'.format(output_dir_name)
    print('Concatenating panels to: "{0:s}"...'.format(figure_file_name))

    imagemagick_utils.concatenate_images(
        input_file_names=panel_file_names,
        output_file_name=figure_file_name, border_width_pixels=100,
        num_panel_rows=2, num_panel_columns=2,
        extra_args_string='-gravity Center'
    )
    imagemagick_utils.trim_whitespace(
        input_file_name=figure_file_name, output_file_name=figure_file_name,
        border_width_pixels=10
    )
def _run(input_file_name, colour_map_name, min_unguided_value,
         max_unguided_value, max_guided_value, num_unguided_contours,
         half_num_guided_contours, smoothing_radius_grid_cells,
         allow_whitespace, plot_panel_names, add_titles, label_colour_bars,
         colour_bar_length, top_output_dir_name):
    """Plots Grad-CAM output (guided and unguided class-activation maps).

    This is effectively the main method.

    :param input_file_name: See documentation at top of file.
    :param colour_map_name: Same.
    :param min_unguided_value: Same.
    :param max_unguided_value: Same.
    :param max_guided_value: Same.
    :param num_unguided_contours: Same.
    :param half_num_guided_contours: Same.
    :param smoothing_radius_grid_cells: Same.
    :param allow_whitespace: Same.
    :param plot_panel_names: Same.
    :param add_titles: Same.
    :param label_colour_bars: Same.
    :param colour_bar_length: Same.
    :param top_output_dir_name: Same.
    """

    if smoothing_radius_grid_cells <= 0:
        smoothing_radius_grid_cells = None

    unguided_cam_dir_name = '{0:s}/main_gradcam'.format(top_output_dir_name)
    guided_cam_dir_name = '{0:s}/guided_gradcam'.format(top_output_dir_name)

    file_system_utils.mkdir_recursive_if_necessary(
        directory_name=unguided_cam_dir_name
    )
    file_system_utils.mkdir_recursive_if_necessary(
        directory_name=guided_cam_dir_name
    )

    # Check input args.
    colour_map_object = pyplot.get_cmap(colour_map_name)
    error_checking.assert_is_greater(min_unguided_value, 0.)
    error_checking.assert_is_greater(max_unguided_value, min_unguided_value)
    error_checking.assert_is_greater(max_guided_value, 0.)
    error_checking.assert_is_geq(num_unguided_contours, 10)
    error_checking.assert_is_geq(half_num_guided_contours, 5)

    print('Reading data from: "{0:s}"...'.format(input_file_name))
    gradcam_dict, pmm_flag = gradcam.read_file(input_file_name)

    if pmm_flag:
        predictor_matrices = gradcam_dict.pop(
            gradcam.MEAN_PREDICTOR_MATRICES_KEY
        )
        cam_matrices = gradcam_dict.pop(gradcam.MEAN_CAM_MATRICES_KEY)
        guided_cam_matrices = gradcam_dict.pop(
            gradcam.MEAN_GUIDED_CAM_MATRICES_KEY
        )

        full_storm_id_strings = [None]
        storm_times_unix_sec = [None]

        for j in range(len(predictor_matrices)):
            predictor_matrices[j] = numpy.expand_dims(
                predictor_matrices[j], axis=0
            )

            if cam_matrices[j] is None:
                continue

            cam_matrices[j] = numpy.expand_dims(
                cam_matrices[j], axis=0
            )
            guided_cam_matrices[j] = numpy.expand_dims(
                guided_cam_matrices[j], axis=0
            )
    else:
        predictor_matrices = gradcam_dict.pop(gradcam.PREDICTOR_MATRICES_KEY)
        cam_matrices = gradcam_dict.pop(gradcam.CAM_MATRICES_KEY)
        guided_cam_matrices = gradcam_dict.pop(gradcam.GUIDED_CAM_MATRICES_KEY)

        full_storm_id_strings = gradcam_dict[gradcam.FULL_STORM_IDS_KEY]
        storm_times_unix_sec = gradcam_dict[gradcam.STORM_TIMES_KEY]

    if smoothing_radius_grid_cells is not None:
        cam_matrices, guided_cam_matrices = _smooth_maps(
            cam_matrices=cam_matrices, guided_cam_matrices=guided_cam_matrices,
            smoothing_radius_grid_cells=smoothing_radius_grid_cells
        )

    # Read metadata for CNN.
    model_file_name = gradcam_dict[gradcam.MODEL_FILE_KEY]
    model_metafile_name = '{0:s}/model_metadata.p'.format(
        os.path.split(model_file_name)[0]
    )

    print('Reading model metadata from: "{0:s}"...'.format(model_metafile_name))
    model_metadata_dict = cnn.read_model_metadata(model_metafile_name)
    print(SEPARATOR_STRING)

    num_examples = predictor_matrices[0].shape[0]
    num_matrices = len(predictor_matrices)

    for i in range(num_examples):
        this_handle_dict = plot_examples.plot_one_example(
            list_of_predictor_matrices=predictor_matrices,
            model_metadata_dict=model_metadata_dict, pmm_flag=pmm_flag,
            example_index=i, plot_sounding=False,
            allow_whitespace=allow_whitespace,
            plot_panel_names=plot_panel_names, add_titles=add_titles,
            label_colour_bars=label_colour_bars,
            colour_bar_length=colour_bar_length
        )

        these_figure_objects = this_handle_dict[plot_examples.RADAR_FIGURES_KEY]
        these_axes_object_matrices = (
            this_handle_dict[plot_examples.RADAR_AXES_KEY]
        )

        for j in range(num_matrices):
            if cam_matrices[j] is None:
                continue

            # print(numpy.percentile(cam_matrices[j][i, ...], 0.))
            # print(numpy.percentile(cam_matrices[j][i, ...], 1.))
            # print(numpy.percentile(cam_matrices[j][i, ...], 99.))
            # print(numpy.percentile(cam_matrices[j][i, ...], 100.))
            #
            # print('\n\n')
            #
            # print(numpy.percentile(guided_cam_matrices[j][i, ...], 0.))
            # print(numpy.percentile(guided_cam_matrices[j][i, ...], 1.))
            # print(numpy.percentile(guided_cam_matrices[j][i, ...], 99.))
            # print(numpy.percentile(guided_cam_matrices[j][i, ...], 100.))
            #
            # print('\n\n------------------------------\n\n')

            this_num_spatial_dim = len(predictor_matrices[j].shape) - 2

            if this_num_spatial_dim == 3:
                _plot_3d_radar_cam(
                    colour_map_object=colour_map_object,
                    min_unguided_value=min_unguided_value,
                    max_unguided_value=max_unguided_value,
                    num_unguided_contours=num_unguided_contours,
                    max_guided_value=max_guided_value,
                    half_num_guided_contours=half_num_guided_contours,
                    label_colour_bars=label_colour_bars,
                    colour_bar_length=colour_bar_length,
                    figure_objects=these_figure_objects,
                    axes_object_matrices=these_axes_object_matrices,
                    model_metadata_dict=model_metadata_dict,
                    output_dir_name=unguided_cam_dir_name,
                    cam_matrix=cam_matrices[j][i, ...],
                    full_storm_id_string=full_storm_id_strings[i],
                    storm_time_unix_sec=storm_times_unix_sec[i]
                )
            else:
                _plot_2d_radar_cam(
                    colour_map_object=colour_map_object,
                    min_unguided_value=min_unguided_value,
                    max_unguided_value=max_unguided_value,
                    num_unguided_contours=num_unguided_contours,
                    max_guided_value=max_guided_value,
                    half_num_guided_contours=half_num_guided_contours,
                    label_colour_bars=label_colour_bars,
                    colour_bar_length=colour_bar_length,
                    figure_objects=these_figure_objects,
                    axes_object_matrices=these_axes_object_matrices,
                    model_metadata_dict=model_metadata_dict,
                    output_dir_name=unguided_cam_dir_name,
                    cam_matrix=cam_matrices[j][i, ...],
                    full_storm_id_string=full_storm_id_strings[i],
                    storm_time_unix_sec=storm_times_unix_sec[i]
                )

        this_handle_dict = plot_examples.plot_one_example(
            list_of_predictor_matrices=predictor_matrices,
            model_metadata_dict=model_metadata_dict, pmm_flag=pmm_flag,
            example_index=i, plot_sounding=False,
            allow_whitespace=allow_whitespace,
            plot_panel_names=plot_panel_names, add_titles=add_titles,
            label_colour_bars=label_colour_bars,
            colour_bar_length=colour_bar_length
        )

        these_figure_objects = this_handle_dict[plot_examples.RADAR_FIGURES_KEY]
        these_axes_object_matrices = (
            this_handle_dict[plot_examples.RADAR_AXES_KEY]
        )

        for j in range(num_matrices):
            if guided_cam_matrices[j] is None:
                continue

            this_num_spatial_dim = len(predictor_matrices[j].shape) - 2

            if this_num_spatial_dim == 3:
                _plot_3d_radar_cam(
                    colour_map_object=colour_map_object,
                    min_unguided_value=min_unguided_value,
                    max_unguided_value=max_unguided_value,
                    num_unguided_contours=num_unguided_contours,
                    max_guided_value=max_guided_value,
                    half_num_guided_contours=half_num_guided_contours,
                    label_colour_bars=label_colour_bars,
                    colour_bar_length=colour_bar_length,
                    figure_objects=these_figure_objects,
                    axes_object_matrices=these_axes_object_matrices,
                    model_metadata_dict=model_metadata_dict,
                    output_dir_name=guided_cam_dir_name,
                    guided_cam_matrix=guided_cam_matrices[j][i, ...],
                    full_storm_id_string=full_storm_id_strings[i],
                    storm_time_unix_sec=storm_times_unix_sec[i]
                )
            else:
                _plot_2d_radar_cam(
                    colour_map_object=colour_map_object,
                    min_unguided_value=min_unguided_value,
                    max_unguided_value=max_unguided_value,
                    num_unguided_contours=num_unguided_contours,
                    max_guided_value=max_guided_value,
                    half_num_guided_contours=half_num_guided_contours,
                    label_colour_bars=label_colour_bars,
                    colour_bar_length=colour_bar_length,
                    figure_objects=these_figure_objects,
                    axes_object_matrices=these_axes_object_matrices,
                    model_metadata_dict=model_metadata_dict,
                    output_dir_name=guided_cam_dir_name,
                    guided_cam_matrix=guided_cam_matrices[j][i, ...],
                    full_storm_id_string=full_storm_id_strings[i],
                    storm_time_unix_sec=storm_times_unix_sec[i]
                )
def _plot_one_example(orig_radar_matrix, translated_radar_matrix,
                      rotated_radar_matrix, noised_radar_matrix,
                      output_dir_name, full_storm_id_string,
                      storm_time_unix_sec):
    """Plots original and augmented radar images for one example.

    M = number of rows in grid
    N = number of columns in grid

    :param orig_radar_matrix: M-by-N-by-1-by-1 numpy array with original values.
    :param translated_radar_matrix: Same but with translated values.
    :param rotated_radar_matrix: Same but with rotated values.
    :param noised_radar_matrix: Same but with noised values.
    :param output_dir_name: Name of output directory (figure will be saved
        here).
    :param full_storm_id_string: Storm ID.
    :param storm_time_unix_sec: Storm time.
    """

    dummy_heights_m_agl = numpy.array([1000, 2000, 3000, 4000], dtype=int)
    concat_radar_matrix = numpy.concatenate(
        (orig_radar_matrix, translated_radar_matrix, rotated_radar_matrix,
         noised_radar_matrix),
        axis=-2)

    training_option_dict = {
        trainval_io.SOUNDING_FIELDS_KEY: None,
        trainval_io.RADAR_FIELDS_KEY: [RADAR_FIELD_NAME],
        trainval_io.RADAR_HEIGHTS_KEY: dummy_heights_m_agl
    }

    model_metadata_dict = {cnn.TRAINING_OPTION_DICT_KEY: training_option_dict}

    handle_dict = plot_examples.plot_one_example(
        list_of_predictor_matrices=[concat_radar_matrix],
        model_metadata_dict=model_metadata_dict,
        pmm_flag=True,
        plot_sounding=False,
        allow_whitespace=True,
        plot_panel_names=False,
        add_titles=False,
        label_colour_bars=True,
        num_panel_rows=2)

    figure_object = handle_dict[plot_examples.RADAR_FIGURES_KEY][0]
    axes_object_matrix = handle_dict[plot_examples.RADAR_AXES_KEY][0]

    axes_object_matrix[0, 0].set_title('(a) Original',
                                       fontsize=TITLE_FONT_SIZE)
    axes_object_matrix[0, 1].set_title('(b) Translated',
                                       fontsize=TITLE_FONT_SIZE)
    axes_object_matrix[1, 0].set_title(r'(c) Rotated 30$^{\circ}$ clockwise',
                                       fontsize=TITLE_FONT_SIZE)
    axes_object_matrix[1, 1].set_title('(d) Noised', fontsize=TITLE_FONT_SIZE)

    output_file_name = '{0:s}/storm={1:s}_time={2:s}.jpg'.format(
        output_dir_name, full_storm_id_string.replace('_', '-'),
        time_conversion.unix_sec_to_string(storm_time_unix_sec,
                                           FILE_NAME_TIME_FORMAT))

    print('Saving figure to: "{0:s}"...'.format(output_file_name))
    figure_object.savefig(output_file_name,
                          dpi=FIGURE_RESOLUTION_DPI,
                          pad_inches=0,
                          bbox_inches='tight')
    pyplot.close(figure_object)