def test_metadata_to_file_name_fourth(self):
        """Ensures correct output from metadata_to_file_name.

        In this case, using fourth set of metadata.
        """

        this_file_name = plot_input_examples.metadata_to_file_name(
            output_dir_name=FIGURE_DIR_NAME,
            is_sounding=False,
            pmm_flag=False,
            full_storm_id_string=FULL_STORM_ID_STRING,
            storm_time_unix_sec=STORM_TIME_UNIX_SEC,
            radar_field_name=FOURTH_RADAR_FIELD_NAME,
            radar_height_m_agl=FOURTH_RADAR_HEIGHT_M_AGL,
            layer_operation_dict=FOURTH_LAYER_OPERATION_DICT)

        self.assertTrue(this_file_name == FOURTH_FIGURE_FILE_NAME)
Esempio n. 2
0
def _plot_2d_regions(figure_objects,
                     axes_object_matrices,
                     model_metadata_dict,
                     list_of_polygon_objects,
                     output_dir_name,
                     full_storm_id_string=None,
                     storm_time_unix_sec=None):
    """Plots regions of interest for 2-D radar data.

    :param figure_objects: See doc for `_plot_3d_radar_cam`.
    :param axes_object_matrices: Same.
    :param model_metadata_dict: Same.
    :param list_of_polygon_objects: List of polygons (instances of
        `shapely.geometry.Polygon`), demarcating regions of interest.
    :param output_dir_name: See doc for `_plot_3d_radar_cam`.
    :param full_storm_id_string: Same.
    :param storm_time_unix_sec: Same.
    """

    conv_2d3d = model_metadata_dict[cnn.CONV_2D3D_KEY]
    figure_index = 1 if conv_2d3d else 0

    training_option_dict = model_metadata_dict[cnn.TRAINING_OPTION_DICT_KEY]
    num_grid_rows = training_option_dict[trainval_io.NUM_ROWS_KEY]
    num_grid_rows *= 1 + int(conv_2d3d)

    list_of_layer_operation_dicts = model_metadata_dict[
        cnn.LAYER_OPERATIONS_KEY]

    if list_of_layer_operation_dicts is None:
        num_channels = len(training_option_dict[trainval_io.RADAR_FIELDS_KEY])
    else:
        num_channels = len(list_of_layer_operation_dicts)

    for this_polygon_object in list_of_polygon_objects:
        for k in range(num_channels):
            i, j = numpy.unravel_index(
                k, axes_object_matrices[figure_index].shape, order='F')

            these_grid_columns = numpy.array(
                this_polygon_object.exterior.xy[0])
            these_grid_rows = num_grid_rows - numpy.array(
                this_polygon_object.exterior.xy[1])

            axes_object_matrices[figure_index][i, j].plot(
                these_grid_columns,
                these_grid_rows,
                color=plotting_utils.colour_from_numpy_to_tuple(REGION_COLOUR),
                linestyle='solid',
                linewidth=REGION_LINE_WIDTH)

    pmm_flag = full_storm_id_string is None and storm_time_unix_sec is None

    output_file_name = plot_input_examples.metadata_to_file_name(
        output_dir_name=output_dir_name,
        is_sounding=False,
        pmm_flag=pmm_flag,
        full_storm_id_string=full_storm_id_string,
        storm_time_unix_sec=storm_time_unix_sec,
        radar_field_name='shear' if conv_2d3d else None)

    print('Saving figure to: "{0:s}"...'.format(output_file_name))
    figure_objects[figure_index].savefig(output_file_name,
                                         dpi=FIGURE_RESOLUTION_DPI,
                                         pad_inches=0,
                                         bbox_inches='tight')
    pyplot.close(figure_objects[figure_index])
Esempio n. 3
0
def _plot_2d_radar_cam(colour_map_object,
                       max_colour_percentile,
                       figure_objects,
                       axes_object_matrices,
                       model_metadata_dict,
                       output_dir_name,
                       cam_matrix=None,
                       guided_cam_matrix=None,
                       significance_matrix=None,
                       full_storm_id_string=None,
                       storm_time_unix_sec=None):
    """Plots guided or unguided class-activation map for 2-D radar data.

    M = number of rows in spatial grid
    N = number of columns in spatial grid
    C = number of radar channels

    :param colour_map_object: See doc for `_plot_3d_radar_cam`.
    :param max_colour_percentile: Same.
    :param figure_objects: Same.
    :param axes_object_matrices: Same.
    :param model_metadata_dict: Same.
    :param output_dir_name: Same.
    :param cam_matrix: M-by-N numpy array of class activations.
    :param guided_cam_matrix: M-by-N-by-C numpy array of guided-CAM output.
    :param significance_matrix: See doc for `_plot_3d_radar_cam`.
    :param full_storm_id_string: Same.
    :param storm_time_unix_sec: Same.
    """

    if cam_matrix is None:
        quantity_string = 'max abs value'
    else:
        quantity_string = 'max activation'

    pmm_flag = full_storm_id_string is None and storm_time_unix_sec is None
    conv_2d3d = model_metadata_dict[cnn.CONV_2D3D_KEY]
    figure_index = 1 if conv_2d3d else 0

    list_of_layer_operation_dicts = model_metadata_dict[
        cnn.LAYER_OPERATIONS_KEY]

    if list_of_layer_operation_dicts is None:
        training_option_dict = model_metadata_dict[
            cnn.TRAINING_OPTION_DICT_KEY]
        radar_field_names = training_option_dict[trainval_io.RADAR_FIELDS_KEY]
        num_channels = len(radar_field_names)
    else:
        num_channels = len(list_of_layer_operation_dicts)

    if cam_matrix is None:
        this_matrix = guided_cam_matrix
    else:
        this_matrix = numpy.expand_dims(cam_matrix, axis=-1)
        this_matrix = numpy.repeat(this_matrix, repeats=num_channels, axis=-1)

    max_contour_level = numpy.percentile(numpy.absolute(this_matrix),
                                         max_colour_percentile)

    if cam_matrix is None:
        saliency_plotting.plot_many_2d_grids_with_contours(
            saliency_matrix_3d=numpy.flip(this_matrix, axis=0),
            axes_object_matrix=axes_object_matrices[figure_index],
            colour_map_object=colour_map_object,
            max_absolute_contour_level=max_contour_level,
            contour_interval=max_contour_level / HALF_NUM_CONTOURS,
            row_major=False)
    else:
        cam_plotting.plot_many_2d_grids(
            class_activation_matrix_3d=numpy.flip(this_matrix, axis=0),
            axes_object_matrix=axes_object_matrices[figure_index],
            colour_map_object=colour_map_object,
            max_contour_level=max_contour_level,
            contour_interval=max_contour_level / NUM_CONTOURS,
            row_major=False)

    if significance_matrix is not None:
        if cam_matrix is None:
            this_matrix = significance_matrix
        else:
            this_matrix = numpy.expand_dims(significance_matrix, axis=-1)
            this_matrix = numpy.repeat(this_matrix,
                                       repeats=num_channels,
                                       axis=-1)

        significance_plotting.plot_many_2d_grids_without_coords(
            significance_matrix=numpy.flip(this_matrix, axis=0),
            axes_object_matrix=axes_object_matrices[figure_index],
            row_major=False)

    this_title_string = figure_objects[figure_index]._suptitle

    if this_title_string is not None:
        this_title_string += ' ({0:s} = {1:.2e})'.format(
            quantity_string, max_contour_level)

        figure_objects[figure_index].suptitle(
            this_title_string, fontsize=plot_input_examples.TITLE_FONT_SIZE)

    output_file_name = plot_input_examples.metadata_to_file_name(
        output_dir_name=output_dir_name,
        is_sounding=False,
        pmm_flag=pmm_flag,
        full_storm_id_string=full_storm_id_string,
        storm_time_unix_sec=storm_time_unix_sec,
        radar_field_name='shear' if conv_2d3d else None)

    print('Saving figure to: "{0:s}"...'.format(output_file_name))
    figure_objects[figure_index].savefig(output_file_name,
                                         dpi=FIGURE_RESOLUTION_DPI,
                                         pad_inches=0,
                                         bbox_inches='tight')
    pyplot.close(figure_objects[figure_index])
