def _run(output_file_name):
    """Plots Laplacian kernel used for edge-detector test.

    This is effectively the main method.

    :param output_file_name: See documentation at top of file.
    """

    num_heights = KERNEL_MATRIX_3D.shape[-1]

    figure_object, axes_object_matrix = plotting_utils.create_paneled_figure(
        num_rows=1,
        num_columns=num_heights,
        horizontal_spacing=0.1,
        vertical_spacing=0.1,
        shared_x_axis=False,
        shared_y_axis=False,
        keep_aspect_ratio=True)

    for k in range(num_heights):
        _plot_kernel_one_height(kernel_matrix_2d=KERNEL_MATRIX_3D[..., k],
                                axes_object=axes_object_matrix[0, k])

    axes_object_matrix[0, 0].set_title('Bottom height')
    axes_object_matrix[0, 1].set_title('Middle height')
    axes_object_matrix[0, 2].set_title('Top height')

    file_system_utils.mkdir_recursive_if_necessary(file_name=output_file_name)

    print('Saving figure to: "{0:s}"...'.format(output_file_name))
    figure_object.savefig(output_file_name,
                          dpi=FIGURE_RESOLUTION_DPI,
                          pad_inches=0,
                          bbox_inches='tight')
    pyplot.close(figure_object)
示例#2
0
def plot_many_1d_feature_maps(
        feature_matrix, colour_map_object, colour_norm_object=None,
        min_colour_value=None, max_colour_value=None,
        figure_width_inches=DEFAULT_FIG_WIDTH_INCHES,
        figure_height_inches=DEFAULT_FIG_HEIGHT_INCHES):
    """Plots many 1-D feature maps in the same figure (one per column).

    N = number of points in spatial grid
    C = number of channels

    :param feature_matrix: N-by-C numpy array of feature values.
    :param colour_map_object: See doc for `plot_many_2d_feature_maps`.
    :param colour_norm_object: Same.
    :param min_colour_value: Same.
    :param max_colour_value: Same.
    :param figure_width_inches: Same.
    :param figure_height_inches: Same.
    :return: figure_object: See doc for `plotting_utils.create_paneled_figure`.
    :return: axes_object_matrix: Same.
    """

    pyplot.rc('axes', linewidth=1)
    error_checking.assert_is_numpy_array(feature_matrix, num_dimensions=2)

    num_channels = feature_matrix.shape[1]
    num_spatial_points = feature_matrix.shape[0]

    figure_object, axes_object_matrix = plotting_utils.create_paneled_figure(
        num_rows=1, num_columns=num_channels,
        figure_width_inches=figure_width_inches,
        figure_height_inches=figure_height_inches,
        horizontal_spacing=0., vertical_spacing=0.,
        shared_x_axis=False, shared_y_axis=False, keep_aspect_ratio=False)

    for k in range(num_channels):
        this_matrix = numpy.reshape(
            feature_matrix[..., k], (num_spatial_points, 1)
        )

        plot_2d_feature_map(
            feature_matrix=this_matrix, axes_object=axes_object_matrix[0, k],
            font_size=30, colour_map_object=colour_map_object,
            colour_norm_object=colour_norm_object,
            min_colour_value=min_colour_value,
            max_colour_value=max_colour_value,
            annotation_string=''
        )

    return figure_object, axes_object_matrix
示例#3
0
def _run(include_caption, output_dir_name):
    """Makes animation to explain multivariate convolution.

    This is effectively the main method.

    :param include_caption: See documentation at top of file.
    :param output_dir_name: Same.
    """

    file_system_utils.mkdir_recursive_if_necessary(
        directory_name=output_dir_name)

    output_feature_matrix = standalone_utils.do_2d_convolution(
        feature_matrix=INPUT_FEATURE_MATRIX, kernel_matrix=KERNEL_MATRIX,
        pad_edges=True, stride_length_px=1)

    output_feature_matrix = output_feature_matrix[0, ..., 0]

    num_grid_rows = INPUT_FEATURE_MATRIX.shape[0]
    num_grid_columns = INPUT_FEATURE_MATRIX.shape[1]
    image_file_names = []

    kernel_width_ratio = float(KERNEL_MATRIX.shape[1]) / num_grid_columns
    kernel_height_ratio = float(KERNEL_MATRIX.shape[0]) / num_grid_rows

    for i in range(num_grid_rows):
        for j in range(num_grid_columns):
            this_figure_object, this_axes_object_matrix = (
                plotting_utils.create_paneled_figure(
                    num_rows=NUM_PANEL_ROWS, num_columns=NUM_PANEL_COLUMNS,
                    horizontal_spacing=0.2, vertical_spacing=0.,
                    shared_x_axis=False, shared_y_axis=False,
                    keep_aspect_ratio=True)
            )

            letter_label = None

            _plot_feature_map(
                feature_matrix_2d=INPUT_FEATURE_MATRIX[..., 0],
                kernel_row=i, kernel_column=j, is_output_map=False,
                axes_object=this_axes_object_matrix[0, 0]
            )

            if letter_label is None:
                letter_label = 'a'
            else:
                letter_label = chr(ord(letter_label) + 1)

            plotting_utils.label_axes(
                axes_object=this_axes_object_matrix[0, 0],
                label_string='({0:s})'.format(letter_label),
                font_size=PANEL_LETTER_FONT_SIZE,
                y_coord_normalized=1.04, x_coord_normalized=0.1
            )

            _plot_feature_map(
                feature_matrix_2d=output_feature_matrix,
                kernel_row=i, kernel_column=j, is_output_map=True,
                axes_object=this_axes_object_matrix[0, 2]
            )

            this_bbox_object = this_axes_object_matrix[0, 1].get_position()
            this_width = kernel_width_ratio * (
                this_bbox_object.x1 - this_bbox_object.x0
            )
            this_height = kernel_height_ratio * (
                this_bbox_object.y1 - this_bbox_object.y0
            )

            this_bbox_object.x0 += 0.5 * this_width
            this_bbox_object.y0 = (
                this_axes_object_matrix[0, 0].get_position().y0 + 0.1
            )
            this_bbox_object.x1 = this_bbox_object.x0 + this_width
            this_bbox_object.y1 = this_bbox_object.y0 + this_height

            this_axes_object_matrix[0, 1].set_position(this_bbox_object)

            _plot_kernel(
                kernel_matrix_2d=KERNEL_MATRIX[..., 0, 0],
                feature_matrix_2d=INPUT_FEATURE_MATRIX[..., 0],
                feature_row_at_center=i, feature_column_at_center=j,
                axes_object=this_axes_object_matrix[0, 1]
            )

            letter_label = chr(ord(letter_label) + 1)

            plotting_utils.label_axes(
                axes_object=this_axes_object_matrix[0, 1],
                label_string='({0:s})'.format(letter_label),
                font_size=PANEL_LETTER_FONT_SIZE,
                y_coord_normalized=1.04, x_coord_normalized=0.2
            )

            _plot_feature_to_kernel_lines(
                kernel_matrix_2d=KERNEL_MATRIX[..., 0, 0],
                feature_matrix_2d=INPUT_FEATURE_MATRIX[..., 0],
                feature_row_at_center=i, feature_column_at_center=j,
                kernel_axes_object=this_axes_object_matrix[0, 1],
                feature_axes_object=this_axes_object_matrix[0, 0]
            )

            letter_label = chr(ord(letter_label) + 1)

            plotting_utils.label_axes(
                axes_object=this_axes_object_matrix[0, 2],
                label_string='({0:s})'.format(letter_label),
                font_size=PANEL_LETTER_FONT_SIZE,
                y_coord_normalized=1.04, x_coord_normalized=0.1
            )

            if include_caption:
                this_figure_object.text(
                    0.5, 0.35, FIGURE_CAPTION,
                    fontsize=DEFAULT_FONT_SIZE, color='k',
                    horizontalalignment='center', verticalalignment='top')

            image_file_names.append(
                '{0:s}/conv_animation_row{1:d}_column{2:d}.jpg'.format(
                    output_dir_name, i, j)
            )

            print('Saving figure to: "{0:s}"...'.format(image_file_names[-1]))
            this_figure_object.savefig(
                image_file_names[-1], dpi=FIGURE_RESOLUTION_DPI,
                pad_inches=0, bbox_inches='tight'
            )
            pyplot.close(this_figure_object)

    animation_file_name = '{0:s}/conv_animation.gif'.format(output_dir_name)
    print('Creating animation: "{0:s}"...'.format(animation_file_name))

    imagemagick_utils.create_gif(
        input_file_names=image_file_names, output_file_name=animation_file_name,
        num_seconds_per_frame=0.5, resize_factor=0.5)
