def _run(output_file_name): """Plots Laplacian kernel used for edge-detector test. This is effectively the main method. :param output_file_name: See documentation at top of file. """ num_heights = KERNEL_MATRIX_3D.shape[-1] figure_object, axes_object_matrix = plotting_utils.create_paneled_figure( num_rows=1, num_columns=num_heights, horizontal_spacing=0.1, vertical_spacing=0.1, shared_x_axis=False, shared_y_axis=False, keep_aspect_ratio=True) for k in range(num_heights): _plot_kernel_one_height(kernel_matrix_2d=KERNEL_MATRIX_3D[..., k], axes_object=axes_object_matrix[0, k]) axes_object_matrix[0, 0].set_title('Bottom height') axes_object_matrix[0, 1].set_title('Middle height') axes_object_matrix[0, 2].set_title('Top height') file_system_utils.mkdir_recursive_if_necessary(file_name=output_file_name) print('Saving figure to: "{0:s}"...'.format(output_file_name)) figure_object.savefig(output_file_name, dpi=FIGURE_RESOLUTION_DPI, pad_inches=0, bbox_inches='tight') pyplot.close(figure_object)
def plot_many_1d_feature_maps( feature_matrix, colour_map_object, colour_norm_object=None, min_colour_value=None, max_colour_value=None, figure_width_inches=DEFAULT_FIG_WIDTH_INCHES, figure_height_inches=DEFAULT_FIG_HEIGHT_INCHES): """Plots many 1-D feature maps in the same figure (one per column). N = number of points in spatial grid C = number of channels :param feature_matrix: N-by-C numpy array of feature values. :param colour_map_object: See doc for `plot_many_2d_feature_maps`. :param colour_norm_object: Same. :param min_colour_value: Same. :param max_colour_value: Same. :param figure_width_inches: Same. :param figure_height_inches: Same. :return: figure_object: See doc for `plotting_utils.create_paneled_figure`. :return: axes_object_matrix: Same. """ pyplot.rc('axes', linewidth=1) error_checking.assert_is_numpy_array(feature_matrix, num_dimensions=2) num_channels = feature_matrix.shape[1] num_spatial_points = feature_matrix.shape[0] figure_object, axes_object_matrix = plotting_utils.create_paneled_figure( num_rows=1, num_columns=num_channels, figure_width_inches=figure_width_inches, figure_height_inches=figure_height_inches, horizontal_spacing=0., vertical_spacing=0., shared_x_axis=False, shared_y_axis=False, keep_aspect_ratio=False) for k in range(num_channels): this_matrix = numpy.reshape( feature_matrix[..., k], (num_spatial_points, 1) ) plot_2d_feature_map( feature_matrix=this_matrix, axes_object=axes_object_matrix[0, k], font_size=30, colour_map_object=colour_map_object, colour_norm_object=colour_norm_object, min_colour_value=min_colour_value, max_colour_value=max_colour_value, annotation_string='' ) return figure_object, axes_object_matrix
def _run(include_caption, output_dir_name): """Makes animation to explain multivariate convolution. This is effectively the main method. :param include_caption: See documentation at top of file. :param output_dir_name: Same. """ file_system_utils.mkdir_recursive_if_necessary( directory_name=output_dir_name) output_feature_matrix = standalone_utils.do_2d_convolution( feature_matrix=INPUT_FEATURE_MATRIX, kernel_matrix=KERNEL_MATRIX, pad_edges=True, stride_length_px=1) output_feature_matrix = output_feature_matrix[0, ..., 0] num_grid_rows = INPUT_FEATURE_MATRIX.shape[0] num_grid_columns = INPUT_FEATURE_MATRIX.shape[1] image_file_names = [] kernel_width_ratio = float(KERNEL_MATRIX.shape[1]) / num_grid_columns kernel_height_ratio = float(KERNEL_MATRIX.shape[0]) / num_grid_rows for i in range(num_grid_rows): for j in range(num_grid_columns): this_figure_object, this_axes_object_matrix = ( plotting_utils.create_paneled_figure( num_rows=NUM_PANEL_ROWS, num_columns=NUM_PANEL_COLUMNS, horizontal_spacing=0.2, vertical_spacing=0., shared_x_axis=False, shared_y_axis=False, keep_aspect_ratio=True) ) letter_label = None _plot_feature_map( feature_matrix_2d=INPUT_FEATURE_MATRIX[..., 0], kernel_row=i, kernel_column=j, is_output_map=False, axes_object=this_axes_object_matrix[0, 0] ) if letter_label is None: letter_label = 'a' else: letter_label = chr(ord(letter_label) + 1) plotting_utils.label_axes( axes_object=this_axes_object_matrix[0, 0], label_string='({0:s})'.format(letter_label), font_size=PANEL_LETTER_FONT_SIZE, y_coord_normalized=1.04, x_coord_normalized=0.1 ) _plot_feature_map( feature_matrix_2d=output_feature_matrix, kernel_row=i, kernel_column=j, is_output_map=True, axes_object=this_axes_object_matrix[0, 2] ) this_bbox_object = this_axes_object_matrix[0, 1].get_position() this_width = kernel_width_ratio * ( this_bbox_object.x1 - this_bbox_object.x0 ) this_height = kernel_height_ratio * ( this_bbox_object.y1 - this_bbox_object.y0 ) this_bbox_object.x0 += 0.5 * this_width this_bbox_object.y0 = ( this_axes_object_matrix[0, 0].get_position().y0 + 0.1 ) this_bbox_object.x1 = this_bbox_object.x0 + this_width this_bbox_object.y1 = this_bbox_object.y0 + this_height this_axes_object_matrix[0, 1].set_position(this_bbox_object) _plot_kernel( kernel_matrix_2d=KERNEL_MATRIX[..., 0, 0], feature_matrix_2d=INPUT_FEATURE_MATRIX[..., 0], feature_row_at_center=i, feature_column_at_center=j, axes_object=this_axes_object_matrix[0, 1] ) letter_label = chr(ord(letter_label) + 1) plotting_utils.label_axes( axes_object=this_axes_object_matrix[0, 1], label_string='({0:s})'.format(letter_label), font_size=PANEL_LETTER_FONT_SIZE, y_coord_normalized=1.04, x_coord_normalized=0.2 ) _plot_feature_to_kernel_lines( kernel_matrix_2d=KERNEL_MATRIX[..., 0, 0], feature_matrix_2d=INPUT_FEATURE_MATRIX[..., 0], feature_row_at_center=i, feature_column_at_center=j, kernel_axes_object=this_axes_object_matrix[0, 1], feature_axes_object=this_axes_object_matrix[0, 0] ) letter_label = chr(ord(letter_label) + 1) plotting_utils.label_axes( axes_object=this_axes_object_matrix[0, 2], label_string='({0:s})'.format(letter_label), font_size=PANEL_LETTER_FONT_SIZE, y_coord_normalized=1.04, x_coord_normalized=0.1 ) if include_caption: this_figure_object.text( 0.5, 0.35, FIGURE_CAPTION, fontsize=DEFAULT_FONT_SIZE, color='k', horizontalalignment='center', verticalalignment='top') image_file_names.append( '{0:s}/conv_animation_row{1:d}_column{2:d}.jpg'.format( output_dir_name, i, j) ) print('Saving figure to: "{0:s}"...'.format(image_file_names[-1])) this_figure_object.savefig( image_file_names[-1], dpi=FIGURE_RESOLUTION_DPI, pad_inches=0, bbox_inches='tight' ) pyplot.close(this_figure_object) animation_file_name = '{0:s}/conv_animation.gif'.format(output_dir_name) print('Creating animation: "{0:s}"...'.format(animation_file_name)) imagemagick_utils.create_gif( input_file_names=image_file_names, output_file_name=animation_file_name, num_seconds_per_frame=0.5, resize_factor=0.5)
def plot_3d_grid_without_coords(field_matrix, field_name, grid_point_heights_metres, ground_relative, num_panel_rows=None, figure_object=None, axes_object_matrix=None, font_size=DEFAULT_FONT_SIZE, colour_map_object=None, colour_norm_object=None): """Plots 3-D grid as many colour maps (one per height). M = number of grid rows N = number of grid columns H = number of grid heights To use the default colour scheme for the given radar field, leave `colour_map_object` and `colour_norm_object` empty. If `num_panel_rows is None`, this method needs arguments `figure_object` and `axes_object_matrix` -- and vice-versa. :param field_matrix: M-by-N-by-H numpy array with values of radar field. :param field_name: Name of radar field (must be accepted by `radar_utils.check_field_name`). :param grid_point_heights_metres: length-H integer numpy array of heights. :param ground_relative: Boolean flag. If True, heights in `height_by_pair_metres` are ground-relative. If False, sea-level-relative. :param num_panel_rows: Number of rows in paneled figure (different than M, the number of grid rows). :param figure_object: See doc for `plotting_utils.create_paneled_figure`. :param axes_object_matrix: See above. :param font_size: Font size for colour-bar ticks and panel labels. :param colour_map_object: See doc for `plot_latlng_grid`. :param colour_norm_object: Same. :return: figure_object: See doc for `plotting_utils.init_panels`. :return: axes_object_matrix: Same. """ error_checking.assert_is_numpy_array(field_matrix, num_dimensions=3) error_checking.assert_is_geq_numpy_array(grid_point_heights_metres, 0) grid_point_heights_metres = numpy.round(grid_point_heights_metres).astype( int) num_heights = field_matrix.shape[2] these_expected_dim = numpy.array([num_heights], dtype=int) error_checking.assert_is_numpy_array(grid_point_heights_metres, exact_dimensions=these_expected_dim) error_checking.assert_is_boolean(ground_relative) if figure_object is None: error_checking.assert_is_integer(num_panel_rows) error_checking.assert_is_geq(num_panel_rows, 1) error_checking.assert_is_leq(num_panel_rows, num_heights) num_panel_columns = int(numpy.ceil( float(num_heights) / num_panel_rows)) figure_object, axes_object_matrix = ( plotting_utils.create_paneled_figure(num_rows=num_panel_rows, num_columns=num_panel_columns, shared_x_axis=False, shared_y_axis=False, keep_aspect_ratio=True)) else: error_checking.assert_is_numpy_array(axes_object_matrix, num_dimensions=2) num_panel_rows = axes_object_matrix.shape[0] num_panel_columns = axes_object_matrix.shape[1] for i in range(num_panel_rows): for j in range(num_panel_columns): this_height_index = i * num_panel_columns + j if this_height_index >= num_heights: axes_object_matrix[i, j].axis('off') continue this_annotation_string = '{0:.1f} km'.format( grid_point_heights_metres[this_height_index] * METRES_TO_KM) if ground_relative: this_annotation_string += ' AGL' else: this_annotation_string += ' ASL' plot_2d_grid_without_coords( field_matrix=field_matrix[..., this_height_index], field_name=field_name, axes_object=axes_object_matrix[i, j], annotation_string=this_annotation_string, colour_map_object=colour_map_object, colour_norm_object=colour_norm_object, font_size=font_size) return figure_object, axes_object_matrix
def plot_many_2d_grids_without_coords(field_matrix, field_name_by_panel, num_panel_rows=None, figure_object=None, axes_object_matrix=None, panel_names=None, colour_map_object_by_panel=None, colour_norm_object_by_panel=None, plot_colour_bar_by_panel=None, font_size=DEFAULT_FONT_SIZE, row_major=True): """Plots 2-D colour map in each panel (one per field/height pair). M = number of rows in spatial grid N = number of columns in spatial grid P = number of panels (field/height pairs) This method uses the default colour scheme for each radar field. If `num_panel_rows is None`, this method needs arguments `figure_object` and `axes_object_matrix` -- and vice-versa. :param field_matrix: M-by-N-by-P numpy array of radar values. :param field_name_by_panel: length-P list of field names. :param num_panel_rows: Number of rows in paneled figure (different than M, which is number of rows in spatial grid). :param figure_object: See doc for `plotting_utils.create_paneled_figure`. :param axes_object_matrix: See above. :param panel_names: length-P list of panel names (will be printed at bottoms of panels). If you do not want panel names, make this None. :param colour_map_object_by_panel: length-P list of `matplotlib.pyplot.cm` objects. If this is None, the default will be used for each field. :param colour_norm_object_by_panel: length-P list of `matplotlib.colors.BoundaryNorm` objects. If this is None, the default will be used for each field. :param plot_colour_bar_by_panel: length-P numpy array of Boolean flags. If plot_colour_bar_by_panel[k] = True, horizontal colour bar will be plotted under [k]th panel. If you want to plot colour bar for every panel, leave this as None. :param font_size: Font size. :param row_major: Boolean flag. If True, panels will be filled along rows first, then down columns. If False, down columns first, then along rows. :return: figure_object: See doc for `plotting_utils.create_paneled_figure`. :return: axes_object_matrix: Same. :raises: ValueError: if `colour_map_object_by_panel` or `colour_norm_object_by_panel` has different length than number of panels. """ error_checking.assert_is_boolean(row_major) error_checking.assert_is_numpy_array(field_matrix, num_dimensions=3) num_panels = field_matrix.shape[2] if panel_names is None: panel_names = [None] * num_panels if plot_colour_bar_by_panel is None: plot_colour_bar_by_panel = numpy.full(num_panels, True, dtype=bool) these_expected_dim = numpy.array([num_panels], dtype=int) error_checking.assert_is_numpy_array(numpy.array(panel_names), exact_dimensions=these_expected_dim) error_checking.assert_is_numpy_array(numpy.array(field_name_by_panel), exact_dimensions=these_expected_dim) error_checking.assert_is_boolean_numpy_array(plot_colour_bar_by_panel) error_checking.assert_is_numpy_array(plot_colour_bar_by_panel, exact_dimensions=these_expected_dim) if (colour_map_object_by_panel is None or colour_norm_object_by_panel is None): colour_map_object_by_panel = [None] * num_panels colour_norm_object_by_panel = [None] * num_panels error_checking.assert_is_list(colour_map_object_by_panel) error_checking.assert_is_list(colour_norm_object_by_panel) if len(colour_map_object_by_panel) != num_panels: error_string = ( 'Number of colour maps ({0:d}) should equal number of panels ' '({1:d}).').format(len(colour_map_object_by_panel), num_panels) raise ValueError(error_string) if len(colour_norm_object_by_panel) != num_panels: error_string = ( 'Number of colour-normalizers ({0:d}) should equal number of panels' ' ({1:d}).').format(len(colour_norm_object_by_panel), num_panels) raise ValueError(error_string) if figure_object is None: error_checking.assert_is_integer(num_panel_rows) error_checking.assert_is_geq(num_panel_rows, 1) error_checking.assert_is_leq(num_panel_rows, num_panels) num_panel_columns = int(numpy.ceil(float(num_panels) / num_panel_rows)) figure_object, axes_object_matrix = ( plotting_utils.create_paneled_figure(num_rows=num_panel_rows, num_columns=num_panel_columns, shared_x_axis=False, shared_y_axis=False, keep_aspect_ratio=True)) else: error_checking.assert_is_numpy_array(axes_object_matrix, num_dimensions=2) num_panel_rows = axes_object_matrix.shape[0] num_panel_columns = axes_object_matrix.shape[1] if row_major: order_string = 'C' else: order_string = 'F' for k in range(num_panels): this_panel_row, this_panel_column = numpy.unravel_index( k, (num_panel_rows, num_panel_columns), order=order_string) # this_colour_map_object, this_colour_norm_object = ( # plot_2d_grid_without_coords( # field_matrix=field_matrix[..., k], # field_name=field_name_by_panel[k], # axes_object=axes_object_matrix[ # this_panel_row, this_panel_column], # annotation_string=panel_names[k], font_size=font_size, # colour_map_object=colour_map_object_by_panel[k], # colour_norm_object=colour_norm_object_by_panel[k] # ) # ) this_colour_map_object, this_colour_norm_object = ( plot_2d_grid_without_coords( field_matrix=field_matrix[..., k], field_name=field_name_by_panel[k], axes_object=axes_object_matrix[this_panel_row, this_panel_column], annotation_string=None, font_size=font_size, colour_map_object=colour_map_object_by_panel[k], colour_norm_object=colour_norm_object_by_panel[k])) if not plot_colour_bar_by_panel[k]: continue this_extend_min_flag = field_name_by_panel[k] in SHEAR_VORT_DIV_NAMES this_colour_bar_object = plotting_utils.plot_colour_bar( axes_object_or_matrix=axes_object_matrix[this_panel_row, this_panel_column], data_matrix=field_matrix[..., k], colour_map_object=this_colour_map_object, colour_norm_object=this_colour_norm_object, orientation_string='horizontal', extend_min=this_extend_min_flag, extend_max=True, fraction_of_axis_length=0.75, font_size=font_size) this_colour_bar_object.set_label(panel_names[k].replace('\n', '; '), fontsize=font_size, fontweight='bold') for k in range(num_panel_rows * num_panel_columns): if k < num_panels: continue this_panel_row, this_panel_column = numpy.unravel_index( k, (num_panel_rows, num_panel_columns), order=order_string) axes_object_matrix[this_panel_row, this_panel_column].axis('off') return figure_object, axes_object_matrix
def _run(input_file_name, num_predictors_to_plot, confidence_level, output_dir_name): """Plots results of permutation-based importance test. This is effectively the main method. :param input_file_name: See documentation at top of file. :param num_predictors_to_plot: Same. :param confidence_level: Same. :param output_dir_name: Same. """ if num_predictors_to_plot <= 0: num_predictors_to_plot = None file_system_utils.mkdir_recursive_if_necessary( directory_name=output_dir_name) print('Reading data from: "{0:s}"...'.format(input_file_name)) permutation_dict = ml4rt_permutation.read_file(input_file_name) permutation_dict = _results_to_gg_format(permutation_dict) figure_object, axes_object_matrix = plotting_utils.create_paneled_figure( num_rows=1, num_columns=2, shared_x_axis=False, shared_y_axis=True, keep_aspect_ratio=False, horizontal_spacing=0.1, vertical_spacing=0.05) permutation_plotting.plot_single_pass_test( permutation_dict=permutation_dict, axes_object=axes_object_matrix[0, 0], num_predictors_to_plot=num_predictors_to_plot, plot_percent_increase=False, confidence_level=confidence_level, bar_face_colour=BAR_FACE_COLOUR) axes_object_matrix[0, 0].set_title('Single-pass test') axes_object_matrix[0, 0].set_xlabel('Mean squared error') permutation_plotting.plot_multipass_test( permutation_dict=permutation_dict, axes_object=axes_object_matrix[0, 1], num_predictors_to_plot=num_predictors_to_plot, plot_percent_increase=False, confidence_level=confidence_level, bar_face_colour=BAR_FACE_COLOUR) axes_object_matrix[0, 1].set_title('Multi-pass test') axes_object_matrix[0, 1].set_xlabel('Mean squared error') axes_object_matrix[0, 1].set_ylabel('') figure_file_name = '{0:s}/permutation_test_abs-values.jpg'.format( output_dir_name) print('Saving figure to: "{0:s}"...'.format(figure_file_name)) figure_object.savefig(figure_file_name, dpi=FIGURE_RESOLUTION_DPI, pad_inches=0, bbox_inches='tight') pyplot.close(figure_object) figure_object, axes_object_matrix = plotting_utils.create_paneled_figure( num_rows=1, num_columns=2, shared_x_axis=False, shared_y_axis=True, keep_aspect_ratio=False, horizontal_spacing=0.1, vertical_spacing=0.05) permutation_plotting.plot_single_pass_test( permutation_dict=permutation_dict, axes_object=axes_object_matrix[0, 0], num_predictors_to_plot=num_predictors_to_plot, plot_percent_increase=True, confidence_level=confidence_level, bar_face_colour=BAR_FACE_COLOUR) axes_object_matrix[0, 0].set_title('Single-pass test') axes_object_matrix[0, 0].set_xlabel('MSE (fraction of original)') permutation_plotting.plot_multipass_test( permutation_dict=permutation_dict, axes_object=axes_object_matrix[0, 1], num_predictors_to_plot=num_predictors_to_plot, plot_percent_increase=True, confidence_level=confidence_level, bar_face_colour=BAR_FACE_COLOUR) axes_object_matrix[0, 1].set_title('Multi-pass test') axes_object_matrix[0, 1].set_xlabel('MSE (fraction of original)') axes_object_matrix[0, 1].set_ylabel('') figure_file_name = '{0:s}/permutation_test_percentage.jpg'.format( output_dir_name) print('Saving figure to: "{0:s}"...'.format(figure_file_name)) figure_object.savefig(figure_file_name, dpi=FIGURE_RESOLUTION_DPI, pad_inches=0, bbox_inches='tight') pyplot.close(figure_object)
def _run(input_file_name, num_predictors_to_plot, output_dir_name): """Plots results of permutation test. This is effectively the main method. :param input_file_name: See documentation at top of file. :param output_dir_name: Same. """ if num_predictors_to_plot <= 0: num_predictors_to_plot = None if output_dir_name in ['', 'None']: output_dir_name = os.path.split(input_file_name)[0] file_system_utils.mkdir_recursive_if_necessary( directory_name=output_dir_name) print( 'Reading permutation results from: "{0:s}"...'.format(input_file_name)) permutation_dict = permutation.read_results(input_file_name) _, axes_object_matrix = plotting_utils.create_paneled_figure( num_rows=1, num_columns=2, shared_x_axis=False, shared_y_axis=True, keep_aspect_ratio=False) permutation_plotting.plot_breiman_results( permutation_dict=permutation_dict, axes_object=axes_object_matrix[0, 0], plot_percent_increase=False, num_predictors_to_plot=num_predictors_to_plot) axes_object_matrix[0, 0].set_xlabel('AUC') axes_object_matrix[0, 0].set_title('Single-pass') permutation_plotting.plot_lakshmanan_results( permutation_dict=permutation_dict, axes_object=axes_object_matrix[0, 1], plot_percent_increase=False, num_steps_to_plot=num_predictors_to_plot) axes_object_matrix[0, 1].set_xlabel('AUC') axes_object_matrix[0, 1].set_ylabel('') axes_object_matrix[0, 1].set_title('Multi-pass') pyplot.tight_layout() absolute_value_file_name = '{0:s}/permutation_absolute-values.jpg'.format( output_dir_name) print('Saving figure to file: "{0:s}"...'.format(absolute_value_file_name)) pyplot.savefig(absolute_value_file_name, dpi=FIGURE_RESOLUTION_DPI) pyplot.close() _, axes_object_matrix = plotting_utils.create_paneled_figure( num_rows=1, num_columns=2, shared_x_axis=False, shared_y_axis=True, keep_aspect_ratio=False) permutation_plotting.plot_breiman_results( permutation_dict=permutation_dict, axes_object=axes_object_matrix[0, 0], plot_percent_increase=True, num_predictors_to_plot=num_predictors_to_plot) axes_object_matrix[0, 0].set_title('Single-pass') permutation_plotting.plot_lakshmanan_results( permutation_dict=permutation_dict, axes_object=axes_object_matrix[0, 1], plot_percent_increase=True, num_steps_to_plot=num_predictors_to_plot) axes_object_matrix[0, 1].set_ylabel('') axes_object_matrix[0, 1].set_title('Multi-pass') pyplot.tight_layout() percentage_file_name = '{0:s}/permutation_percentage.jpg'.format( output_dir_name) print('Saving figure to file: "{0:s}"...'.format(percentage_file_name)) pyplot.savefig(percentage_file_name, dpi=FIGURE_RESOLUTION_DPI) pyplot.close()
def plot_many_2d_feature_maps( feature_matrix, annotation_string_by_panel, num_panel_rows, colour_map_object, colour_norm_object=None, min_colour_value=None, max_colour_value=None, figure_width_inches=DEFAULT_FIG_WIDTH_INCHES, figure_height_inches=DEFAULT_FIG_HEIGHT_INCHES, font_size=DEFAULT_FONT_SIZE): """Plots many 2-D feature maps in the same figure (one per panel). M = number of rows in spatial grid N = number of columns in spatial grid P = number of panels :param feature_matrix: M-by-N-by-P numpy array of feature values (either before or after activation function -- this method doesn't care). :param annotation_string_by_panel: length-P list of annotations. annotation_string_by_panel[k] will be printed in the bottom-center of the [k]th panel. :param num_panel_rows: Number of panel rows. :param colour_map_object: See doc for `plot_2d_feature_map`. :param colour_norm_object: Same. :param min_colour_value: Same. :param max_colour_value: Same. :param figure_width_inches: Figure width. :param figure_height_inches: Figure height. :param font_size: Font size for panel labels. :return: figure_object: See doc for `plotting_utils.create_paneled_figure`. :return: axes_object_matrix: Same. """ pyplot.rc('axes', linewidth=3) error_checking.assert_is_numpy_array(feature_matrix, num_dimensions=3) num_panels = feature_matrix.shape[-1] error_checking.assert_is_numpy_array( numpy.array(annotation_string_by_panel), exact_dimensions=numpy.array([num_panels]) ) error_checking.assert_is_integer(num_panel_rows) error_checking.assert_is_geq(num_panel_rows, 1) error_checking.assert_is_leq(num_panel_rows, num_panels) num_panel_columns = int(numpy.ceil( float(num_panels) / num_panel_rows )) figure_object, axes_object_matrix = plotting_utils.create_paneled_figure( num_rows=num_panel_rows, num_columns=num_panel_columns, figure_width_inches=figure_width_inches, figure_height_inches=figure_height_inches, horizontal_spacing=0., vertical_spacing=0., shared_x_axis=False, shared_y_axis=False, keep_aspect_ratio=False) for i in range(num_panel_rows): for j in range(num_panel_columns): this_linear_index = i * num_panel_columns + j if this_linear_index >= num_panels: axes_object_matrix[i, j].axis('off') continue plot_2d_feature_map( feature_matrix=feature_matrix[..., this_linear_index], axes_object=axes_object_matrix[i, j], font_size=font_size, colour_map_object=colour_map_object, colour_norm_object=colour_norm_object, min_colour_value=min_colour_value, max_colour_value=max_colour_value, annotation_string=annotation_string_by_panel[this_linear_index] ) return figure_object, axes_object_matrix
def _run(forward_test_file_name, backwards_test_file_name, num_predictors, confidence_level, output_file_name): """Makes figure with results of all 4 permutation tests. This is effectively the main method. :param forward_test_file_name: See documentation at top of file. :param backwards_test_file_name: Same. :param num_predictors: Same. :param confidence_level: Same. :param output_file_name: Same. """ if num_predictors <= 0: num_predictors = None file_system_utils.mkdir_recursive_if_necessary(file_name=output_file_name) print('Reading data from: "{0:s}"...'.format(forward_test_file_name)) forward_test_dict = permutation_utils.read_results(forward_test_file_name) print('Reading data from: "{0:s}"...'.format(backwards_test_file_name)) backwards_test_dict = permutation_utils.read_results( backwards_test_file_name ) figure_object, axes_object_matrix = plotting_utils.create_paneled_figure( num_rows=2, num_columns=2, shared_x_axis=False, shared_y_axis=True, keep_aspect_ratio=False, horizontal_spacing=0.1, vertical_spacing=0.05 ) permutation_plotting.plot_single_pass_test( permutation_dict=forward_test_dict, axes_object=axes_object_matrix[0, 0], plot_percent_increase=False, confidence_level=confidence_level, num_predictors_to_plot=num_predictors ) axes_object_matrix[0, 0].set_title('Forward single-pass test') axes_object_matrix[0, 0].set_xticks([]) axes_object_matrix[0, 0].set_xlabel('') plotting_utils.label_axes( axes_object=axes_object_matrix[0, 0], label_string='(a)', x_coord_normalized=-0.01, y_coord_normalized=0.925 ) permutation_plotting.plot_multipass_test( permutation_dict=forward_test_dict, axes_object=axes_object_matrix[0, 1], plot_percent_increase=False, confidence_level=confidence_level, num_predictors_to_plot=num_predictors ) axes_object_matrix[0, 1].set_title('Forward multi-pass test') axes_object_matrix[0, 1].set_xticks([]) axes_object_matrix[0, 1].set_xlabel('') axes_object_matrix[0, 1].set_ylabel('') plotting_utils.label_axes( axes_object=axes_object_matrix[0, 1], label_string='(b)', x_coord_normalized=1.15, y_coord_normalized=0.925 ) permutation_plotting.plot_single_pass_test( permutation_dict=backwards_test_dict, axes_object=axes_object_matrix[1, 0], plot_percent_increase=False, confidence_level=confidence_level, num_predictors_to_plot=num_predictors ) axes_object_matrix[1, 0].set_title('Backward single-pass test') axes_object_matrix[1, 0].set_xlabel('Area under ROC curve (AUC)') plotting_utils.label_axes( axes_object=axes_object_matrix[1, 0], label_string='(c)', x_coord_normalized=-0.01, y_coord_normalized=0.925 ) permutation_plotting.plot_multipass_test( permutation_dict=backwards_test_dict, axes_object=axes_object_matrix[1, 1], plot_percent_increase=False, confidence_level=confidence_level, num_predictors_to_plot=num_predictors ) axes_object_matrix[1, 1].set_title('Backward multi-pass test') axes_object_matrix[1, 1].set_xlabel('Area under ROC curve (AUC)') axes_object_matrix[1, 1].set_ylabel('') plotting_utils.label_axes( axes_object=axes_object_matrix[1, 1], label_string='(d)', x_coord_normalized=1.15, y_coord_normalized=0.925 ) print('Saving figure to: "{0:s}"...'.format(output_file_name)) figure_object.savefig( output_file_name, dpi=FIGURE_RESOLUTION_DPI, pad_inches=0, bbox_inches='tight' ) pyplot.close(figure_object)
def _run(include_caption, output_dir_name): """Makes animation to explain pooling. This is effectively the main method. :param include_caption: See documentation at top of file. :param output_dir_name: Same. """ file_system_utils.mkdir_recursive_if_necessary( directory_name=output_dir_name) num_input_rows = INPUT_FEATURE_MATRIX.shape[0] num_input_columns = INPUT_FEATURE_MATRIX.shape[1] num_channels = INPUT_FEATURE_MATRIX.shape[2] image_file_names = [] for i in range(num_input_rows): for j in range(num_input_columns): this_figure_object, this_axes_object_matrix = ( plotting_utils.create_paneled_figure( num_rows=num_channels, num_columns=NUM_PANEL_COLUMNS, horizontal_spacing=HORIZ_PANEL_SPACING, vertical_spacing=VERTICAL_PANEL_SPACING, shared_x_axis=False, shared_y_axis=False, keep_aspect_ratio=True)) letter_label = None for k in range(num_channels): _plot_feature_map(feature_matrix_2d=INPUT_FEATURE_MATRIX[..., k], pooled_row=i, pooled_column=j, pooled=True, axes_object=this_axes_object_matrix[k, 0]) if letter_label is None: letter_label = 'a' else: letter_label = chr(ord(letter_label) + 1) plotting_utils.label_axes( axes_object=this_axes_object_matrix[k, 0], label_string='({0:s})'.format(letter_label), font_size=PANEL_LETTER_FONT_SIZE, y_coord_normalized=0.85, x_coord_normalized=-0.02) for k in range(num_channels): _plot_feature_map(feature_matrix_2d=OUTPUT_FEATURE_MATRIX[..., k], pooled_row=i, pooled_column=j, pooled=False, axes_object=this_axes_object_matrix[k, 1]) letter_label = chr(ord(letter_label) + 1) plotting_utils.label_axes( axes_object=this_axes_object_matrix[k, 1], label_string='({0:s})'.format(letter_label), font_size=PANEL_LETTER_FONT_SIZE, y_coord_normalized=0.85, x_coord_normalized=-0.02) _plot_interpanel_lines( pooled_row=i, pooled_column=j, input_fm_axes_object=this_axes_object_matrix[k, 0], output_fm_axes_object=this_axes_object_matrix[k, 1]) if include_caption: this_figure_object.text(0.5, CAPTION_Y_COORD, FIGURE_CAPTION, fontsize=DEFAULT_FONT_SIZE, color='k', horizontalalignment='center', verticalalignment='top') image_file_names.append( '{0:s}/upsampling_animation_row{1:d}_column{2:d}.jpg'.format( output_dir_name, i, j)) print('Saving figure to: "{0:s}"...'.format(image_file_names[-1])) this_figure_object.savefig(image_file_names[-1], dpi=FIGURE_RESOLUTION_DPI, pad_inches=0, bbox_inches='tight') pyplot.close(this_figure_object) animation_file_name = '{0:s}/upsampling_animation.gif'.format( output_dir_name) print('Creating animation: "{0:s}"...'.format(animation_file_name)) imagemagick_utils.create_gif(input_file_names=image_file_names, output_file_name=animation_file_name, num_seconds_per_frame=0.5, resize_factor=0.5)
def _plot_one_example( input_feature_matrix, feature_matrix_after_conv, feature_matrix_after_activn, feature_matrix_after_bn, feature_matrix_after_pooling, output_file_name): """Plots entire figure for one example (storm object). :param input_feature_matrix: 2-D numpy array with input features. :param feature_matrix_after_conv: 2-D numpy array with features after convolution. :param feature_matrix_after_activn: 2-D numpy array with features after activation. :param feature_matrix_after_bn: 2-D numpy array with features after batch normalization. :param feature_matrix_after_pooling: 2-D numpy array with features after pooling. :param output_file_name: Path to output file. Figure will be saved here. """ num_output_channels = feature_matrix_after_conv.shape[-1] figure_object, axes_object_matrix = plotting_utils.create_paneled_figure( num_rows=num_output_channels, num_columns=NUM_PANEL_COLUMNS, horizontal_spacing=0., vertical_spacing=0., shared_x_axis=False, shared_y_axis=False, keep_aspect_ratio=True) max_colour_value = numpy.percentile( numpy.absolute(input_feature_matrix), MAX_COLOUR_PERCENTILE ) axes_object_matrix[0, 0].set_title('Input', fontsize=TITLE_FONT_SIZE) for k in range(num_output_channels): if k == 0: _plot_one_feature_map( feature_matrix_2d=input_feature_matrix[..., k], max_colour_value=max_colour_value, plot_colour_bar=False, axes_object=axes_object_matrix[k, 0] ) continue axes_object_matrix[k, 0].axis('off') colour_bar_object = plotting_utils.plot_linear_colour_bar( axes_object_or_matrix=axes_object_matrix[num_output_channels - 1, 0], data_matrix=input_feature_matrix[..., 0], colour_map_object=COLOUR_MAP_OBJECT, min_value=-1 * max_colour_value, max_value=max_colour_value, orientation_string='horizontal', padding=0.015, fraction_of_axis_length=0.9, extend_min=True, extend_max=True, font_size=DEFAULT_FONT_SIZE) tick_values = colour_bar_object.ax.get_xticks() tick_label_strings = ['{0:.1f}'.format(x) for x in tick_values] colour_bar_object.set_ticks(tick_values) colour_bar_object.set_ticklabels(tick_label_strings) letter_label = 'a' plotting_utils.label_axes( axes_object=axes_object_matrix[0, 0], label_string='({0:s})'.format(letter_label), font_size=TITLE_FONT_SIZE, x_coord_normalized=0.125, y_coord_normalized=1.025 ) this_matrix = numpy.stack( (feature_matrix_after_conv, feature_matrix_after_activn), axis=0 ) max_colour_value = numpy.percentile( numpy.absolute(this_matrix), MAX_COLOUR_PERCENTILE ) axes_object_matrix[0, 1].set_title( ' After convolution', fontsize=TITLE_FONT_SIZE) for k in range(num_output_channels): _plot_one_feature_map( feature_matrix_2d=feature_matrix_after_conv[..., k], max_colour_value=max_colour_value, plot_colour_bar=k == num_output_channels - 1, axes_object=axes_object_matrix[k, 1] ) letter_label = chr(ord(letter_label) + 1) plotting_utils.label_axes( axes_object=axes_object_matrix[k, 1], label_string='({0:s})'.format(letter_label), font_size=TITLE_FONT_SIZE, x_coord_normalized=0.125, y_coord_normalized=1.025 ) axes_object_matrix[0, 2].set_title( ' After activation', fontsize=TITLE_FONT_SIZE) for k in range(num_output_channels): _plot_one_feature_map( feature_matrix_2d=feature_matrix_after_activn[..., k], max_colour_value=max_colour_value, plot_colour_bar=k == num_output_channels - 1, axes_object=axes_object_matrix[k, 2] ) letter_label = chr(ord(letter_label) + 1) plotting_utils.label_axes( axes_object=axes_object_matrix[k, 2], label_string='({0:s})'.format(letter_label), font_size=TITLE_FONT_SIZE, x_coord_normalized=0.125, y_coord_normalized=1.025 ) max_colour_value = numpy.percentile( numpy.absolute(feature_matrix_after_bn), MAX_COLOUR_PERCENTILE ) axes_object_matrix[0, 3].set_title( ' After batch norm', fontsize=TITLE_FONT_SIZE) for k in range(num_output_channels): _plot_one_feature_map( feature_matrix_2d=feature_matrix_after_bn[..., k], max_colour_value=max_colour_value, plot_colour_bar=k == num_output_channels - 1, axes_object=axes_object_matrix[k, 3] ) letter_label = chr(ord(letter_label) + 1) plotting_utils.label_axes( axes_object=axes_object_matrix[k, 3], label_string='({0:s})'.format(letter_label), font_size=TITLE_FONT_SIZE, x_coord_normalized=0.125, y_coord_normalized=1.025 ) max_colour_value = numpy.percentile( numpy.absolute(feature_matrix_after_pooling), MAX_COLOUR_PERCENTILE ) axes_object_matrix[0, 4].set_title( 'After pooling', fontsize=TITLE_FONT_SIZE) for k in range(num_output_channels): _plot_one_feature_map( feature_matrix_2d=feature_matrix_after_pooling[..., k], max_colour_value=max_colour_value, plot_colour_bar=k == num_output_channels - 1, axes_object=axes_object_matrix[k, 4] ) letter_label = chr(ord(letter_label) + 1) plotting_utils.label_axes( axes_object=axes_object_matrix[k, 4], label_string='({0:s})'.format(letter_label), font_size=TITLE_FONT_SIZE, x_coord_normalized=0.125, y_coord_normalized=1.025 ) print('Saving figure to: "{0:s}"...'.format(output_file_name)) figure_object.savefig( output_file_name, dpi=FIGURE_RESOLUTION_DPI, pad_inches=0, bbox_inches='tight' ) pyplot.close(figure_object)
def _plot_2d_radar_scan(list_of_predictor_matrices, model_metadata_dict, allow_whitespace, title_string=None): """Plots 2-D radar scan for one example. J = number of panel rows in image K = number of panel columns in image :param list_of_predictor_matrices: See doc for `_plot_3d_radar_scan`. :param model_metadata_dict: Same. :param allow_whitespace: Same. :param title_string: Same. :return: figure_objects: length-1 list of figure handles (instances of `matplotlib.figure.Figure`). :return: axes_object_matrices: length-1 list. Each element is a J-by-K numpy array of axes handles (instances of `matplotlib.axes._subplots.AxesSubplot`). """ training_option_dict = model_metadata_dict[cnn.TRAINING_OPTION_DICT_KEY] list_of_layer_operation_dicts = model_metadata_dict[ cnn.LAYER_OPERATIONS_KEY] if list_of_layer_operation_dicts is None: field_name_by_panel = training_option_dict[ trainval_io.RADAR_FIELDS_KEY] num_panels = len(field_name_by_panel) panel_names = radar_plotting.radar_fields_and_heights_to_panel_names( field_names=field_name_by_panel, heights_m_agl=training_option_dict[trainval_io.RADAR_HEIGHTS_KEY]) plot_cbar_by_panel = numpy.full(num_panels, True, dtype=bool) else: list_of_layer_operation_dicts = [ list_of_layer_operation_dicts[k] for k in LAYER_OP_INDICES_TO_KEEP ] list_of_predictor_matrices[0] = list_of_predictor_matrices[0][ ..., LAYER_OP_INDICES_TO_KEEP] field_name_by_panel, panel_names = ( radar_plotting.layer_ops_to_field_and_panel_names( list_of_layer_operation_dicts=list_of_layer_operation_dicts)) num_panels = len(field_name_by_panel) plot_cbar_by_panel = numpy.full(num_panels, True, dtype=bool) # if allow_whitespace: # if len(field_name_by_panel) == 12: # plot_cbar_by_panel[2::3] = True # else: # plot_cbar_by_panel[:] = True num_panel_rows = int(numpy.floor(numpy.sqrt(num_panels))) num_panel_columns = int(numpy.ceil(float(num_panels) / num_panel_rows)) if allow_whitespace: figure_object = None axes_object_matrix = None else: figure_object, axes_object_matrix = ( plotting_utils.create_paneled_figure(num_rows=num_panel_rows, num_columns=num_panel_columns, horizontal_spacing=0., vertical_spacing=0., shared_x_axis=False, shared_y_axis=False, keep_aspect_ratio=True)) figure_object, axes_object_matrix = ( radar_plotting.plot_many_2d_grids_without_coords( field_matrix=numpy.flip(list_of_predictor_matrices[0], axis=0), field_name_by_panel=field_name_by_panel, panel_names=panel_names, num_panel_rows=num_panel_rows, figure_object=figure_object, axes_object_matrix=axes_object_matrix, plot_colour_bar_by_panel=plot_cbar_by_panel, font_size=FONT_SIZE_WITH_COLOUR_BARS, row_major=False)) if allow_whitespace and title_string is not None: pyplot.suptitle(title_string, fontsize=TITLE_FONT_SIZE) return [figure_object], [axes_object_matrix]
def _plot_2d3d_radar_scan(list_of_predictor_matrices, model_metadata_dict, allow_whitespace, title_string=None): """Plots 3-D reflectivity and 2-D azimuthal shear for one example. :param list_of_predictor_matrices: See doc for `_plot_3d_radar_scan`. :param model_metadata_dict: Same. :param allow_whitespace: Same. :param title_string: Same. :return: figure_objects: length-2 list of figure handles (instances of `matplotlib.figure.Figure`). The first is for reflectivity; the second is for azimuthal shear. :return: axes_object_matrices: length-2 list (the first is for reflectivity; the second is for azimuthal shear). Each element is a 2-D numpy array of axes handles (instances of `matplotlib.axes._subplots.AxesSubplot`). """ training_option_dict = model_metadata_dict[cnn.TRAINING_OPTION_DICT_KEY] az_shear_field_names = training_option_dict[trainval_io.RADAR_FIELDS_KEY] refl_heights_m_agl = training_option_dict[trainval_io.RADAR_HEIGHTS_KEY] num_az_shear_fields = len(az_shear_field_names) num_refl_heights = len(refl_heights_m_agl) this_num_panel_rows = int(numpy.floor(numpy.sqrt(num_refl_heights))) this_num_panel_columns = int( numpy.ceil(float(num_refl_heights) / this_num_panel_rows)) if allow_whitespace: refl_figure_object = None refl_axes_object_matrix = None else: refl_figure_object, refl_axes_object_matrix = ( plotting_utils.create_paneled_figure( num_rows=this_num_panel_rows, num_columns=this_num_panel_columns, horizontal_spacing=0., vertical_spacing=0., shared_x_axis=False, shared_y_axis=False, keep_aspect_ratio=True)) refl_figure_object, refl_axes_object_matrix = ( radar_plotting.plot_3d_grid_without_coords( field_matrix=numpy.flip(list_of_predictor_matrices[0][..., 0], axis=0), field_name=radar_utils.REFL_NAME, grid_point_heights_metres=refl_heights_m_agl, ground_relative=True, num_panel_rows=this_num_panel_rows, figure_object=refl_figure_object, axes_object_matrix=refl_axes_object_matrix, font_size=FONT_SIZE_SANS_COLOUR_BARS)) if allow_whitespace: this_colour_map_object, this_colour_norm_object = ( radar_plotting.get_default_colour_scheme(radar_utils.REFL_NAME)) plotting_utils.plot_colour_bar( axes_object_or_matrix=refl_axes_object_matrix, data_matrix=list_of_predictor_matrices[0], colour_map_object=this_colour_map_object, colour_norm_object=this_colour_norm_object, orientation_string='horizontal', extend_min=True, extend_max=True) if title_string is not None: this_title_string = '{0:s}; {1:s}'.format(title_string, radar_utils.REFL_NAME) pyplot.suptitle(this_title_string, fontsize=TITLE_FONT_SIZE) if allow_whitespace: shear_figure_object = None shear_axes_object_matrix = None else: shear_figure_object, shear_axes_object_matrix = ( plotting_utils.create_paneled_figure( num_rows=1, num_columns=num_az_shear_fields, horizontal_spacing=0., vertical_spacing=0., shared_x_axis=False, shared_y_axis=False, keep_aspect_ratio=True)) shear_figure_object, shear_axes_object_matrix = ( radar_plotting.plot_many_2d_grids_without_coords( field_matrix=numpy.flip(list_of_predictor_matrices[1], axis=0), field_name_by_panel=az_shear_field_names, panel_names=az_shear_field_names, num_panel_rows=1, figure_object=shear_figure_object, axes_object_matrix=shear_axes_object_matrix, plot_colour_bar_by_panel=numpy.full(num_az_shear_fields, False, dtype=bool), font_size=FONT_SIZE_SANS_COLOUR_BARS)) if allow_whitespace: this_colour_map_object, this_colour_norm_object = ( radar_plotting.get_default_colour_scheme( radar_utils.LOW_LEVEL_SHEAR_NAME)) plotting_utils.plot_colour_bar( axes_object_or_matrix=shear_axes_object_matrix, data_matrix=list_of_predictor_matrices[1], colour_map_object=this_colour_map_object, colour_norm_object=this_colour_norm_object, orientation_string='horizontal', extend_min=True, extend_max=True) if title_string is not None: pyplot.suptitle(title_string, fontsize=TITLE_FONT_SIZE) figure_objects = [refl_figure_object, shear_figure_object] axes_object_matrices = [refl_axes_object_matrix, shear_axes_object_matrix] return figure_objects, axes_object_matrices
def _plot_3d_radar_scan(list_of_predictor_matrices, model_metadata_dict, allow_whitespace, title_string=None): """Plots 3-D radar scan for one example. J = number of panel rows in image K = number of panel columns in image F = number of radar fields :param list_of_predictor_matrices: List created by `testing_io.read_specific_examples`, except that the first axis (example dimension) is removed. :param model_metadata_dict: Dictionary returned by `cnn.read_model_metadata`. :param allow_whitespace: See documentation at top of file. :param title_string: Title (may be None). :return: figure_objects: length-F list of figure handles (instances of `matplotlib.figure.Figure`). :return: axes_object_matrices: length-F list. Each element is a J-by-K numpy array of axes handles (instances of `matplotlib.axes._subplots.AxesSubplot`). """ training_option_dict = model_metadata_dict[cnn.TRAINING_OPTION_DICT_KEY] radar_field_names = training_option_dict[trainval_io.RADAR_FIELDS_KEY] radar_heights_m_agl = training_option_dict[trainval_io.RADAR_HEIGHTS_KEY] num_radar_fields = len(radar_field_names) num_radar_heights = len(radar_heights_m_agl) num_panel_rows = int(numpy.floor(numpy.sqrt(num_radar_heights))) num_panel_columns = int( numpy.ceil(float(num_radar_heights) / num_panel_rows)) figure_objects = [None] * num_radar_fields axes_object_matrices = [None] * num_radar_fields radar_matrix = list_of_predictor_matrices[0] for j in range(num_radar_fields): this_radar_matrix = numpy.flip(radar_matrix[..., j], axis=0) if not allow_whitespace: figure_objects[j], axes_object_matrices[j] = ( plotting_utils.create_paneled_figure( num_rows=num_panel_rows, num_columns=num_panel_columns, horizontal_spacing=0., vertical_spacing=0., shared_x_axis=False, shared_y_axis=False, keep_aspect_ratio=True)) figure_objects[j], axes_object_matrices[j] = ( radar_plotting.plot_3d_grid_without_coords( field_matrix=this_radar_matrix, field_name=radar_field_names[j], grid_point_heights_metres=radar_heights_m_agl, ground_relative=True, num_panel_rows=num_panel_rows, figure_object=figure_objects[j], axes_object_matrix=axes_object_matrices[j], font_size=FONT_SIZE_SANS_COLOUR_BARS)) if allow_whitespace: this_colour_map_object, this_colour_norm_object = ( radar_plotting.get_default_colour_scheme(radar_field_names[j])) plotting_utils.plot_colour_bar( axes_object_or_matrix=axes_object_matrices[j], data_matrix=this_radar_matrix, colour_map_object=this_colour_map_object, colour_norm_object=this_colour_norm_object, orientation_string='horizontal', extend_min=True, extend_max=True) if title_string is not None: this_title_string = '{0:s}; {1:s}'.format( title_string, radar_field_names[j]) pyplot.suptitle(this_title_string, fontsize=TITLE_FONT_SIZE) return figure_objects, axes_object_matrices
def _plot_one_score(score_matrix, colour_map_object, min_colour_value, max_colour_value, colour_bar_label, is_score_bias, best_model_index, output_file_name): """Plots one score. :param score_matrix: 4-D numpy array of scores, where the first axis represents dropout rate; second represents L2 weight; third represents num dense layers; and fourth is data augmentation (yes or no). :param colour_map_object: See documentation at top of file. :param min_colour_value: Minimum value in colour scheme. :param max_colour_value: Max value in colour scheme. :param colour_bar_label: Label string for colour bar. :param is_score_bias: Boolean flag. If True, score to be plotted is frequency bias, which changes settings for colour scheme. :param best_model_index: Linear index of best model. :param output_file_name: Path to output file (figure will be saved here). """ if is_score_bias: colour_map_object, colour_norm_object = _get_bias_colour_scheme( max_value=max_colour_value) else: colour_norm_object = None num_dense_layer_counts = len(DENSE_LAYER_COUNTS) num_data_aug_flags = len(DATA_AUGMENTATION_FLAGS) figure_object, axes_object_matrix = plotting_utils.create_paneled_figure( num_rows=num_dense_layer_counts * num_data_aug_flags, num_columns=1, horizontal_spacing=0.15, vertical_spacing=0.15, shared_x_axis=False, shared_y_axis=False, keep_aspect_ratio=True) axes_object_matrix = numpy.reshape( axes_object_matrix, (num_dense_layer_counts, num_data_aug_flags)) x_axis_label = r'L$_2$ weight (log$_{10}$)' y_axis_label = 'Dropout rate' x_tick_labels = ['{0:.1f}'.format(w) for w in numpy.log10(L2_WEIGHTS)] y_tick_labels = ['{0:.3f}'.format(d) for d in DROPOUT_RATES] best_model_index_tuple = numpy.unravel_index(best_model_index, score_matrix.shape) for k in range(num_dense_layer_counts): for m in range(num_data_aug_flags): model_eval.plot_hyperparam_grid( score_matrix=score_matrix[..., k, m], min_colour_value=min_colour_value, max_colour_value=max_colour_value, colour_map_object=colour_map_object, colour_norm_object=colour_norm_object, axes_object=axes_object_matrix[k, m]) axes_object_matrix[k, m].set_xticklabels( x_tick_labels, fontsize=TICK_LABEL_FONT_SIZE, rotation=90.) axes_object_matrix[k, m].set_yticklabels( y_tick_labels, fontsize=TICK_LABEL_FONT_SIZE) axes_object_matrix[k, m].set_ylabel(y_axis_label, fontsize=TICK_LABEL_FONT_SIZE) if k == num_dense_layer_counts - 1 and m == num_data_aug_flags - 1: axes_object_matrix[k, m].set_xlabel(x_axis_label) else: axes_object_matrix[k, m].set_xticks([], []) this_title_string = '{0:d} dense layer{1:s}, DA {2:s}'.format( DENSE_LAYER_COUNTS[k], 's' if DENSE_LAYER_COUNTS[k] > 1 else '', 'on' if DATA_AUGMENTATION_FLAGS[m] else 'off') axes_object_matrix[k, m].set_title(this_title_string) i = best_model_index_tuple[0] j = best_model_index_tuple[1] k = best_model_index_tuple[2] m = best_model_index_tuple[3] axes_object_matrix[k, m].plot(j, i, linestyle='None', marker=BEST_MODEL_MARKER_TYPE, markersize=BEST_MODEL_MARKER_SIZE, markerfacecolor=MARKER_COLOUR, markeredgecolor=MARKER_COLOUR, markeredgewidth=BEST_MODEL_MARKER_WIDTH) corrupt_model_indices = numpy.where(numpy.isnan( numpy.ravel(score_matrix)))[0] for this_linear_index in corrupt_model_indices: i, j, k, m = numpy.unravel_index(this_linear_index, score_matrix.shape) axes_object_matrix[k, m].plot(j, i, linestyle='None', marker=CORRUPT_MODEL_MARKER_TYPE, markersize=CORRUPT_MODEL_MARKER_SIZE, markerfacecolor=MARKER_COLOUR, markeredgecolor=MARKER_COLOUR, markeredgewidth=CORRUPT_MODEL_MARKER_WIDTH) if is_score_bias: colour_bar_object = plotting_utils.plot_colour_bar( axes_object_or_matrix=axes_object_matrix, data_matrix=score_matrix, colour_map_object=colour_map_object, colour_norm_object=colour_norm_object, orientation_string='vertical', extend_min=False, extend_max=True, font_size=DEFAULT_FONT_SIZE) tick_values = colour_bar_object.get_ticks() tick_strings = ['{0:.1f}'.format(v) for v in tick_values] colour_bar_object.set_ticks(tick_values) colour_bar_object.set_ticklabels(tick_strings) else: colour_bar_object = plotting_utils.plot_linear_colour_bar( axes_object_or_matrix=axes_object_matrix, data_matrix=score_matrix, colour_map_object=colour_map_object, min_value=min_colour_value, max_value=max_colour_value, orientation_string='vertical', extend_min=True, extend_max=True, font_size=DEFAULT_FONT_SIZE) colour_bar_object.set_label(colour_bar_label) print('Saving figure to: "{0:s}"...'.format(output_file_name)) figure_object.savefig(output_file_name, dpi=FIGURE_RESOLUTION_DPI, pad_inches=0, bbox_inches='tight') pyplot.close(figure_object)
def _run(exp1_permutation_dir_name, exp2_permutation_dir_name, use_forward_test, use_multipass_test, output_file_name): """Creates figure showing permutation-test results for both models. This is effectively the main method. :param exp1_permutation_dir_name: See documentation at top of file. :param exp2_permutation_dir_name: Same. :param use_forward_test: Same. :param use_multipass_test: Same. :param output_file_name: Same. """ file_system_utils.mkdir_recursive_if_necessary(file_name=output_file_name) exp1_flux_file_name = '{0:s}/{1:s}_perm_test_fluxes-only.nc'.format( exp1_permutation_dir_name, 'forward' if use_forward_test else 'backwards') exp1_heating_rate_file_name = '{0:s}/{1:s}_perm_test_hr-only.nc'.format( exp1_permutation_dir_name, 'forward' if use_forward_test else 'backwards') exp2_flux_file_name = '{0:s}/{1:s}_perm_test_fluxes-only.nc'.format( exp2_permutation_dir_name, 'forward' if use_forward_test else 'backwards') exp2_heating_rate_file_name = '{0:s}/{1:s}_perm_test_hr-only.nc'.format( exp2_permutation_dir_name, 'forward' if use_forward_test else 'backwards') print('Reading data from: "{0:s}"...'.format(exp1_heating_rate_file_name)) exp1_heating_permutation_dict = ml4rt_permutation.read_file( exp1_heating_rate_file_name) exp1_heating_permutation_dict = _results_to_gg_format( exp1_heating_permutation_dict) figure_object, axes_object_matrix = plotting_utils.create_paneled_figure( num_rows=2, num_columns=2, shared_x_axis=False, shared_y_axis=True, keep_aspect_ratio=False, horizontal_spacing=0.25, vertical_spacing=0.25) if use_multipass_test: permutation_plotting.plot_multipass_test( permutation_dict=exp1_heating_permutation_dict, axes_object=axes_object_matrix[0, 0], plot_percent_increase=False, confidence_level=CONFIDENCE_LEVEL) else: permutation_plotting.plot_single_pass_test( permutation_dict=exp1_heating_permutation_dict, axes_object=axes_object_matrix[0, 0], plot_percent_increase=False, confidence_level=CONFIDENCE_LEVEL) plotting_utils.label_axes(axes_object=axes_object_matrix[0, 0], label_string='(a)', font_size=30, x_coord_normalized=0.1, y_coord_normalized=1.01) axes_object_matrix[0, 0].set_title('Exp 1, heating rates only') axes_object_matrix[0, 0].set_xlabel(r'Dual-weighted MSE (K$^3$ day$^{-3}$)') axes_object_matrix[0, 0].set_ylabel('') print('Reading data from: "{0:s}"...'.format(exp1_flux_file_name)) exp1_flux_permutation_dict = ml4rt_permutation.read_file( exp1_flux_file_name) exp1_flux_permutation_dict = _results_to_gg_format( exp1_flux_permutation_dict) if use_multipass_test: permutation_plotting.plot_multipass_test( permutation_dict=exp1_flux_permutation_dict, axes_object=axes_object_matrix[0, 1], plot_percent_increase=False, confidence_level=CONFIDENCE_LEVEL) else: permutation_plotting.plot_single_pass_test( permutation_dict=exp1_flux_permutation_dict, axes_object=axes_object_matrix[0, 1], plot_percent_increase=False, confidence_level=CONFIDENCE_LEVEL) plotting_utils.label_axes(axes_object=axes_object_matrix[0, 1], label_string='(b)', font_size=30, x_coord_normalized=0.1, y_coord_normalized=1.01) axes_object_matrix[0, 1].set_title('Exp 1, fluxes only') axes_object_matrix[0, 1].set_xlabel(r'MSE (K day$^{-1}$)') axes_object_matrix[0, 1].set_ylabel('') print('Reading data from: "{0:s}"...'.format(exp2_heating_rate_file_name)) exp2_heating_permutation_dict = ml4rt_permutation.read_file( exp2_heating_rate_file_name) exp2_heating_permutation_dict = _results_to_gg_format( exp2_heating_permutation_dict) if use_multipass_test: permutation_plotting.plot_multipass_test( permutation_dict=exp2_heating_permutation_dict, axes_object=axes_object_matrix[1, 0], plot_percent_increase=False, confidence_level=CONFIDENCE_LEVEL) else: permutation_plotting.plot_single_pass_test( permutation_dict=exp2_heating_permutation_dict, axes_object=axes_object_matrix[1, 0], plot_percent_increase=False, confidence_level=CONFIDENCE_LEVEL) plotting_utils.label_axes(axes_object=axes_object_matrix[1, 0], label_string='(c)', font_size=30, x_coord_normalized=0.1, y_coord_normalized=1.01) axes_object_matrix[1, 0].set_title('Exp 2, heating rates only') axes_object_matrix[1, 0].set_xlabel(r'Dual-weighted MSE (K$^3$ day$^{-3}$)') axes_object_matrix[1, 0].set_ylabel('') print('Reading data from: "{0:s}"...'.format(exp2_flux_file_name)) exp2_flux_permutation_dict = ml4rt_permutation.read_file( exp2_flux_file_name) exp2_flux_permutation_dict = _results_to_gg_format( exp2_flux_permutation_dict) if use_multipass_test: permutation_plotting.plot_multipass_test( permutation_dict=exp2_flux_permutation_dict, axes_object=axes_object_matrix[1, 1], plot_percent_increase=False, confidence_level=CONFIDENCE_LEVEL) else: permutation_plotting.plot_single_pass_test( permutation_dict=exp2_flux_permutation_dict, axes_object=axes_object_matrix[1, 1], plot_percent_increase=False, confidence_level=CONFIDENCE_LEVEL) plotting_utils.label_axes(axes_object=axes_object_matrix[1, 1], label_string='(d)', font_size=30, x_coord_normalized=0.1, y_coord_normalized=1.01) axes_object_matrix[1, 1].set_title('Exp 2, fluxes only') axes_object_matrix[1, 1].set_xlabel(r'MSE (K day$^{-1}$)') axes_object_matrix[1, 1].set_ylabel('') print('Saving figure to: "{0:s}"...'.format(output_file_name)) figure_object.savefig(output_file_name, dpi=FIGURE_RESOLUTION_DPI, pad_inches=0, bbox_inches='tight') pyplot.close(figure_object)