Esempio n. 4
0
def _plot_3d_radar_cam(colour_map_object,
                       max_colour_percentile,
                       figure_objects,
                       axes_object_matrices,
                       model_metadata_dict,
                       output_dir_name,
                       cam_matrix=None,
                       guided_cam_matrix=None,
                       significance_matrix=None,
                       full_storm_id_string=None,
                       storm_time_unix_sec=None):
    """Plots guided or unguided class-activation map for 3-D radar data.

    This method will plot either `cam_matrix` or `guided_cam_matrix`, not both.

    M = number of rows in spatial grid
    N = number of columns in spatial grid
    H = number of heights in spatial grid
    F = number of radar fields

    If this method is plotting a composite rather than single example (storm
    object), `full_storm_id_string` and `storm_time_unix_sec` can be None.

    :param colour_map_object: See documentation at top of file.
    :param max_colour_percentile: Same.
    :param figure_objects: See doc for `plot_input_examples._plot_3d_example`.
    :param axes_object_matrices: Same.
    :param model_metadata_dict: Dictionary returned by
        `cnn.read_model_metadata`.
    :param output_dir_name: Path to output directory.  Figure(s) will be saved
        here.
    :param cam_matrix: M-by-N-by-H numpy array of class activations.
    :param guided_cam_matrix: M-by-N-by-H-by-F numpy array of guided-CAM output.
    :param significance_matrix: Boolean numpy array with the same dimensions as
        the array being plotted (`cam_matrix` or `guided_cam_matrix`),
        indicating where differences with some other CAM are significant.
    :param full_storm_id_string: Full storm ID.
    :param storm_time_unix_sec: Storm time.
    """

    if cam_matrix is None:
        quantity_string = 'max abs value'
    else:
        quantity_string = 'max activation'

    pmm_flag = full_storm_id_string is None and storm_time_unix_sec is None
    conv_2d3d = model_metadata_dict[cnn.CONV_2D3D_KEY]

    if conv_2d3d:
        loop_max = 1
        radar_field_names = ['reflectivity']
    else:
        loop_max = len(figure_objects)
        training_option_dict = model_metadata_dict[
            cnn.TRAINING_OPTION_DICT_KEY]
        radar_field_names = training_option_dict[trainval_io.RADAR_FIELDS_KEY]

    for j in range(loop_max):
        if cam_matrix is None:
            this_matrix = guided_cam_matrix[..., j]
        else:
            this_matrix = cam_matrix

        this_max_contour_level = numpy.percentile(numpy.absolute(this_matrix),
                                                  max_colour_percentile)

        if cam_matrix is None:
            saliency_plotting.plot_many_2d_grids_with_contours(
                saliency_matrix_3d=numpy.flip(this_matrix, axis=0),
                axes_object_matrix=axes_object_matrices[j],
                colour_map_object=colour_map_object,
                max_absolute_contour_level=this_max_contour_level,
                contour_interval=this_max_contour_level / HALF_NUM_CONTOURS)
        else:
            cam_plotting.plot_many_2d_grids(
                class_activation_matrix_3d=numpy.flip(this_matrix, axis=0),
                axes_object_matrix=axes_object_matrices[j],
                colour_map_object=colour_map_object,
                max_contour_level=this_max_contour_level,
                contour_interval=this_max_contour_level / NUM_CONTOURS)

        if significance_matrix is not None:
            if cam_matrix is None:
                this_matrix = significance_matrix[..., j]
            else:
                this_matrix = significance_matrix

            significance_plotting.plot_many_2d_grids_without_coords(
                significance_matrix=numpy.flip(this_matrix, axis=0),
                axes_object_matrix=axes_object_matrices[j])

        this_title_string = figure_objects[j]._suptitle

        if this_title_string is not None:
            this_title_string += ' ({0:s} = {1:.2e})'.format(
                quantity_string, this_max_contour_level)

            figure_objects[j].suptitle(
                this_title_string,
                fontsize=plot_input_examples.TITLE_FONT_SIZE)

        this_file_name = plot_input_examples.metadata_to_file_name(
            output_dir_name=output_dir_name,
            is_sounding=False,
            pmm_flag=pmm_flag,
            full_storm_id_string=full_storm_id_string,
            storm_time_unix_sec=storm_time_unix_sec,
            radar_field_name=radar_field_names[j])

        print('Saving figure to: "{0:s}"...'.format(this_file_name))
        figure_objects[j].savefig(this_file_name,
                                  dpi=FIGURE_RESOLUTION_DPI,
                                  pad_inches=0,
                                  bbox_inches='tight')
        pyplot.close(figure_objects[j])