示例#4
0
def plot_3d_grid_without_coords(field_matrix,
                                field_name,
                                grid_point_heights_metres,
                                ground_relative,
                                num_panel_rows=None,
                                figure_object=None,
                                axes_object_matrix=None,
                                font_size=DEFAULT_FONT_SIZE,
                                colour_map_object=None,
                                colour_norm_object=None):
    """Plots 3-D grid as many colour maps (one per height).

    M = number of grid rows
    N = number of grid columns
    H = number of grid heights

    To use the default colour scheme for the given radar field, leave
    `colour_map_object` and `colour_norm_object` empty.

    If `num_panel_rows is None`, this method needs arguments `figure_object` and
    `axes_object_matrix` -- and vice-versa.

    :param field_matrix: M-by-N-by-H numpy array with values of radar field.
    :param field_name: Name of radar field (must be accepted by
        `radar_utils.check_field_name`).
    :param grid_point_heights_metres: length-H integer numpy array of heights.
    :param ground_relative: Boolean flag.  If True, heights in
        `height_by_pair_metres` are ground-relative.  If False,
        sea-level-relative.
    :param num_panel_rows: Number of rows in paneled figure (different than M,
        the number of grid rows).
    :param figure_object: See doc for `plotting_utils.create_paneled_figure`.
    :param axes_object_matrix: See above.
    :param font_size: Font size for colour-bar ticks and panel labels.
    :param colour_map_object: See doc for `plot_latlng_grid`.
    :param colour_norm_object: Same.
    :return: figure_object: See doc for `plotting_utils.init_panels`.
    :return: axes_object_matrix: Same.
    """

    error_checking.assert_is_numpy_array(field_matrix, num_dimensions=3)
    error_checking.assert_is_geq_numpy_array(grid_point_heights_metres, 0)
    grid_point_heights_metres = numpy.round(grid_point_heights_metres).astype(
        int)

    num_heights = field_matrix.shape[2]
    these_expected_dim = numpy.array([num_heights], dtype=int)
    error_checking.assert_is_numpy_array(grid_point_heights_metres,
                                         exact_dimensions=these_expected_dim)

    error_checking.assert_is_boolean(ground_relative)

    if figure_object is None:
        error_checking.assert_is_integer(num_panel_rows)
        error_checking.assert_is_geq(num_panel_rows, 1)
        error_checking.assert_is_leq(num_panel_rows, num_heights)

        num_panel_columns = int(numpy.ceil(
            float(num_heights) / num_panel_rows))

        figure_object, axes_object_matrix = (
            plotting_utils.create_paneled_figure(num_rows=num_panel_rows,
                                                 num_columns=num_panel_columns,
                                                 shared_x_axis=False,
                                                 shared_y_axis=False,
                                                 keep_aspect_ratio=True))
    else:
        error_checking.assert_is_numpy_array(axes_object_matrix,
                                             num_dimensions=2)

        num_panel_rows = axes_object_matrix.shape[0]
        num_panel_columns = axes_object_matrix.shape[1]

    for i in range(num_panel_rows):
        for j in range(num_panel_columns):
            this_height_index = i * num_panel_columns + j

            if this_height_index >= num_heights:
                axes_object_matrix[i, j].axis('off')
                continue

            this_annotation_string = '{0:.1f} km'.format(
                grid_point_heights_metres[this_height_index] * METRES_TO_KM)

            if ground_relative:
                this_annotation_string += ' AGL'
            else:
                this_annotation_string += ' ASL'

            plot_2d_grid_without_coords(
                field_matrix=field_matrix[..., this_height_index],
                field_name=field_name,
                axes_object=axes_object_matrix[i, j],
                annotation_string=this_annotation_string,
                colour_map_object=colour_map_object,
                colour_norm_object=colour_norm_object,
                font_size=font_size)

    return figure_object, axes_object_matrix
示例#5
0
def plot_many_2d_grids_without_coords(field_matrix,
                                      field_name_by_panel,
                                      num_panel_rows=None,
                                      figure_object=None,
                                      axes_object_matrix=None,
                                      panel_names=None,
                                      colour_map_object_by_panel=None,
                                      colour_norm_object_by_panel=None,
                                      plot_colour_bar_by_panel=None,
                                      font_size=DEFAULT_FONT_SIZE,
                                      row_major=True):
    """Plots 2-D colour map in each panel (one per field/height pair).

    M = number of rows in spatial grid
    N = number of columns in spatial grid
    P = number of panels (field/height pairs)

    This method uses the default colour scheme for each radar field.

    If `num_panel_rows is None`, this method needs arguments `figure_object` and
    `axes_object_matrix` -- and vice-versa.

    :param field_matrix: M-by-N-by-P numpy array of radar values.
    :param field_name_by_panel: length-P list of field names.
    :param num_panel_rows: Number of rows in paneled figure (different than M,
        which is number of rows in spatial grid).
    :param figure_object: See doc for `plotting_utils.create_paneled_figure`.
    :param axes_object_matrix: See above.
    :param panel_names: length-P list of panel names (will be printed at bottoms
        of panels).  If you do not want panel names, make this None.
    :param colour_map_object_by_panel: length-P list of `matplotlib.pyplot.cm`
        objects.  If this is None, the default will be used for each field.
    :param colour_norm_object_by_panel: length-P list of
        `matplotlib.colors.BoundaryNorm` objects.  If this is None, the default
        will be used for each field.
    :param plot_colour_bar_by_panel: length-P numpy array of Boolean flags.  If
        plot_colour_bar_by_panel[k] = True, horizontal colour bar will be
        plotted under [k]th panel.  If you want to plot colour bar for every
        panel, leave this as None.
    :param font_size: Font size.
    :param row_major: Boolean flag.  If True, panels will be filled along rows
        first, then down columns.  If False, down columns first, then along
        rows.
    :return: figure_object: See doc for `plotting_utils.create_paneled_figure`.
    :return: axes_object_matrix: Same.
    :raises: ValueError: if `colour_map_object_by_panel` or
        `colour_norm_object_by_panel` has different length than number of
        panels.
    """

    error_checking.assert_is_boolean(row_major)
    error_checking.assert_is_numpy_array(field_matrix, num_dimensions=3)
    num_panels = field_matrix.shape[2]

    if panel_names is None:
        panel_names = [None] * num_panels
    if plot_colour_bar_by_panel is None:
        plot_colour_bar_by_panel = numpy.full(num_panels, True, dtype=bool)

    these_expected_dim = numpy.array([num_panels], dtype=int)
    error_checking.assert_is_numpy_array(numpy.array(panel_names),
                                         exact_dimensions=these_expected_dim)
    error_checking.assert_is_numpy_array(numpy.array(field_name_by_panel),
                                         exact_dimensions=these_expected_dim)

    error_checking.assert_is_boolean_numpy_array(plot_colour_bar_by_panel)
    error_checking.assert_is_numpy_array(plot_colour_bar_by_panel,
                                         exact_dimensions=these_expected_dim)

    if (colour_map_object_by_panel is None
            or colour_norm_object_by_panel is None):
        colour_map_object_by_panel = [None] * num_panels
        colour_norm_object_by_panel = [None] * num_panels

    error_checking.assert_is_list(colour_map_object_by_panel)
    error_checking.assert_is_list(colour_norm_object_by_panel)

    if len(colour_map_object_by_panel) != num_panels:
        error_string = (
            'Number of colour maps ({0:d}) should equal number of panels '
            '({1:d}).').format(len(colour_map_object_by_panel), num_panels)

        raise ValueError(error_string)

    if len(colour_norm_object_by_panel) != num_panels:
        error_string = (
            'Number of colour-normalizers ({0:d}) should equal number of panels'
            ' ({1:d}).').format(len(colour_norm_object_by_panel), num_panels)

        raise ValueError(error_string)

    if figure_object is None:
        error_checking.assert_is_integer(num_panel_rows)
        error_checking.assert_is_geq(num_panel_rows, 1)
        error_checking.assert_is_leq(num_panel_rows, num_panels)

        num_panel_columns = int(numpy.ceil(float(num_panels) / num_panel_rows))

        figure_object, axes_object_matrix = (
            plotting_utils.create_paneled_figure(num_rows=num_panel_rows,
                                                 num_columns=num_panel_columns,
                                                 shared_x_axis=False,
                                                 shared_y_axis=False,
                                                 keep_aspect_ratio=True))
    else:
        error_checking.assert_is_numpy_array(axes_object_matrix,
                                             num_dimensions=2)

        num_panel_rows = axes_object_matrix.shape[0]
        num_panel_columns = axes_object_matrix.shape[1]

    if row_major:
        order_string = 'C'
    else:
        order_string = 'F'

    for k in range(num_panels):
        this_panel_row, this_panel_column = numpy.unravel_index(
            k, (num_panel_rows, num_panel_columns), order=order_string)

        # this_colour_map_object, this_colour_norm_object = (
        #     plot_2d_grid_without_coords(
        #         field_matrix=field_matrix[..., k],
        #         field_name=field_name_by_panel[k],
        #         axes_object=axes_object_matrix[
        #             this_panel_row, this_panel_column],
        #         annotation_string=panel_names[k], font_size=font_size,
        #         colour_map_object=colour_map_object_by_panel[k],
        #         colour_norm_object=colour_norm_object_by_panel[k]
        #     )
        # )

        this_colour_map_object, this_colour_norm_object = (
            plot_2d_grid_without_coords(
                field_matrix=field_matrix[..., k],
                field_name=field_name_by_panel[k],
                axes_object=axes_object_matrix[this_panel_row,
                                               this_panel_column],
                annotation_string=None,
                font_size=font_size,
                colour_map_object=colour_map_object_by_panel[k],
                colour_norm_object=colour_norm_object_by_panel[k]))

        if not plot_colour_bar_by_panel[k]:
            continue

        this_extend_min_flag = field_name_by_panel[k] in SHEAR_VORT_DIV_NAMES

        this_colour_bar_object = plotting_utils.plot_colour_bar(
            axes_object_or_matrix=axes_object_matrix[this_panel_row,
                                                     this_panel_column],
            data_matrix=field_matrix[..., k],
            colour_map_object=this_colour_map_object,
            colour_norm_object=this_colour_norm_object,
            orientation_string='horizontal',
            extend_min=this_extend_min_flag,
            extend_max=True,
            fraction_of_axis_length=0.75,
            font_size=font_size)

        this_colour_bar_object.set_label(panel_names[k].replace('\n', '; '),
                                         fontsize=font_size,
                                         fontweight='bold')

    for k in range(num_panel_rows * num_panel_columns):
        if k < num_panels:
            continue

        this_panel_row, this_panel_column = numpy.unravel_index(
            k, (num_panel_rows, num_panel_columns), order=order_string)

        axes_object_matrix[this_panel_row, this_panel_column].axis('off')

    return figure_object, axes_object_matrix