def _plot_sounding_saliency(
        saliency_matrix, colour_map_object, max_colour_value, sounding_matrix,
        saliency_dict, model_metadata_dict, output_dir_name, pmm_flag,
        example_index=None):
    """Plots saliency for sounding.

    H = number of sounding heights
    F = number of sounding fields

    If plotting a composite rather than one example, `full_storm_id_string` and
    `storm_time_unix_sec` can be None.

    :param saliency_matrix: H-by-F numpy array of saliency values.
    :param colour_map_object: See documentation at top of file.
    :param max_colour_value: Max value in colour scheme for saliency.
    :param sounding_matrix: H-by-F numpy array of actual sounding values.
    :param saliency_dict: Dictionary returned from
        `saliency_maps.read_standard_file` or `saliency_maps.read_pmm_file`.
    :param model_metadata_dict: Dictionary returned by
        `cnn.read_model_metadata`.
    :param output_dir_name: Path to output directory.  Figure will be saved
        here.
    :param pmm_flag: Boolean flag.  If True, will plot composite rather than one
        example.
    :param example_index: [used only if `pmm_flag == False`]
        Will plot the [i]th example, where i = `example_index`.
    """

    if pmm_flag:
        example_index = 0

    training_option_dict = model_metadata_dict[cnn.TRAINING_OPTION_DICT_KEY]
    sounding_field_names = training_option_dict[trainval_io.SOUNDING_FIELDS_KEY]
    sounding_heights_m_agl = training_option_dict[
        trainval_io.SOUNDING_HEIGHTS_KEY]

    sounding_matrix = numpy.expand_dims(sounding_matrix, axis=0)

    if saliency_maps.SOUNDING_PRESSURES_KEY in saliency_dict:
        pressure_matrix_pascals = numpy.expand_dims(
            saliency_dict[saliency_maps.SOUNDING_PRESSURES_KEY], axis=-1
        )

        pressure_matrix_pascals = pressure_matrix_pascals[[example_index], ...]
        sounding_matrix = numpy.concatenate(
            (sounding_matrix, pressure_matrix_pascals), axis=-1
        )

        sounding_dict_for_metpy = dl_utils.soundings_to_metpy_dictionaries(
            sounding_matrix=sounding_matrix,
            field_names=sounding_field_names + [soundings.PRESSURE_NAME]
        )[0]
    else:
        sounding_dict_for_metpy = dl_utils.soundings_to_metpy_dictionaries(
            sounding_matrix=sounding_matrix, field_names=sounding_field_names,
            height_levels_m_agl=sounding_heights_m_agl,
            storm_elevations_m_asl=numpy.array([0.])
        )[0]

    if pmm_flag:
        full_storm_id_string = None
        storm_time_unix_sec = None
        title_string = 'PMM composite'
    else:
        full_storm_id_string = saliency_dict[saliency_maps.FULL_IDS_KEY][
            example_index]
        storm_time_unix_sec = saliency_dict[saliency_maps.STORM_TIMES_KEY][
            example_index]

        title_string = 'Storm "{0:s}" at {1:s}'.format(
            full_storm_id_string,
            time_conversion.unix_sec_to_string(
                storm_time_unix_sec, plot_input_examples.TIME_FORMAT)
        )

    sounding_plotting.plot_sounding(
        sounding_dict_for_metpy=sounding_dict_for_metpy,
        title_string=title_string)

    left_panel_file_name = plot_input_examples.metadata_to_file_name(
        output_dir_name=output_dir_name, is_sounding=False, pmm_flag=pmm_flag,
        full_storm_id_string=full_storm_id_string,
        storm_time_unix_sec=storm_time_unix_sec,
        radar_field_name='sounding-actual')

    print('Saving figure to file: "{0:s}"...'.format(left_panel_file_name))
    pyplot.savefig(left_panel_file_name, dpi=FIGURE_RESOLUTION_DPI)
    pyplot.close()

    imagemagick_utils.trim_whitespace(
        input_file_name=left_panel_file_name,
        output_file_name=left_panel_file_name)

    saliency_plotting.plot_saliency_for_sounding(
        saliency_matrix=saliency_matrix,
        sounding_field_names=sounding_field_names,
        pressure_levels_mb=sounding_dict_for_metpy[
            soundings.PRESSURE_COLUMN_METPY],
        colour_map_object=colour_map_object,
        max_absolute_colour_value=max_colour_value)

    right_panel_file_name = plot_input_examples.metadata_to_file_name(
        output_dir_name=output_dir_name, is_sounding=False, pmm_flag=pmm_flag,
        full_storm_id_string=full_storm_id_string,
        storm_time_unix_sec=storm_time_unix_sec,
        radar_field_name='sounding-saliency')

    print('Saving figure to file: "{0:s}"...'.format(right_panel_file_name))
    pyplot.savefig(right_panel_file_name, dpi=FIGURE_RESOLUTION_DPI)
    pyplot.close()

    imagemagick_utils.trim_whitespace(
        input_file_name=right_panel_file_name,
        output_file_name=right_panel_file_name)

    concat_file_name = plot_input_examples.metadata_to_file_name(
        output_dir_name=output_dir_name, is_sounding=True, pmm_flag=pmm_flag,
        full_storm_id_string=full_storm_id_string,
        storm_time_unix_sec=storm_time_unix_sec)

    print('Concatenating figures to: "{0:s}"...\n'.format(concat_file_name))
    imagemagick_utils.concatenate_images(
        input_file_names=[left_panel_file_name, right_panel_file_name],
        output_file_name=concat_file_name, num_panel_rows=1,
        num_panel_columns=2)

    imagemagick_utils.resize_image(
        input_file_name=concat_file_name, output_file_name=concat_file_name,
        output_size_pixels=SOUNDING_IMAGE_SIZE_PX)

    os.remove(left_panel_file_name)
    os.remove(right_panel_file_name)