def _run(input_file_name, num_predictors_to_plot, confidence_level,
         output_dir_name):
    """Plots results of permutation-based importance test.

    This is effectively the main method.

    :param input_file_name: See documentation at top of file.
    :param num_predictors_to_plot: Same.
    :param confidence_level: Same.
    :param output_dir_name: Same.
    """

    if num_predictors_to_plot <= 0:
        num_predictors_to_plot = None

    file_system_utils.mkdir_recursive_if_necessary(
        directory_name=output_dir_name)

    print('Reading data from: "{0:s}"...'.format(input_file_name))
    permutation_dict = ml4rt_permutation.read_file(input_file_name)
    permutation_dict = _results_to_gg_format(permutation_dict)

    figure_object, axes_object_matrix = plotting_utils.create_paneled_figure(
        num_rows=1,
        num_columns=2,
        shared_x_axis=False,
        shared_y_axis=True,
        keep_aspect_ratio=False,
        horizontal_spacing=0.1,
        vertical_spacing=0.05)
    permutation_plotting.plot_single_pass_test(
        permutation_dict=permutation_dict,
        axes_object=axes_object_matrix[0, 0],
        num_predictors_to_plot=num_predictors_to_plot,
        plot_percent_increase=False,
        confidence_level=confidence_level,
        bar_face_colour=BAR_FACE_COLOUR)
    axes_object_matrix[0, 0].set_title('Single-pass test')
    axes_object_matrix[0, 0].set_xlabel('Mean squared error')

    permutation_plotting.plot_multipass_test(
        permutation_dict=permutation_dict,
        axes_object=axes_object_matrix[0, 1],
        num_predictors_to_plot=num_predictors_to_plot,
        plot_percent_increase=False,
        confidence_level=confidence_level,
        bar_face_colour=BAR_FACE_COLOUR)
    axes_object_matrix[0, 1].set_title('Multi-pass test')
    axes_object_matrix[0, 1].set_xlabel('Mean squared error')
    axes_object_matrix[0, 1].set_ylabel('')

    figure_file_name = '{0:s}/permutation_test_abs-values.jpg'.format(
        output_dir_name)

    print('Saving figure to: "{0:s}"...'.format(figure_file_name))
    figure_object.savefig(figure_file_name,
                          dpi=FIGURE_RESOLUTION_DPI,
                          pad_inches=0,
                          bbox_inches='tight')
    pyplot.close(figure_object)

    figure_object, axes_object_matrix = plotting_utils.create_paneled_figure(
        num_rows=1,
        num_columns=2,
        shared_x_axis=False,
        shared_y_axis=True,
        keep_aspect_ratio=False,
        horizontal_spacing=0.1,
        vertical_spacing=0.05)
    permutation_plotting.plot_single_pass_test(
        permutation_dict=permutation_dict,
        axes_object=axes_object_matrix[0, 0],
        num_predictors_to_plot=num_predictors_to_plot,
        plot_percent_increase=True,
        confidence_level=confidence_level,
        bar_face_colour=BAR_FACE_COLOUR)
    axes_object_matrix[0, 0].set_title('Single-pass test')
    axes_object_matrix[0, 0].set_xlabel('MSE (fraction of original)')

    permutation_plotting.plot_multipass_test(
        permutation_dict=permutation_dict,
        axes_object=axes_object_matrix[0, 1],
        num_predictors_to_plot=num_predictors_to_plot,
        plot_percent_increase=True,
        confidence_level=confidence_level,
        bar_face_colour=BAR_FACE_COLOUR)
    axes_object_matrix[0, 1].set_title('Multi-pass test')
    axes_object_matrix[0, 1].set_xlabel('MSE (fraction of original)')
    axes_object_matrix[0, 1].set_ylabel('')

    figure_file_name = '{0:s}/permutation_test_percentage.jpg'.format(
        output_dir_name)

    print('Saving figure to: "{0:s}"...'.format(figure_file_name))
    figure_object.savefig(figure_file_name,
                          dpi=FIGURE_RESOLUTION_DPI,
                          pad_inches=0,
                          bbox_inches='tight')
    pyplot.close(figure_object)
def _run(input_file_name, num_predictors_to_plot, output_dir_name):
    """Plots results of permutation test.

    This is effectively the main method.

    :param input_file_name: See documentation at top of file.
    :param output_dir_name: Same.
    """

    if num_predictors_to_plot <= 0:
        num_predictors_to_plot = None

    if output_dir_name in ['', 'None']:
        output_dir_name = os.path.split(input_file_name)[0]

    file_system_utils.mkdir_recursive_if_necessary(
        directory_name=output_dir_name)

    print(
        'Reading permutation results from: "{0:s}"...'.format(input_file_name))
    permutation_dict = permutation.read_results(input_file_name)

    _, axes_object_matrix = plotting_utils.create_paneled_figure(
        num_rows=1,
        num_columns=2,
        shared_x_axis=False,
        shared_y_axis=True,
        keep_aspect_ratio=False)

    permutation_plotting.plot_breiman_results(
        permutation_dict=permutation_dict,
        axes_object=axes_object_matrix[0, 0],
        plot_percent_increase=False,
        num_predictors_to_plot=num_predictors_to_plot)

    axes_object_matrix[0, 0].set_xlabel('AUC')
    axes_object_matrix[0, 0].set_title('Single-pass')

    permutation_plotting.plot_lakshmanan_results(
        permutation_dict=permutation_dict,
        axes_object=axes_object_matrix[0, 1],
        plot_percent_increase=False,
        num_steps_to_plot=num_predictors_to_plot)

    axes_object_matrix[0, 1].set_xlabel('AUC')
    axes_object_matrix[0, 1].set_ylabel('')
    axes_object_matrix[0, 1].set_title('Multi-pass')

    pyplot.tight_layout()
    absolute_value_file_name = '{0:s}/permutation_absolute-values.jpg'.format(
        output_dir_name)

    print('Saving figure to file: "{0:s}"...'.format(absolute_value_file_name))
    pyplot.savefig(absolute_value_file_name, dpi=FIGURE_RESOLUTION_DPI)
    pyplot.close()

    _, axes_object_matrix = plotting_utils.create_paneled_figure(
        num_rows=1,
        num_columns=2,
        shared_x_axis=False,
        shared_y_axis=True,
        keep_aspect_ratio=False)

    permutation_plotting.plot_breiman_results(
        permutation_dict=permutation_dict,
        axes_object=axes_object_matrix[0, 0],
        plot_percent_increase=True,
        num_predictors_to_plot=num_predictors_to_plot)

    axes_object_matrix[0, 0].set_title('Single-pass')

    permutation_plotting.plot_lakshmanan_results(
        permutation_dict=permutation_dict,
        axes_object=axes_object_matrix[0, 1],
        plot_percent_increase=True,
        num_steps_to_plot=num_predictors_to_plot)

    axes_object_matrix[0, 1].set_ylabel('')
    axes_object_matrix[0, 1].set_title('Multi-pass')

    pyplot.tight_layout()
    percentage_file_name = '{0:s}/permutation_percentage.jpg'.format(
        output_dir_name)

    print('Saving figure to file: "{0:s}"...'.format(percentage_file_name))
    pyplot.savefig(percentage_file_name, dpi=FIGURE_RESOLUTION_DPI)
    pyplot.close()