def _plot_2d_radar_saliency(
        saliency_matrix, colour_map_object, max_colour_value, figure_objects,
        axes_object_matrices, model_metadata_dict, output_dir_name,
        significance_matrix=None, full_storm_id_string=None,
        storm_time_unix_sec=None):
    """Plots saliency map for 2-D radar data.

    M = number of rows in spatial grid
    N = number of columns in spatial grid
    C = number of radar channels

    If this method is plotting a composite rather than single example (storm
    object), `full_storm_id_string` and `storm_time_unix_sec` can be None.

    :param saliency_matrix: M-by-N-by-C numpy array of saliency values.
    :param colour_map_object: See documentation at top of file.
    :param max_colour_value: Max value in colour scheme for saliency.
    :param figure_objects: See doc for
        `plot_input_examples._plot_2d_radar_scan`.
    :param axes_object_matrices: Same.
    :param model_metadata_dict: Dictionary returned by
        `cnn.read_model_metadata`.
    :param output_dir_name: Path to output directory.  Figure(s) will be saved
        here.
    :param significance_matrix: M-by-N-by-H numpy array of Boolean flags,
        indicating where differences with some other saliency map are
        significant.
    :param full_storm_id_string: Full storm ID.
    :param storm_time_unix_sec: Storm time.
    """

    pmm_flag = full_storm_id_string is None and storm_time_unix_sec is None
    conv_2d3d = model_metadata_dict[cnn.CONV_2D3D_KEY]

    if conv_2d3d:
        figure_index = 1
        radar_field_name = 'shear'
    else:
        figure_index = 0
        radar_field_name = None

    list_of_layer_operation_dicts = model_metadata_dict[
        cnn.LAYER_OPERATIONS_KEY]

    if list_of_layer_operation_dicts is not None:
        saliency_matrix = saliency_matrix[
            ..., plot_input_examples.LAYER_OP_INDICES_TO_KEEP
        ]

        if significance_matrix is not None:
            significance_matrix = significance_matrix[
                ..., plot_input_examples.LAYER_OP_INDICES_TO_KEEP
            ]

    saliency_plotting.plot_many_2d_grids_with_contours(
        saliency_matrix_3d=numpy.flip(saliency_matrix, axis=0),
        axes_object_matrix=axes_object_matrices[figure_index],
        colour_map_object=colour_map_object,
        max_absolute_contour_level=max_colour_value,
        contour_interval=max_colour_value / HALF_NUM_CONTOURS,
        row_major=False)

    if significance_matrix is not None:
        significance_plotting.plot_many_2d_grids_without_coords(
            significance_matrix=numpy.flip(significance_matrix, axis=0),
            axes_object_matrix=axes_object_matrices[figure_index],
            row_major=False)

    this_title_object = figure_objects[figure_index]._suptitle

    if this_title_object is not None:
        this_title_string = '{0:s} (max abs saliency = {1:.2e})'.format(
            this_title_object.get_text(), max_colour_value
        )

        figure_objects[figure_index].suptitle(
            this_title_string, fontsize=plot_input_examples.TITLE_FONT_SIZE)

    output_file_name = plot_input_examples.metadata_to_file_name(
        output_dir_name=output_dir_name, is_sounding=False, pmm_flag=pmm_flag,
        full_storm_id_string=full_storm_id_string,
        storm_time_unix_sec=storm_time_unix_sec,
        radar_field_name=radar_field_name)

    print('Saving figure to: "{0:s}"...'.format(output_file_name))
    figure_objects[figure_index].savefig(
        output_file_name, dpi=FIGURE_RESOLUTION_DPI, pad_inches=0,
        bbox_inches='tight'
    )
    pyplot.close(figure_objects[figure_index])
def _plot_3d_radar_saliency(
        saliency_matrix, colour_map_object, max_colour_value, figure_objects,
        axes_object_matrices, model_metadata_dict, output_dir_name,
        significance_matrix=None, full_storm_id_string=None,
        storm_time_unix_sec=None):
    """Plots saliency map for 3-D radar data.

    M = number of rows in spatial grid
    N = number of columns in spatial grid
    H = number of heights in spatial grid
    F = number of radar fields

    If this method is plotting a composite rather than single example (storm
    object), `full_storm_id_string` and `storm_time_unix_sec` can be None.

    :param saliency_matrix: M-by-N-by-H-by-F numpy array of saliency values.
    :param colour_map_object: See documentation at top of file.
    :param max_colour_value: Max value in colour scheme for saliency.
    :param figure_objects: See doc for
        `plot_input_examples._plot_3d_radar_scan`.
    :param axes_object_matrices: Same.
    :param model_metadata_dict: Dictionary returned by
        `cnn.read_model_metadata`.
    :param output_dir_name: Path to output directory.  Figure(s) will be saved
        here.
    :param significance_matrix: M-by-N-by-H-by-F numpy array of Boolean flags,
        indicating where differences with some other saliency map are
        significant.
    :param full_storm_id_string: Full storm ID.
    :param storm_time_unix_sec: Storm time.
    """

    pmm_flag = full_storm_id_string is None and storm_time_unix_sec is None
    conv_2d3d = model_metadata_dict[cnn.CONV_2D3D_KEY]

    if conv_2d3d:
        loop_max = 1
        radar_field_names = ['reflectivity']
    else:
        loop_max = len(figure_objects)
        training_option_dict = model_metadata_dict[cnn.TRAINING_OPTION_DICT_KEY]
        radar_field_names = training_option_dict[trainval_io.RADAR_FIELDS_KEY]

    for j in range(loop_max):
        saliency_plotting.plot_many_2d_grids_with_contours(
            saliency_matrix_3d=numpy.flip(saliency_matrix[..., j], axis=0),
            axes_object_matrix=axes_object_matrices[j],
            colour_map_object=colour_map_object,
            max_absolute_contour_level=max_colour_value,
            contour_interval=max_colour_value / HALF_NUM_CONTOURS)

        if significance_matrix is not None:
            this_matrix = numpy.flip(significance_matrix[..., j], axis=0)

            significance_plotting.plot_many_2d_grids_without_coords(
                significance_matrix=this_matrix,
                axes_object_matrix=axes_object_matrices[j]
            )

        this_title_object = figure_objects[j]._suptitle

        if this_title_object is not None:
            this_title_string = '{0:s} (max abs saliency = {1:.2e})'.format(
                this_title_object.get_text(), max_colour_value
            )

            figure_objects[j].suptitle(
                this_title_string, fontsize=plot_input_examples.TITLE_FONT_SIZE)

        this_file_name = plot_input_examples.metadata_to_file_name(
            output_dir_name=output_dir_name, is_sounding=False,
            pmm_flag=pmm_flag, full_storm_id_string=full_storm_id_string,
            storm_time_unix_sec=storm_time_unix_sec,
            radar_field_name=radar_field_names[j]
        )

        print('Saving figure to: "{0:s}"...'.format(this_file_name))
        figure_objects[j].savefig(
            this_file_name, dpi=FIGURE_RESOLUTION_DPI, pad_inches=0,
            bbox_inches='tight'
        )
        pyplot.close(figure_objects[j])