示例#8
0
def plot_many_2d_feature_maps(
        feature_matrix, annotation_string_by_panel, num_panel_rows,
        colour_map_object, colour_norm_object=None, min_colour_value=None,
        max_colour_value=None, figure_width_inches=DEFAULT_FIG_WIDTH_INCHES,
        figure_height_inches=DEFAULT_FIG_HEIGHT_INCHES,
        font_size=DEFAULT_FONT_SIZE):
    """Plots many 2-D feature maps in the same figure (one per panel).

    M = number of rows in spatial grid
    N = number of columns in spatial grid
    P = number of panels

    :param feature_matrix: M-by-N-by-P numpy array of feature values (either
        before or after activation function -- this method doesn't care).
    :param annotation_string_by_panel: length-P list of annotations.
        annotation_string_by_panel[k] will be printed in the bottom-center of
        the [k]th panel.
    :param num_panel_rows: Number of panel rows.
    :param colour_map_object: See doc for `plot_2d_feature_map`.
    :param colour_norm_object: Same.
    :param min_colour_value: Same.
    :param max_colour_value: Same.
    :param figure_width_inches: Figure width.
    :param figure_height_inches: Figure height.
    :param font_size: Font size for panel labels.
    :return: figure_object: See doc for `plotting_utils.create_paneled_figure`.
    :return: axes_object_matrix: Same.
    """

    pyplot.rc('axes', linewidth=3)
    error_checking.assert_is_numpy_array(feature_matrix, num_dimensions=3)

    num_panels = feature_matrix.shape[-1]
    error_checking.assert_is_numpy_array(
        numpy.array(annotation_string_by_panel),
        exact_dimensions=numpy.array([num_panels])
    )

    error_checking.assert_is_integer(num_panel_rows)
    error_checking.assert_is_geq(num_panel_rows, 1)
    error_checking.assert_is_leq(num_panel_rows, num_panels)

    num_panel_columns = int(numpy.ceil(
        float(num_panels) / num_panel_rows
    ))

    figure_object, axes_object_matrix = plotting_utils.create_paneled_figure(
        num_rows=num_panel_rows, num_columns=num_panel_columns,
        figure_width_inches=figure_width_inches,
        figure_height_inches=figure_height_inches,
        horizontal_spacing=0., vertical_spacing=0.,
        shared_x_axis=False, shared_y_axis=False, keep_aspect_ratio=False)

    for i in range(num_panel_rows):
        for j in range(num_panel_columns):
            this_linear_index = i * num_panel_columns + j

            if this_linear_index >= num_panels:
                axes_object_matrix[i, j].axis('off')
                continue

            plot_2d_feature_map(
                feature_matrix=feature_matrix[..., this_linear_index],
                axes_object=axes_object_matrix[i, j], font_size=font_size,
                colour_map_object=colour_map_object,
                colour_norm_object=colour_norm_object,
                min_colour_value=min_colour_value,
                max_colour_value=max_colour_value,
                annotation_string=annotation_string_by_panel[this_linear_index]
            )

    return figure_object, axes_object_matrix
示例#9
0
def _run(forward_test_file_name, backwards_test_file_name, num_predictors,
         confidence_level, output_file_name):
    """Makes figure with results of all 4 permutation tests.

    This is effectively the main method.

    :param forward_test_file_name: See documentation at top of file.
    :param backwards_test_file_name: Same.
    :param num_predictors: Same.
    :param confidence_level: Same.
    :param output_file_name: Same.
    """

    if num_predictors <= 0:
        num_predictors = None

    file_system_utils.mkdir_recursive_if_necessary(file_name=output_file_name)

    print('Reading data from: "{0:s}"...'.format(forward_test_file_name))
    forward_test_dict = permutation_utils.read_results(forward_test_file_name)

    print('Reading data from: "{0:s}"...'.format(backwards_test_file_name))
    backwards_test_dict = permutation_utils.read_results(
        backwards_test_file_name
    )

    figure_object, axes_object_matrix = plotting_utils.create_paneled_figure(
        num_rows=2, num_columns=2, shared_x_axis=False, shared_y_axis=True,
        keep_aspect_ratio=False, horizontal_spacing=0.1, vertical_spacing=0.05
    )

    permutation_plotting.plot_single_pass_test(
        permutation_dict=forward_test_dict,
        axes_object=axes_object_matrix[0, 0],
        plot_percent_increase=False, confidence_level=confidence_level,
        num_predictors_to_plot=num_predictors
    )

    axes_object_matrix[0, 0].set_title('Forward single-pass test')
    axes_object_matrix[0, 0].set_xticks([])
    axes_object_matrix[0, 0].set_xlabel('')
    plotting_utils.label_axes(
        axes_object=axes_object_matrix[0, 0], label_string='(a)',
        x_coord_normalized=-0.01, y_coord_normalized=0.925
    )

    permutation_plotting.plot_multipass_test(
        permutation_dict=forward_test_dict,
        axes_object=axes_object_matrix[0, 1],
        plot_percent_increase=False, confidence_level=confidence_level,
        num_predictors_to_plot=num_predictors
    )

    axes_object_matrix[0, 1].set_title('Forward multi-pass test')
    axes_object_matrix[0, 1].set_xticks([])
    axes_object_matrix[0, 1].set_xlabel('')
    axes_object_matrix[0, 1].set_ylabel('')
    plotting_utils.label_axes(
        axes_object=axes_object_matrix[0, 1], label_string='(b)',
        x_coord_normalized=1.15, y_coord_normalized=0.925
    )

    permutation_plotting.plot_single_pass_test(
        permutation_dict=backwards_test_dict,
        axes_object=axes_object_matrix[1, 0],
        plot_percent_increase=False, confidence_level=confidence_level,
        num_predictors_to_plot=num_predictors
    )

    axes_object_matrix[1, 0].set_title('Backward single-pass test')
    axes_object_matrix[1, 0].set_xlabel('Area under ROC curve (AUC)')
    plotting_utils.label_axes(
        axes_object=axes_object_matrix[1, 0], label_string='(c)',
        x_coord_normalized=-0.01, y_coord_normalized=0.925
    )

    permutation_plotting.plot_multipass_test(
        permutation_dict=backwards_test_dict,
        axes_object=axes_object_matrix[1, 1],
        plot_percent_increase=False, confidence_level=confidence_level,
        num_predictors_to_plot=num_predictors
    )

    axes_object_matrix[1, 1].set_title('Backward multi-pass test')
    axes_object_matrix[1, 1].set_xlabel('Area under ROC curve (AUC)')
    axes_object_matrix[1, 1].set_ylabel('')
    plotting_utils.label_axes(
        axes_object=axes_object_matrix[1, 1], label_string='(d)',
        x_coord_normalized=1.15, y_coord_normalized=0.925
    )

    print('Saving figure to: "{0:s}"...'.format(output_file_name))
    figure_object.savefig(
        output_file_name, dpi=FIGURE_RESOLUTION_DPI,
        pad_inches=0, bbox_inches='tight'
    )
    pyplot.close(figure_object)