def _plot_sounding_saliency(
        saliency_matrix, colour_map_object, max_colour_value,
        sounding_figure_object, sounding_axes_object,
        sounding_pressures_pascals, saliency_dict, model_metadata_dict,
        add_title, output_dir_name, pmm_flag, example_index=None):
    """Plots saliency for sounding.

    H = number of sounding heights
    F = number of sounding fields

    :param saliency_matrix: H-by-F numpy array of saliency values.
    :param colour_map_object: See documentation at top of file.
    :param max_colour_value: Same.
    :param sounding_figure_object: Figure handle (instance of
        `matplotlib.figure.Figure`) for sounding itself.
    :param sounding_axes_object: Axes handle (instance of
        `matplotlib.axes._subplots.AxesSubplot`) for sounding itself.
    :param sounding_pressures_pascals: length-H numpy array of sounding
        pressures.
    :param saliency_dict: Dictionary returned by `saliency_maps.read_file`.
    :param model_metadata_dict: Dictionary returned by
        `cnn.read_model_metadata`.
    :param add_title: Boolean flag.
    :param output_dir_name: Name of output directory.  Figures will be saved
        here.
    :param pmm_flag: Boolean flag.  If True, plotting PMM composite rather than
        one example.
    :param example_index: [used only if `pmm_flag == False`]
        Plotting the [i]th example, where i = `example_index`.
    """

    if max_colour_value is None:
        max_colour_value = numpy.percentile(
            numpy.absolute(saliency_matrix), MAX_COLOUR_PERCENTILE
        )

    training_option_dict = model_metadata_dict[cnn.TRAINING_OPTION_DICT_KEY]
    sounding_field_names = training_option_dict[trainval_io.SOUNDING_FIELDS_KEY]

    if pmm_flag:
        full_storm_id_string = None
        storm_time_unix_sec = None
    else:
        full_storm_id_string = saliency_dict[saliency_maps.FULL_STORM_IDS_KEY][
            example_index]
        storm_time_unix_sec = saliency_dict[saliency_maps.STORM_TIMES_KEY][
            example_index]

    if add_title:
        title_string = 'Max absolute saliency = {0:.2e}'.format(
            max_colour_value)
        sounding_axes_object.set_title(title_string)

    left_panel_file_name = plot_examples.metadata_to_file_name(
        output_dir_name=output_dir_name, is_sounding=False, pmm_flag=pmm_flag,
        full_storm_id_string=full_storm_id_string,
        storm_time_unix_sec=storm_time_unix_sec,
        radar_field_name='sounding-actual')

    print('Saving figure to file: "{0:s}"...'.format(left_panel_file_name))
    sounding_figure_object.savefig(
        left_panel_file_name, dpi=FIGURE_RESOLUTION_DPI, pad_inches=0,
        bbox_inches='tight')
    pyplot.close(sounding_figure_object)

    saliency_plotting.plot_saliency_for_sounding(
        saliency_matrix=saliency_matrix,
        sounding_field_names=sounding_field_names,
        pressure_levels_mb=PASCALS_TO_MB * sounding_pressures_pascals,
        colour_map_object=colour_map_object,
        max_absolute_colour_value=max_colour_value)

    right_panel_file_name = plot_examples.metadata_to_file_name(
        output_dir_name=output_dir_name, is_sounding=False, pmm_flag=pmm_flag,
        full_storm_id_string=full_storm_id_string,
        storm_time_unix_sec=storm_time_unix_sec,
        radar_field_name='sounding-saliency')

    print('Saving figure to file: "{0:s}"...'.format(right_panel_file_name))
    pyplot.savefig(right_panel_file_name, dpi=FIGURE_RESOLUTION_DPI,
                   pad_inches=0, bbox_inches='tight')
    pyplot.close()

    concat_file_name = plot_examples.metadata_to_file_name(
        output_dir_name=output_dir_name, is_sounding=True, pmm_flag=pmm_flag,
        full_storm_id_string=full_storm_id_string,
        storm_time_unix_sec=storm_time_unix_sec)

    print('Concatenating figures to: "{0:s}"...\n'.format(concat_file_name))
    imagemagick_utils.concatenate_images(
        input_file_names=[left_panel_file_name, right_panel_file_name],
        output_file_name=concat_file_name,
        num_panel_rows=1, num_panel_columns=2
    )

    imagemagick_utils.resize_image(
        input_file_name=concat_file_name, output_file_name=concat_file_name,
        output_size_pixels=SOUNDING_IMAGE_SIZE_PX)

    os.remove(left_panel_file_name)
    os.remove(right_panel_file_name)
def _plot_2d_radar_saliency(
        saliency_matrix, colour_map_object, max_colour_value, half_num_contours,
        label_colour_bars, colour_bar_length, figure_objects,
        axes_object_matrices, model_metadata_dict, output_dir_name,
        significance_matrix=None, full_storm_id_string=None,
        storm_time_unix_sec=None):
    """Plots saliency map for 2-D radar data.

    M = number of rows in spatial grid
    N = number of columns in spatial grid
    C = number of radar channels

    If this method is plotting a composite rather than single example (storm
    object), `full_storm_id_string` and `storm_time_unix_sec` can be None.

    :param saliency_matrix: M-by-N-by-C numpy array of saliency values.
    :param colour_map_object: See documentation at top of file.
    :param max_colour_value: Same.
    :param half_num_contours: Same.
    :param label_colour_bars: Same.
    :param colour_bar_length: Same.
    :param figure_objects: See doc for
        `plot_input_examples._plot_2d_radar_scan`.
    :param axes_object_matrices: Same.
    :param model_metadata_dict: Dictionary returned by
        `cnn.read_model_metadata`.
    :param output_dir_name: Path to output directory.  Figure(s) will be saved
        here.
    :param significance_matrix: M-by-N-by-H numpy array of Boolean flags,
        indicating where differences with some other saliency map are
        significant.
    :param full_storm_id_string: Full storm ID.
    :param storm_time_unix_sec: Storm time.
    """

    if max_colour_value is None:
        max_colour_value = numpy.percentile(
            numpy.absolute(saliency_matrix), MAX_COLOUR_PERCENTILE
        )

    pmm_flag = full_storm_id_string is None and storm_time_unix_sec is None
    conv_2d3d = model_metadata_dict[cnn.CONV_2D3D_KEY]

    if conv_2d3d:
        figure_index = 1
        radar_field_name = 'shear'
    else:
        figure_index = 0
        radar_field_name = None

    saliency_plotting.plot_many_2d_grids_with_contours(
        saliency_matrix_3d=numpy.flip(saliency_matrix, axis=0),
        axes_object_matrix=axes_object_matrices[figure_index],
        colour_map_object=colour_map_object,
        max_absolute_contour_level=max_colour_value,
        contour_interval=max_colour_value / half_num_contours,
        row_major=False)

    if significance_matrix is not None:
        significance_plotting.plot_many_2d_grids_without_coords(
            significance_matrix=numpy.flip(significance_matrix, axis=0),
            axes_object_matrix=axes_object_matrices[figure_index],
            row_major=False)

    colour_bar_object = plotting_utils.plot_linear_colour_bar(
        axes_object_or_matrix=axes_object_matrices[figure_index],
        data_matrix=saliency_matrix,
        colour_map_object=colour_map_object, min_value=0.,
        max_value=max_colour_value, orientation_string='horizontal',
        fraction_of_axis_length=colour_bar_length / (1 + int(conv_2d3d)),
        extend_min=False, extend_max=True, font_size=COLOUR_BAR_FONT_SIZE)

    if label_colour_bars:
        colour_bar_object.set_label(
            'Absolute saliency', fontsize=COLOUR_BAR_FONT_SIZE)

    output_file_name = plot_examples.metadata_to_file_name(
        output_dir_name=output_dir_name, is_sounding=False, pmm_flag=pmm_flag,
        full_storm_id_string=full_storm_id_string,
        storm_time_unix_sec=storm_time_unix_sec,
        radar_field_name=radar_field_name)

    print('Saving figure to: "{0:s}"...'.format(output_file_name))
    figure_objects[figure_index].savefig(
        output_file_name, dpi=FIGURE_RESOLUTION_DPI, pad_inches=0,
        bbox_inches='tight'
    )
    pyplot.close(figure_objects[figure_index])
def _plot_one_example(
        radar_matrix, sounding_matrix, sounding_pressures_pascals,
        full_storm_id_string, storm_time_unix_sec, model_metadata_dict,
        output_dir_name):
    """Plots predictors for one example.

    M = number of rows in radar grid
    N = number of columns in radar grid
    H_r = number of heights in radar grid
    F_r = number of radar fields
    H_s = number of sounding heights
    F_s = number of sounding fields

    :param radar_matrix: numpy array (1 x M x N x H_r x F_r) of radar values.
    :param sounding_matrix: numpy array (1 x H_s x F_s) of sounding values.
    :param sounding_pressures_pascals: numpy array (length H_s) of sounding
        pressures.
    :param full_storm_id_string: Full storm ID.
    :param storm_time_unix_sec: Valid time.
    :param model_metadata_dict: Dictionary returned by
        `cnn.read_model_metadata`.
    :param output_dir_name: Name of output directory (figures will be saved
        here).
    :return: radar_figure_file_name: Path to radar figure created by this
        method.
    """

    training_option_dict = model_metadata_dict[cnn.TRAINING_OPTION_DICT_KEY]
    radar_field_names = training_option_dict[trainval_io.RADAR_FIELDS_KEY]

    num_radar_fields = radar_matrix.shape[-1]
    num_radar_heights = radar_matrix.shape[-2]

    handle_dict = plot_examples.plot_one_example(
        list_of_predictor_matrices=[radar_matrix, sounding_matrix],
        model_metadata_dict=model_metadata_dict,
        pmm_flag=False, example_index=0, plot_sounding=True,
        sounding_pressures_pascals=sounding_pressures_pascals,
        allow_whitespace=True, plot_panel_names=True,
        panel_name_font_size=PANEL_NAME_FONT_SIZE,
        add_titles=False, label_colour_bars=True,
        colour_bar_length=COLOUR_BAR_LENGTH,
        colour_bar_font_size=COLOUR_BAR_FONT_SIZE,
        num_panel_rows=num_radar_heights
    )

    sounding_file_name = plot_examples.metadata_to_file_name(
        output_dir_name=output_dir_name, is_sounding=True, pmm_flag=False,
        full_storm_id_string=full_storm_id_string,
        storm_time_unix_sec=storm_time_unix_sec
    )

    print('Saving figure to: "{0:s}"...'.format(sounding_file_name))

    sounding_figure_object = handle_dict[plot_examples.SOUNDING_FIGURE_KEY]
    sounding_figure_object.savefig(
        sounding_file_name, dpi=FIGURE_RESOLUTION_DPI,
        pad_inches=0, bbox_inches='tight'
    )
    pyplot.close(sounding_figure_object)

    figure_objects = handle_dict[plot_examples.RADAR_FIGURES_KEY]
    panel_file_names = [None] * num_radar_fields

    for k in range(num_radar_fields):
        panel_file_names[k] = plot_examples.metadata_to_file_name(
            output_dir_name=output_dir_name, is_sounding=False, pmm_flag=False,
            full_storm_id_string=full_storm_id_string,
            storm_time_unix_sec=storm_time_unix_sec,
            radar_field_name=radar_field_names[k]
        )

        print('Saving figure to: "{0:s}"...'.format(panel_file_names[k]))

        figure_objects[k].savefig(
            panel_file_names[k], dpi=FIGURE_RESOLUTION_DPI,
            pad_inches=0, bbox_inches='tight'
        )
        pyplot.close(figure_objects[k])

    radar_figure_file_name = plot_examples.metadata_to_file_name(
        output_dir_name=output_dir_name, is_sounding=False, pmm_flag=False,
        full_storm_id_string=full_storm_id_string,
        storm_time_unix_sec=storm_time_unix_sec
    )

    print('Concatenating panels to: "{0:s}"...'.format(radar_figure_file_name))

    imagemagick_utils.concatenate_images(
        input_file_names=panel_file_names,
        output_file_name=radar_figure_file_name,
        num_panel_rows=1, num_panel_columns=num_radar_fields,
        border_width_pixels=50
    )
    imagemagick_utils.resize_image(
        input_file_name=radar_figure_file_name,
        output_file_name=radar_figure_file_name,
        output_size_pixels=CONCAT_FIGURE_SIZE_PX
    )
    imagemagick_utils.trim_whitespace(
        input_file_name=radar_figure_file_name,
        output_file_name=radar_figure_file_name,
        border_width_pixels=10
    )

    for this_file_name in panel_file_names:
        os.remove(this_file_name)

    return radar_figure_file_name