def _run(include_caption, output_dir_name):
    """Makes animation to explain pooling.

    This is effectively the main method.

    :param include_caption: See documentation at top of file.
    :param output_dir_name: Same.
    """

    file_system_utils.mkdir_recursive_if_necessary(
        directory_name=output_dir_name)

    num_input_rows = INPUT_FEATURE_MATRIX.shape[0]
    num_input_columns = INPUT_FEATURE_MATRIX.shape[1]
    num_channels = INPUT_FEATURE_MATRIX.shape[2]
    image_file_names = []

    for i in range(num_input_rows):
        for j in range(num_input_columns):
            this_figure_object, this_axes_object_matrix = (
                plotting_utils.create_paneled_figure(
                    num_rows=num_channels,
                    num_columns=NUM_PANEL_COLUMNS,
                    horizontal_spacing=HORIZ_PANEL_SPACING,
                    vertical_spacing=VERTICAL_PANEL_SPACING,
                    shared_x_axis=False,
                    shared_y_axis=False,
                    keep_aspect_ratio=True))

            letter_label = None

            for k in range(num_channels):
                _plot_feature_map(feature_matrix_2d=INPUT_FEATURE_MATRIX[...,
                                                                         k],
                                  pooled_row=i,
                                  pooled_column=j,
                                  pooled=True,
                                  axes_object=this_axes_object_matrix[k, 0])

                if letter_label is None:
                    letter_label = 'a'
                else:
                    letter_label = chr(ord(letter_label) + 1)

                plotting_utils.label_axes(
                    axes_object=this_axes_object_matrix[k, 0],
                    label_string='({0:s})'.format(letter_label),
                    font_size=PANEL_LETTER_FONT_SIZE,
                    y_coord_normalized=0.85,
                    x_coord_normalized=-0.02)

            for k in range(num_channels):
                _plot_feature_map(feature_matrix_2d=OUTPUT_FEATURE_MATRIX[...,
                                                                          k],
                                  pooled_row=i,
                                  pooled_column=j,
                                  pooled=False,
                                  axes_object=this_axes_object_matrix[k, 1])

                letter_label = chr(ord(letter_label) + 1)

                plotting_utils.label_axes(
                    axes_object=this_axes_object_matrix[k, 1],
                    label_string='({0:s})'.format(letter_label),
                    font_size=PANEL_LETTER_FONT_SIZE,
                    y_coord_normalized=0.85,
                    x_coord_normalized=-0.02)

                _plot_interpanel_lines(
                    pooled_row=i,
                    pooled_column=j,
                    input_fm_axes_object=this_axes_object_matrix[k, 0],
                    output_fm_axes_object=this_axes_object_matrix[k, 1])

            if include_caption:
                this_figure_object.text(0.5,
                                        CAPTION_Y_COORD,
                                        FIGURE_CAPTION,
                                        fontsize=DEFAULT_FONT_SIZE,
                                        color='k',
                                        horizontalalignment='center',
                                        verticalalignment='top')

            image_file_names.append(
                '{0:s}/upsampling_animation_row{1:d}_column{2:d}.jpg'.format(
                    output_dir_name, i, j))

            print('Saving figure to: "{0:s}"...'.format(image_file_names[-1]))
            this_figure_object.savefig(image_file_names[-1],
                                       dpi=FIGURE_RESOLUTION_DPI,
                                       pad_inches=0,
                                       bbox_inches='tight')
            pyplot.close(this_figure_object)

    animation_file_name = '{0:s}/upsampling_animation.gif'.format(
        output_dir_name)
    print('Creating animation: "{0:s}"...'.format(animation_file_name))

    imagemagick_utils.create_gif(input_file_names=image_file_names,
                                 output_file_name=animation_file_name,
                                 num_seconds_per_frame=0.5,
                                 resize_factor=0.5)
示例#11
0
def _plot_one_example(
        input_feature_matrix, feature_matrix_after_conv,
        feature_matrix_after_activn, feature_matrix_after_bn,
        feature_matrix_after_pooling, output_file_name):
    """Plots entire figure for one example (storm object).

    :param input_feature_matrix: 2-D numpy array with input features.
    :param feature_matrix_after_conv: 2-D numpy array with features after
        convolution.
    :param feature_matrix_after_activn: 2-D numpy array with features after
        activation.
    :param feature_matrix_after_bn: 2-D numpy array with features after batch
        normalization.
    :param feature_matrix_after_pooling: 2-D numpy array with features after
        pooling.
    :param output_file_name: Path to output file.  Figure will be saved here.
    """

    num_output_channels = feature_matrix_after_conv.shape[-1]

    figure_object, axes_object_matrix = plotting_utils.create_paneled_figure(
        num_rows=num_output_channels, num_columns=NUM_PANEL_COLUMNS,
        horizontal_spacing=0., vertical_spacing=0.,
        shared_x_axis=False, shared_y_axis=False, keep_aspect_ratio=True)

    max_colour_value = numpy.percentile(
        numpy.absolute(input_feature_matrix), MAX_COLOUR_PERCENTILE
    )

    axes_object_matrix[0, 0].set_title('Input', fontsize=TITLE_FONT_SIZE)

    for k in range(num_output_channels):
        if k == 0:
            _plot_one_feature_map(
                feature_matrix_2d=input_feature_matrix[..., k],
                max_colour_value=max_colour_value, plot_colour_bar=False,
                axes_object=axes_object_matrix[k, 0]
            )

            continue

        axes_object_matrix[k, 0].axis('off')

    colour_bar_object = plotting_utils.plot_linear_colour_bar(
        axes_object_or_matrix=axes_object_matrix[num_output_channels - 1, 0],
        data_matrix=input_feature_matrix[..., 0],
        colour_map_object=COLOUR_MAP_OBJECT, min_value=-1 * max_colour_value,
        max_value=max_colour_value, orientation_string='horizontal',
        padding=0.015, fraction_of_axis_length=0.9,
        extend_min=True, extend_max=True, font_size=DEFAULT_FONT_SIZE)

    tick_values = colour_bar_object.ax.get_xticks()
    tick_label_strings = ['{0:.1f}'.format(x) for x in tick_values]
    colour_bar_object.set_ticks(tick_values)
    colour_bar_object.set_ticklabels(tick_label_strings)

    letter_label = 'a'
    plotting_utils.label_axes(
        axes_object=axes_object_matrix[0, 0],
        label_string='({0:s})'.format(letter_label),
        font_size=TITLE_FONT_SIZE,
        x_coord_normalized=0.125, y_coord_normalized=1.025
    )

    this_matrix = numpy.stack(
        (feature_matrix_after_conv, feature_matrix_after_activn), axis=0
    )
    max_colour_value = numpy.percentile(
        numpy.absolute(this_matrix), MAX_COLOUR_PERCENTILE
    )

    axes_object_matrix[0, 1].set_title(
        '  After convolution', fontsize=TITLE_FONT_SIZE)

    for k in range(num_output_channels):
        _plot_one_feature_map(
            feature_matrix_2d=feature_matrix_after_conv[..., k],
            max_colour_value=max_colour_value,
            plot_colour_bar=k == num_output_channels - 1,
            axes_object=axes_object_matrix[k, 1]
        )

        letter_label = chr(ord(letter_label) + 1)

        plotting_utils.label_axes(
            axes_object=axes_object_matrix[k, 1],
            label_string='({0:s})'.format(letter_label),
            font_size=TITLE_FONT_SIZE,
            x_coord_normalized=0.125, y_coord_normalized=1.025
        )

    axes_object_matrix[0, 2].set_title(
        ' After activation', fontsize=TITLE_FONT_SIZE)

    for k in range(num_output_channels):
        _plot_one_feature_map(
            feature_matrix_2d=feature_matrix_after_activn[..., k],
            max_colour_value=max_colour_value,
            plot_colour_bar=k == num_output_channels - 1,
            axes_object=axes_object_matrix[k, 2]
        )

        letter_label = chr(ord(letter_label) + 1)

        plotting_utils.label_axes(
            axes_object=axes_object_matrix[k, 2],
            label_string='({0:s})'.format(letter_label),
            font_size=TITLE_FONT_SIZE,
            x_coord_normalized=0.125, y_coord_normalized=1.025
        )

    max_colour_value = numpy.percentile(
        numpy.absolute(feature_matrix_after_bn), MAX_COLOUR_PERCENTILE
    )

    axes_object_matrix[0, 3].set_title(
        '  After batch norm', fontsize=TITLE_FONT_SIZE)

    for k in range(num_output_channels):
        _plot_one_feature_map(
            feature_matrix_2d=feature_matrix_after_bn[..., k],
            max_colour_value=max_colour_value,
            plot_colour_bar=k == num_output_channels - 1,
            axes_object=axes_object_matrix[k, 3]
        )

        letter_label = chr(ord(letter_label) + 1)

        plotting_utils.label_axes(
            axes_object=axes_object_matrix[k, 3],
            label_string='({0:s})'.format(letter_label),
            font_size=TITLE_FONT_SIZE,
            x_coord_normalized=0.125, y_coord_normalized=1.025
        )

    max_colour_value = numpy.percentile(
        numpy.absolute(feature_matrix_after_pooling), MAX_COLOUR_PERCENTILE
    )

    axes_object_matrix[0, 4].set_title(
        'After pooling', fontsize=TITLE_FONT_SIZE)

    for k in range(num_output_channels):
        _plot_one_feature_map(
            feature_matrix_2d=feature_matrix_after_pooling[..., k],
            max_colour_value=max_colour_value,
            plot_colour_bar=k == num_output_channels - 1,
            axes_object=axes_object_matrix[k, 4]
        )

        letter_label = chr(ord(letter_label) + 1)

        plotting_utils.label_axes(
            axes_object=axes_object_matrix[k, 4],
            label_string='({0:s})'.format(letter_label),
            font_size=TITLE_FONT_SIZE,
            x_coord_normalized=0.125, y_coord_normalized=1.025
        )

    print('Saving figure to: "{0:s}"...'.format(output_file_name))
    figure_object.savefig(
        output_file_name, dpi=FIGURE_RESOLUTION_DPI,
        pad_inches=0, bbox_inches='tight'
    )
    pyplot.close(figure_object)
示例#12
0
def _plot_2d_radar_scan(list_of_predictor_matrices,
                        model_metadata_dict,
                        allow_whitespace,
                        title_string=None):
    """Plots 2-D radar scan for one example.

    J = number of panel rows in image
    K = number of panel columns in image

    :param list_of_predictor_matrices: See doc for `_plot_3d_radar_scan`.
    :param model_metadata_dict: Same.
    :param allow_whitespace: Same.
    :param title_string: Same.
    :return: figure_objects: length-1 list of figure handles (instances of
        `matplotlib.figure.Figure`).
    :return: axes_object_matrices: length-1 list.  Each element is a J-by-K
        numpy array of axes handles (instances of
        `matplotlib.axes._subplots.AxesSubplot`).
    """

    training_option_dict = model_metadata_dict[cnn.TRAINING_OPTION_DICT_KEY]
    list_of_layer_operation_dicts = model_metadata_dict[
        cnn.LAYER_OPERATIONS_KEY]

    if list_of_layer_operation_dicts is None:
        field_name_by_panel = training_option_dict[
            trainval_io.RADAR_FIELDS_KEY]
        num_panels = len(field_name_by_panel)

        panel_names = radar_plotting.radar_fields_and_heights_to_panel_names(
            field_names=field_name_by_panel,
            heights_m_agl=training_option_dict[trainval_io.RADAR_HEIGHTS_KEY])

        plot_cbar_by_panel = numpy.full(num_panels, True, dtype=bool)
    else:
        list_of_layer_operation_dicts = [
            list_of_layer_operation_dicts[k] for k in LAYER_OP_INDICES_TO_KEEP
        ]

        list_of_predictor_matrices[0] = list_of_predictor_matrices[0][
            ..., LAYER_OP_INDICES_TO_KEEP]

        field_name_by_panel, panel_names = (
            radar_plotting.layer_ops_to_field_and_panel_names(
                list_of_layer_operation_dicts=list_of_layer_operation_dicts))

        num_panels = len(field_name_by_panel)
        plot_cbar_by_panel = numpy.full(num_panels, True, dtype=bool)

        # if allow_whitespace:
        #     if len(field_name_by_panel) == 12:
        #         plot_cbar_by_panel[2::3] = True
        #     else:
        #         plot_cbar_by_panel[:] = True

    num_panel_rows = int(numpy.floor(numpy.sqrt(num_panels)))
    num_panel_columns = int(numpy.ceil(float(num_panels) / num_panel_rows))

    if allow_whitespace:
        figure_object = None
        axes_object_matrix = None
    else:
        figure_object, axes_object_matrix = (
            plotting_utils.create_paneled_figure(num_rows=num_panel_rows,
                                                 num_columns=num_panel_columns,
                                                 horizontal_spacing=0.,
                                                 vertical_spacing=0.,
                                                 shared_x_axis=False,
                                                 shared_y_axis=False,
                                                 keep_aspect_ratio=True))

    figure_object, axes_object_matrix = (
        radar_plotting.plot_many_2d_grids_without_coords(
            field_matrix=numpy.flip(list_of_predictor_matrices[0], axis=0),
            field_name_by_panel=field_name_by_panel,
            panel_names=panel_names,
            num_panel_rows=num_panel_rows,
            figure_object=figure_object,
            axes_object_matrix=axes_object_matrix,
            plot_colour_bar_by_panel=plot_cbar_by_panel,
            font_size=FONT_SIZE_WITH_COLOUR_BARS,
            row_major=False))

    if allow_whitespace and title_string is not None:
        pyplot.suptitle(title_string, fontsize=TITLE_FONT_SIZE)

    return [figure_object], [axes_object_matrix]
示例#13
0
def _plot_2d3d_radar_scan(list_of_predictor_matrices,
                          model_metadata_dict,
                          allow_whitespace,
                          title_string=None):
    """Plots 3-D reflectivity and 2-D azimuthal shear for one example.

    :param list_of_predictor_matrices: See doc for `_plot_3d_radar_scan`.
    :param model_metadata_dict: Same.
    :param allow_whitespace: Same.
    :param title_string: Same.
    :return: figure_objects: length-2 list of figure handles (instances of
        `matplotlib.figure.Figure`).  The first is for reflectivity; the second
        is for azimuthal shear.
    :return: axes_object_matrices: length-2 list (the first is for reflectivity;
        the second is for azimuthal shear).  Each element is a 2-D numpy
        array of axes handles (instances of
        `matplotlib.axes._subplots.AxesSubplot`).
    """

    training_option_dict = model_metadata_dict[cnn.TRAINING_OPTION_DICT_KEY]
    az_shear_field_names = training_option_dict[trainval_io.RADAR_FIELDS_KEY]
    refl_heights_m_agl = training_option_dict[trainval_io.RADAR_HEIGHTS_KEY]

    num_az_shear_fields = len(az_shear_field_names)
    num_refl_heights = len(refl_heights_m_agl)

    this_num_panel_rows = int(numpy.floor(numpy.sqrt(num_refl_heights)))
    this_num_panel_columns = int(
        numpy.ceil(float(num_refl_heights) / this_num_panel_rows))

    if allow_whitespace:
        refl_figure_object = None
        refl_axes_object_matrix = None
    else:
        refl_figure_object, refl_axes_object_matrix = (
            plotting_utils.create_paneled_figure(
                num_rows=this_num_panel_rows,
                num_columns=this_num_panel_columns,
                horizontal_spacing=0.,
                vertical_spacing=0.,
                shared_x_axis=False,
                shared_y_axis=False,
                keep_aspect_ratio=True))

    refl_figure_object, refl_axes_object_matrix = (
        radar_plotting.plot_3d_grid_without_coords(
            field_matrix=numpy.flip(list_of_predictor_matrices[0][..., 0],
                                    axis=0),
            field_name=radar_utils.REFL_NAME,
            grid_point_heights_metres=refl_heights_m_agl,
            ground_relative=True,
            num_panel_rows=this_num_panel_rows,
            figure_object=refl_figure_object,
            axes_object_matrix=refl_axes_object_matrix,
            font_size=FONT_SIZE_SANS_COLOUR_BARS))

    if allow_whitespace:
        this_colour_map_object, this_colour_norm_object = (
            radar_plotting.get_default_colour_scheme(radar_utils.REFL_NAME))

        plotting_utils.plot_colour_bar(
            axes_object_or_matrix=refl_axes_object_matrix,
            data_matrix=list_of_predictor_matrices[0],
            colour_map_object=this_colour_map_object,
            colour_norm_object=this_colour_norm_object,
            orientation_string='horizontal',
            extend_min=True,
            extend_max=True)

        if title_string is not None:
            this_title_string = '{0:s}; {1:s}'.format(title_string,
                                                      radar_utils.REFL_NAME)
            pyplot.suptitle(this_title_string, fontsize=TITLE_FONT_SIZE)

    if allow_whitespace:
        shear_figure_object = None
        shear_axes_object_matrix = None
    else:
        shear_figure_object, shear_axes_object_matrix = (
            plotting_utils.create_paneled_figure(
                num_rows=1,
                num_columns=num_az_shear_fields,
                horizontal_spacing=0.,
                vertical_spacing=0.,
                shared_x_axis=False,
                shared_y_axis=False,
                keep_aspect_ratio=True))

    shear_figure_object, shear_axes_object_matrix = (
        radar_plotting.plot_many_2d_grids_without_coords(
            field_matrix=numpy.flip(list_of_predictor_matrices[1], axis=0),
            field_name_by_panel=az_shear_field_names,
            panel_names=az_shear_field_names,
            num_panel_rows=1,
            figure_object=shear_figure_object,
            axes_object_matrix=shear_axes_object_matrix,
            plot_colour_bar_by_panel=numpy.full(num_az_shear_fields,
                                                False,
                                                dtype=bool),
            font_size=FONT_SIZE_SANS_COLOUR_BARS))

    if allow_whitespace:
        this_colour_map_object, this_colour_norm_object = (
            radar_plotting.get_default_colour_scheme(
                radar_utils.LOW_LEVEL_SHEAR_NAME))

        plotting_utils.plot_colour_bar(
            axes_object_or_matrix=shear_axes_object_matrix,
            data_matrix=list_of_predictor_matrices[1],
            colour_map_object=this_colour_map_object,
            colour_norm_object=this_colour_norm_object,
            orientation_string='horizontal',
            extend_min=True,
            extend_max=True)

        if title_string is not None:
            pyplot.suptitle(title_string, fontsize=TITLE_FONT_SIZE)

    figure_objects = [refl_figure_object, shear_figure_object]
    axes_object_matrices = [refl_axes_object_matrix, shear_axes_object_matrix]
    return figure_objects, axes_object_matrices