def _plot_2d_radar_cam(
        colour_map_object, min_unguided_value, max_unguided_value,
        num_unguided_contours, max_guided_value, half_num_guided_contours,
        label_colour_bars, colour_bar_length, figure_objects,
        axes_object_matrices, model_metadata_dict, output_dir_name,
        cam_matrix=None, guided_cam_matrix=None, full_storm_id_string=None,
        storm_time_unix_sec=None):
    """Plots class-activation map for 2-D radar data.

    M = number of rows in spatial grid
    N = number of columns in spatial grid
    F = number of radar fields

    If this method is plotting a composite rather than single example (storm
    object), `full_storm_id_string` and `storm_time_unix_sec` can be None.

    :param colour_map_object: See doc for `_plot_3d_radar_cam`.
    :param min_unguided_value: Same.
    :param max_unguided_value: Same.
    :param num_unguided_contours: Same.
    :param max_guided_value: Same.
    :param half_num_guided_contours: Same.
    :param label_colour_bars: Same.
    :param colour_bar_length: Same.
    :param figure_objects: See doc for
        `plot_input_examples._plot_2d_radar_scan`.
    :param axes_object_matrices: Same.
    :param model_metadata_dict: See doc for `_plot_3d_radar_cam`.
    :param output_dir_name: Same.
    :param cam_matrix: M-by-N numpy array of unguided class activations.
    :param guided_cam_matrix: [used only if `cam_matrix is None`]
        M-by-N-by-F numpy array of guided class activations.
    :param full_storm_id_string: Full storm ID.
    :param storm_time_unix_sec: Storm time.
    """

    pmm_flag = full_storm_id_string is None and storm_time_unix_sec is None
    conv_2d3d = model_metadata_dict[cnn.CONV_2D3D_KEY]

    if conv_2d3d:
        figure_index = 1
        radar_field_name = 'shear'
    else:
        figure_index = 0
        radar_field_name = None

    list_of_layer_operation_dicts = model_metadata_dict[
        cnn.LAYER_OPERATIONS_KEY]

    if list_of_layer_operation_dicts is None:
        training_option_dict = model_metadata_dict[cnn.TRAINING_OPTION_DICT_KEY]
        radar_field_names = training_option_dict[trainval_io.RADAR_FIELDS_KEY]
        num_channels = len(radar_field_names)
    else:
        num_channels = len(list_of_layer_operation_dicts)

    min_unguided_value_log10 = numpy.log10(min_unguided_value)
    max_unguided_value_log10 = numpy.log10(max_unguided_value)
    contour_interval_log10 = (
        (max_unguided_value_log10 - min_unguided_value_log10) /
        (num_unguided_contours - 1)
    )

    if cam_matrix is None:
        saliency_plotting.plot_many_2d_grids_with_contours(
            saliency_matrix_3d=numpy.flip(guided_cam_matrix, axis=0),
            axes_object_matrix=axes_object_matrices[figure_index],
            colour_map_object=colour_map_object,
            max_absolute_contour_level=max_guided_value,
            contour_interval=max_guided_value / half_num_guided_contours,
            row_major=False
        )

        this_colour_bar_object = plotting_utils.plot_linear_colour_bar(
            axes_object_or_matrix=axes_object_matrices[figure_index],
            data_matrix=guided_cam_matrix,
            colour_map_object=colour_map_object, min_value=0.,
            max_value=max_guided_value, orientation_string='horizontal',
            fraction_of_axis_length=colour_bar_length / (1 + int(conv_2d3d)),
            extend_min=False, extend_max=True,
            font_size=COLOUR_BAR_FONT_SIZE
        )

        if label_colour_bars:
            this_colour_bar_object.set_label(
                'Absolute guided class activation',
                fontsize=COLOUR_BAR_FONT_SIZE
            )
    else:
        this_cam_matrix_log10 = numpy.log10(
            numpy.expand_dims(cam_matrix, axis=-1)
        )
        this_cam_matrix_log10 = numpy.repeat(
            this_cam_matrix_log10, repeats=num_channels, axis=-1
        )

        cam_plotting.plot_many_2d_grids(
            class_activation_matrix_3d=numpy.flip(
                this_cam_matrix_log10, axis=0
            ),
            axes_object_matrix=axes_object_matrices[figure_index],
            colour_map_object=colour_map_object,
            min_contour_level=min_unguided_value_log10,
            max_contour_level=max_unguided_value_log10,
            contour_interval=contour_interval_log10, row_major=False
        )

        this_colour_bar_object = plotting_utils.plot_linear_colour_bar(
            axes_object_or_matrix=axes_object_matrices[figure_index],
            data_matrix=this_cam_matrix_log10,
            colour_map_object=colour_map_object,
            min_value=min_unguided_value_log10,
            max_value=max_unguided_value_log10,
            orientation_string='horizontal',
            fraction_of_axis_length=colour_bar_length / (1 + int(conv_2d3d)),
            extend_min=True, extend_max=True,
            font_size=COLOUR_BAR_FONT_SIZE
        )

        these_tick_values = this_colour_bar_object.get_ticks()
        these_tick_strings = [
            '{0:.2f}'.format(10 ** v)[:4] for v in these_tick_values
        ]
        this_colour_bar_object.set_ticks(these_tick_values)
        this_colour_bar_object.set_ticklabels(these_tick_strings)

        if label_colour_bars:
            this_colour_bar_object.set_label(
                'Class activation', fontsize=COLOUR_BAR_FONT_SIZE
            )

    output_file_name = plot_examples.metadata_to_file_name(
        output_dir_name=output_dir_name, is_sounding=False, pmm_flag=pmm_flag,
        full_storm_id_string=full_storm_id_string,
        storm_time_unix_sec=storm_time_unix_sec,
        radar_field_name=radar_field_name
    )

    print('Saving figure to: "{0:s}"...'.format(output_file_name))
    figure_objects[figure_index].savefig(
        output_file_name, dpi=FIGURE_RESOLUTION_DPI, pad_inches=0,
        bbox_inches='tight'
    )
    pyplot.close(figure_objects[figure_index])
def _plot_3d_radar_cam(
        colour_map_object, min_unguided_value, max_unguided_value,
        num_unguided_contours, max_guided_value, half_num_guided_contours,
        label_colour_bars, colour_bar_length, figure_objects,
        axes_object_matrices, model_metadata_dict, output_dir_name,
        cam_matrix=None, guided_cam_matrix=None, full_storm_id_string=None,
        storm_time_unix_sec=None):
    """Plots class-activation map for 3-D radar data.

    M = number of rows in spatial grid
    N = number of columns in spatial grid
    H = number of heights in spatial grid
    F = number of radar fields

    If this method is plotting a composite rather than single example (storm
    object), `full_storm_id_string` and `storm_time_unix_sec` can be None.

    :param colour_map_object: See documentation at top of file.
    :param min_unguided_value: Same.
    :param max_unguided_value: Same.
    :param num_unguided_contours: Same.
    :param max_guided_value: Same.
    :param half_num_guided_contours: Same.
    :param label_colour_bars: Same.
    :param colour_bar_length: Same.
    :param figure_objects: See doc for
        `plot_input_examples._plot_3d_radar_scan`.
    :param axes_object_matrices: Same.
    :param model_metadata_dict: Dictionary returned by
        `cnn.read_model_metadata`.
    :param output_dir_name: Path to output directory.  Figure(s) will be saved
        here.
    :param cam_matrix: M-by-N-by-H numpy array of unguided class activations.
    :param guided_cam_matrix: [used only if `cam_matrix is None`]
        M-by-N-by-H-by-F numpy array of guided class activations.
    :param full_storm_id_string: Full storm ID.
    :param storm_time_unix_sec: Storm time.
    """

    pmm_flag = full_storm_id_string is None and storm_time_unix_sec is None
    conv_2d3d = model_metadata_dict[cnn.CONV_2D3D_KEY]

    if conv_2d3d:
        loop_max = 1
        radar_field_names = ['reflectivity']
    else:
        loop_max = len(figure_objects)
        training_option_dict = model_metadata_dict[cnn.TRAINING_OPTION_DICT_KEY]
        radar_field_names = training_option_dict[trainval_io.RADAR_FIELDS_KEY]

    min_unguided_value_log10 = numpy.log10(min_unguided_value)
    max_unguided_value_log10 = numpy.log10(max_unguided_value)
    contour_interval_log10 = (
        (max_unguided_value_log10 - min_unguided_value_log10) /
        (num_unguided_contours - 1)
    )

    for j in range(loop_max):
        if cam_matrix is None:
            saliency_plotting.plot_many_2d_grids_with_contours(
                saliency_matrix_3d=numpy.flip(
                    guided_cam_matrix[..., j], axis=0
                ),
                axes_object_matrix=axes_object_matrices[j],
                colour_map_object=colour_map_object,
                max_absolute_contour_level=max_guided_value,
                contour_interval=max_guided_value / half_num_guided_contours
            )

            this_colour_bar_object = plotting_utils.plot_linear_colour_bar(
                axes_object_or_matrix=axes_object_matrices[j],
                data_matrix=guided_cam_matrix[..., j],
                colour_map_object=colour_map_object, min_value=0.,
                max_value=max_guided_value, orientation_string='horizontal',
                fraction_of_axis_length=colour_bar_length,
                extend_min=False, extend_max=True,
                font_size=COLOUR_BAR_FONT_SIZE
            )

            if label_colour_bars:
                this_colour_bar_object.set_label(
                    'Absolute guided class activation',
                    fontsize=COLOUR_BAR_FONT_SIZE
                )
        else:
            cam_matrix_log10 = numpy.log10(cam_matrix)

            cam_plotting.plot_many_2d_grids(
                class_activation_matrix_3d=numpy.flip(cam_matrix_log10, axis=0),
                axes_object_matrix=axes_object_matrices[j],
                colour_map_object=colour_map_object,
                min_contour_level=min_unguided_value_log10,
                max_contour_level=max_unguided_value_log10,
                contour_interval=contour_interval_log10
            )

            this_colour_bar_object = plotting_utils.plot_linear_colour_bar(
                axes_object_or_matrix=axes_object_matrices[j],
                data_matrix=cam_matrix_log10,
                colour_map_object=colour_map_object,
                min_value=min_unguided_value_log10,
                max_value=max_unguided_value_log10,
                orientation_string='horizontal',
                fraction_of_axis_length=colour_bar_length,
                extend_min=True, extend_max=True,
                font_size=COLOUR_BAR_FONT_SIZE
            )

            these_tick_values = this_colour_bar_object.get_ticks()
            these_tick_strings = [
                '{0:.2f}'.format(10 ** v)[:4] for v in these_tick_values
            ]
            this_colour_bar_object.set_ticks(these_tick_values)
            this_colour_bar_object.set_ticklabels(these_tick_strings)

            if label_colour_bars:
                this_colour_bar_object.set_label(
                    'Class activation', fontsize=COLOUR_BAR_FONT_SIZE
                )

        this_file_name = plot_examples.metadata_to_file_name(
            output_dir_name=output_dir_name, is_sounding=False,
            pmm_flag=pmm_flag, full_storm_id_string=full_storm_id_string,
            storm_time_unix_sec=storm_time_unix_sec,
            radar_field_name=radar_field_names[j]
        )

        print('Saving figure to: "{0:s}"...'.format(this_file_name))
        figure_objects[j].savefig(
            this_file_name, dpi=FIGURE_RESOLUTION_DPI, pad_inches=0,
            bbox_inches='tight'
        )
        pyplot.close(figure_objects[j])