示例#14
0
def _plot_3d_radar_scan(list_of_predictor_matrices,
                        model_metadata_dict,
                        allow_whitespace,
                        title_string=None):
    """Plots 3-D radar scan for one example.

    J = number of panel rows in image
    K = number of panel columns in image
    F = number of radar fields

    :param list_of_predictor_matrices: List created by
        `testing_io.read_specific_examples`, except that the first axis (example
        dimension) is removed.
    :param model_metadata_dict: Dictionary returned by
        `cnn.read_model_metadata`.
    :param allow_whitespace: See documentation at top of file.
    :param title_string: Title (may be None).

    :return: figure_objects: length-F list of figure handles (instances of
        `matplotlib.figure.Figure`).
    :return: axes_object_matrices: length-F list.  Each element is a J-by-K
        numpy array of axes handles (instances of
        `matplotlib.axes._subplots.AxesSubplot`).
    """

    training_option_dict = model_metadata_dict[cnn.TRAINING_OPTION_DICT_KEY]
    radar_field_names = training_option_dict[trainval_io.RADAR_FIELDS_KEY]
    radar_heights_m_agl = training_option_dict[trainval_io.RADAR_HEIGHTS_KEY]

    num_radar_fields = len(radar_field_names)
    num_radar_heights = len(radar_heights_m_agl)

    num_panel_rows = int(numpy.floor(numpy.sqrt(num_radar_heights)))
    num_panel_columns = int(
        numpy.ceil(float(num_radar_heights) / num_panel_rows))

    figure_objects = [None] * num_radar_fields
    axes_object_matrices = [None] * num_radar_fields
    radar_matrix = list_of_predictor_matrices[0]

    for j in range(num_radar_fields):
        this_radar_matrix = numpy.flip(radar_matrix[..., j], axis=0)

        if not allow_whitespace:
            figure_objects[j], axes_object_matrices[j] = (
                plotting_utils.create_paneled_figure(
                    num_rows=num_panel_rows,
                    num_columns=num_panel_columns,
                    horizontal_spacing=0.,
                    vertical_spacing=0.,
                    shared_x_axis=False,
                    shared_y_axis=False,
                    keep_aspect_ratio=True))

        figure_objects[j], axes_object_matrices[j] = (
            radar_plotting.plot_3d_grid_without_coords(
                field_matrix=this_radar_matrix,
                field_name=radar_field_names[j],
                grid_point_heights_metres=radar_heights_m_agl,
                ground_relative=True,
                num_panel_rows=num_panel_rows,
                figure_object=figure_objects[j],
                axes_object_matrix=axes_object_matrices[j],
                font_size=FONT_SIZE_SANS_COLOUR_BARS))

        if allow_whitespace:
            this_colour_map_object, this_colour_norm_object = (
                radar_plotting.get_default_colour_scheme(radar_field_names[j]))

            plotting_utils.plot_colour_bar(
                axes_object_or_matrix=axes_object_matrices[j],
                data_matrix=this_radar_matrix,
                colour_map_object=this_colour_map_object,
                colour_norm_object=this_colour_norm_object,
                orientation_string='horizontal',
                extend_min=True,
                extend_max=True)

            if title_string is not None:
                this_title_string = '{0:s}; {1:s}'.format(
                    title_string, radar_field_names[j])
                pyplot.suptitle(this_title_string, fontsize=TITLE_FONT_SIZE)

    return figure_objects, axes_object_matrices
示例#15
0
def _plot_one_score(score_matrix, colour_map_object, min_colour_value,
                    max_colour_value, colour_bar_label, is_score_bias,
                    best_model_index, output_file_name):
    """Plots one score.

    :param score_matrix: 4-D numpy array of scores, where the first axis
        represents dropout rate; second represents L2 weight; third represents
        num dense layers; and fourth is data augmentation (yes or no).
    :param colour_map_object: See documentation at top of file.
    :param min_colour_value: Minimum value in colour scheme.
    :param max_colour_value: Max value in colour scheme.
    :param colour_bar_label: Label string for colour bar.
    :param is_score_bias: Boolean flag.  If True, score to be plotted is
        frequency bias, which changes settings for colour scheme.
    :param best_model_index: Linear index of best model.
    :param output_file_name: Path to output file (figure will be saved here).
    """

    if is_score_bias:
        colour_map_object, colour_norm_object = _get_bias_colour_scheme(
            max_value=max_colour_value)
    else:
        colour_norm_object = None

    num_dense_layer_counts = len(DENSE_LAYER_COUNTS)
    num_data_aug_flags = len(DATA_AUGMENTATION_FLAGS)

    figure_object, axes_object_matrix = plotting_utils.create_paneled_figure(
        num_rows=num_dense_layer_counts * num_data_aug_flags,
        num_columns=1,
        horizontal_spacing=0.15,
        vertical_spacing=0.15,
        shared_x_axis=False,
        shared_y_axis=False,
        keep_aspect_ratio=True)

    axes_object_matrix = numpy.reshape(
        axes_object_matrix, (num_dense_layer_counts, num_data_aug_flags))

    x_axis_label = r'L$_2$ weight (log$_{10}$)'
    y_axis_label = 'Dropout rate'
    x_tick_labels = ['{0:.1f}'.format(w) for w in numpy.log10(L2_WEIGHTS)]
    y_tick_labels = ['{0:.3f}'.format(d) for d in DROPOUT_RATES]

    best_model_index_tuple = numpy.unravel_index(best_model_index,
                                                 score_matrix.shape)

    for k in range(num_dense_layer_counts):
        for m in range(num_data_aug_flags):
            model_eval.plot_hyperparam_grid(
                score_matrix=score_matrix[..., k, m],
                min_colour_value=min_colour_value,
                max_colour_value=max_colour_value,
                colour_map_object=colour_map_object,
                colour_norm_object=colour_norm_object,
                axes_object=axes_object_matrix[k, m])

            axes_object_matrix[k, m].set_xticklabels(
                x_tick_labels, fontsize=TICK_LABEL_FONT_SIZE, rotation=90.)
            axes_object_matrix[k, m].set_yticklabels(
                y_tick_labels, fontsize=TICK_LABEL_FONT_SIZE)
            axes_object_matrix[k, m].set_ylabel(y_axis_label,
                                                fontsize=TICK_LABEL_FONT_SIZE)

            if k == num_dense_layer_counts - 1 and m == num_data_aug_flags - 1:
                axes_object_matrix[k, m].set_xlabel(x_axis_label)
            else:
                axes_object_matrix[k, m].set_xticks([], [])

            this_title_string = '{0:d} dense layer{1:s}, DA {2:s}'.format(
                DENSE_LAYER_COUNTS[k],
                's' if DENSE_LAYER_COUNTS[k] > 1 else '',
                'on' if DATA_AUGMENTATION_FLAGS[m] else 'off')

            axes_object_matrix[k, m].set_title(this_title_string)

    i = best_model_index_tuple[0]
    j = best_model_index_tuple[1]
    k = best_model_index_tuple[2]
    m = best_model_index_tuple[3]

    axes_object_matrix[k, m].plot(j,
                                  i,
                                  linestyle='None',
                                  marker=BEST_MODEL_MARKER_TYPE,
                                  markersize=BEST_MODEL_MARKER_SIZE,
                                  markerfacecolor=MARKER_COLOUR,
                                  markeredgecolor=MARKER_COLOUR,
                                  markeredgewidth=BEST_MODEL_MARKER_WIDTH)

    corrupt_model_indices = numpy.where(numpy.isnan(
        numpy.ravel(score_matrix)))[0]

    for this_linear_index in corrupt_model_indices:
        i, j, k, m = numpy.unravel_index(this_linear_index, score_matrix.shape)

        axes_object_matrix[k,
                           m].plot(j,
                                   i,
                                   linestyle='None',
                                   marker=CORRUPT_MODEL_MARKER_TYPE,
                                   markersize=CORRUPT_MODEL_MARKER_SIZE,
                                   markerfacecolor=MARKER_COLOUR,
                                   markeredgecolor=MARKER_COLOUR,
                                   markeredgewidth=CORRUPT_MODEL_MARKER_WIDTH)

    if is_score_bias:
        colour_bar_object = plotting_utils.plot_colour_bar(
            axes_object_or_matrix=axes_object_matrix,
            data_matrix=score_matrix,
            colour_map_object=colour_map_object,
            colour_norm_object=colour_norm_object,
            orientation_string='vertical',
            extend_min=False,
            extend_max=True,
            font_size=DEFAULT_FONT_SIZE)

        tick_values = colour_bar_object.get_ticks()
        tick_strings = ['{0:.1f}'.format(v) for v in tick_values]

        colour_bar_object.set_ticks(tick_values)
        colour_bar_object.set_ticklabels(tick_strings)
    else:
        colour_bar_object = plotting_utils.plot_linear_colour_bar(
            axes_object_or_matrix=axes_object_matrix,
            data_matrix=score_matrix,
            colour_map_object=colour_map_object,
            min_value=min_colour_value,
            max_value=max_colour_value,
            orientation_string='vertical',
            extend_min=True,
            extend_max=True,
            font_size=DEFAULT_FONT_SIZE)

    colour_bar_object.set_label(colour_bar_label)
    print('Saving figure to: "{0:s}"...'.format(output_file_name))

    figure_object.savefig(output_file_name,
                          dpi=FIGURE_RESOLUTION_DPI,
                          pad_inches=0,
                          bbox_inches='tight')
    pyplot.close(figure_object)
示例#16
0
def _run(exp1_permutation_dir_name, exp2_permutation_dir_name,
         use_forward_test, use_multipass_test, output_file_name):
    """Creates figure showing permutation-test results for both models.

    This is effectively the main method.

    :param exp1_permutation_dir_name: See documentation at top of file.
    :param exp2_permutation_dir_name: Same.
    :param use_forward_test: Same.
    :param use_multipass_test: Same.
    :param output_file_name: Same.
    """

    file_system_utils.mkdir_recursive_if_necessary(file_name=output_file_name)

    exp1_flux_file_name = '{0:s}/{1:s}_perm_test_fluxes-only.nc'.format(
        exp1_permutation_dir_name,
        'forward' if use_forward_test else 'backwards')
    exp1_heating_rate_file_name = '{0:s}/{1:s}_perm_test_hr-only.nc'.format(
        exp1_permutation_dir_name,
        'forward' if use_forward_test else 'backwards')
    exp2_flux_file_name = '{0:s}/{1:s}_perm_test_fluxes-only.nc'.format(
        exp2_permutation_dir_name,
        'forward' if use_forward_test else 'backwards')
    exp2_heating_rate_file_name = '{0:s}/{1:s}_perm_test_hr-only.nc'.format(
        exp2_permutation_dir_name,
        'forward' if use_forward_test else 'backwards')

    print('Reading data from: "{0:s}"...'.format(exp1_heating_rate_file_name))
    exp1_heating_permutation_dict = ml4rt_permutation.read_file(
        exp1_heating_rate_file_name)
    exp1_heating_permutation_dict = _results_to_gg_format(
        exp1_heating_permutation_dict)

    figure_object, axes_object_matrix = plotting_utils.create_paneled_figure(
        num_rows=2,
        num_columns=2,
        shared_x_axis=False,
        shared_y_axis=True,
        keep_aspect_ratio=False,
        horizontal_spacing=0.25,
        vertical_spacing=0.25)

    if use_multipass_test:
        permutation_plotting.plot_multipass_test(
            permutation_dict=exp1_heating_permutation_dict,
            axes_object=axes_object_matrix[0, 0],
            plot_percent_increase=False,
            confidence_level=CONFIDENCE_LEVEL)
    else:
        permutation_plotting.plot_single_pass_test(
            permutation_dict=exp1_heating_permutation_dict,
            axes_object=axes_object_matrix[0, 0],
            plot_percent_increase=False,
            confidence_level=CONFIDENCE_LEVEL)

    plotting_utils.label_axes(axes_object=axes_object_matrix[0, 0],
                              label_string='(a)',
                              font_size=30,
                              x_coord_normalized=0.1,
                              y_coord_normalized=1.01)
    axes_object_matrix[0, 0].set_title('Exp 1, heating rates only')
    axes_object_matrix[0,
                       0].set_xlabel(r'Dual-weighted MSE (K$^3$ day$^{-3}$)')
    axes_object_matrix[0, 0].set_ylabel('')

    print('Reading data from: "{0:s}"...'.format(exp1_flux_file_name))
    exp1_flux_permutation_dict = ml4rt_permutation.read_file(
        exp1_flux_file_name)
    exp1_flux_permutation_dict = _results_to_gg_format(
        exp1_flux_permutation_dict)

    if use_multipass_test:
        permutation_plotting.plot_multipass_test(
            permutation_dict=exp1_flux_permutation_dict,
            axes_object=axes_object_matrix[0, 1],
            plot_percent_increase=False,
            confidence_level=CONFIDENCE_LEVEL)
    else:
        permutation_plotting.plot_single_pass_test(
            permutation_dict=exp1_flux_permutation_dict,
            axes_object=axes_object_matrix[0, 1],
            plot_percent_increase=False,
            confidence_level=CONFIDENCE_LEVEL)

    plotting_utils.label_axes(axes_object=axes_object_matrix[0, 1],
                              label_string='(b)',
                              font_size=30,
                              x_coord_normalized=0.1,
                              y_coord_normalized=1.01)
    axes_object_matrix[0, 1].set_title('Exp 1, fluxes only')
    axes_object_matrix[0, 1].set_xlabel(r'MSE (K day$^{-1}$)')
    axes_object_matrix[0, 1].set_ylabel('')

    print('Reading data from: "{0:s}"...'.format(exp2_heating_rate_file_name))
    exp2_heating_permutation_dict = ml4rt_permutation.read_file(
        exp2_heating_rate_file_name)
    exp2_heating_permutation_dict = _results_to_gg_format(
        exp2_heating_permutation_dict)

    if use_multipass_test:
        permutation_plotting.plot_multipass_test(
            permutation_dict=exp2_heating_permutation_dict,
            axes_object=axes_object_matrix[1, 0],
            plot_percent_increase=False,
            confidence_level=CONFIDENCE_LEVEL)
    else:
        permutation_plotting.plot_single_pass_test(
            permutation_dict=exp2_heating_permutation_dict,
            axes_object=axes_object_matrix[1, 0],
            plot_percent_increase=False,
            confidence_level=CONFIDENCE_LEVEL)

    plotting_utils.label_axes(axes_object=axes_object_matrix[1, 0],
                              label_string='(c)',
                              font_size=30,
                              x_coord_normalized=0.1,
                              y_coord_normalized=1.01)
    axes_object_matrix[1, 0].set_title('Exp 2, heating rates only')
    axes_object_matrix[1,
                       0].set_xlabel(r'Dual-weighted MSE (K$^3$ day$^{-3}$)')
    axes_object_matrix[1, 0].set_ylabel('')

    print('Reading data from: "{0:s}"...'.format(exp2_flux_file_name))
    exp2_flux_permutation_dict = ml4rt_permutation.read_file(
        exp2_flux_file_name)
    exp2_flux_permutation_dict = _results_to_gg_format(
        exp2_flux_permutation_dict)

    if use_multipass_test:
        permutation_plotting.plot_multipass_test(
            permutation_dict=exp2_flux_permutation_dict,
            axes_object=axes_object_matrix[1, 1],
            plot_percent_increase=False,
            confidence_level=CONFIDENCE_LEVEL)
    else:
        permutation_plotting.plot_single_pass_test(
            permutation_dict=exp2_flux_permutation_dict,
            axes_object=axes_object_matrix[1, 1],
            plot_percent_increase=False,
            confidence_level=CONFIDENCE_LEVEL)

    plotting_utils.label_axes(axes_object=axes_object_matrix[1, 1],
                              label_string='(d)',
                              font_size=30,
                              x_coord_normalized=0.1,
                              y_coord_normalized=1.01)
    axes_object_matrix[1, 1].set_title('Exp 2, fluxes only')
    axes_object_matrix[1, 1].set_xlabel(r'MSE (K day$^{-1}$)')
    axes_object_matrix[1, 1].set_ylabel('')

    print('Saving figure to: "{0:s}"...'.format(output_file_name))
    figure_object.savefig(output_file_name,
                          dpi=FIGURE_RESOLUTION_DPI,
                          pad_inches=0,
                          bbox_inches='tight')
    pyplot.close(figure_object)