def _run(): """Analyzes backwards-optimization experiment. This is effectively the main method. """ num_l2_weights = len(L2_WEIGHTS) num_minmax_weights = len(MINMAX_WEIGHTS) mean_final_activation_matrix = numpy.full( (num_l2_weights, num_minmax_weights), numpy.nan) for i in range(num_l2_weights): for j in range(num_minmax_weights): this_file_name = ( '{0:s}/bwo_pmm_l2-weight={1:.10f}_minmax-weight={2:.10f}.p' ).format(TOP_EXPERIMENT_DIR_NAME, L2_WEIGHTS[i], MINMAX_WEIGHTS[j]) print('Reading data from: "{0:s}"...'.format(this_file_name)) this_bwo_dict = backwards_opt.read_file(this_file_name)[0] mean_final_activation_matrix[i, j] = this_bwo_dict[ backwards_opt.MEAN_FINAL_ACTIVATION_KEY] x_tick_labels = ['{0:.1f}'.format(r) for r in numpy.log10(MINMAX_WEIGHTS)] y_tick_labels = ['{0:.1f}'.format(w) for w in numpy.log10(L2_WEIGHTS)] axes_object = model_evaluation.plot_hyperparam_grid( score_matrix=mean_final_activation_matrix, colour_map_object=COLOUR_MAP_OBJECT, min_colour_value=0., max_colour_value=1.) axes_object.set_xticklabels(x_tick_labels, rotation=90.) axes_object.set_yticklabels(y_tick_labels) axes_object.set_xlabel(r'Min-max weight (log$_{10}$)') axes_object.set_ylabel(r'L$_2$ weight (log$_{10}$)') plotting_utils.plot_linear_colour_bar( axes_object_or_matrix=axes_object, data_matrix=mean_final_activation_matrix, colour_map_object=COLOUR_MAP_OBJECT, min_value=0., max_value=1., orientation_string='vertical', extend_min=False, extend_max=False, font_size=FONT_SIZE) output_file_name = '{0:s}/mean_final_activations.jpg'.format( TOP_EXPERIMENT_DIR_NAME) print('Saving figure to: "{0:s}"...'.format(output_file_name)) pyplot.savefig(output_file_name, dpi=FIGURE_RESOLUTION_DPI, pad_inches=0, bbox_inches='tight') pyplot.close()
def _plot_feature_maps_one_layer( feature_matrix, example_id_strings, layer_name, output_dir_name): """Plots feature maps for one layer. E = number of examples H = number of heights C = number of channels :param feature_matrix: E-by-H-by-C numpy array of feature maps. :param example_id_strings: length-E list of example IDs. :param layer_name: Name of layer that generated feature maps. :param output_dir_name: Name of output directory. Figures will be saved here. """ error_checking.assert_is_numpy_array(feature_matrix, num_dimensions=3) num_examples = feature_matrix.shape[0] # TODO(thunderhoser): Maybe define colour limits differently? max_colour_value = numpy.percentile(numpy.absolute(feature_matrix), 99) min_colour_value = -1 * max_colour_value for i in range(num_examples): this_figure_object, this_axes_object_matrix = ( feature_map_plotting.plot_many_1d_feature_maps( feature_matrix=feature_matrix[i, ...], colour_map_object=COLOUR_MAP_OBJECT, min_colour_value=min_colour_value, max_colour_value=max_colour_value) ) plotting_utils.plot_linear_colour_bar( axes_object_or_matrix=this_axes_object_matrix, data_matrix=feature_matrix[i, ...], colour_map_object=COLOUR_MAP_OBJECT, min_value=min_colour_value, max_value=max_colour_value, orientation_string='horizontal', padding=0.01, extend_min=True, extend_max=True ) this_title_string = 'Layer "{0:s}", example "{1:s}"'.format( layer_name, example_id_strings[i] ) this_figure_object.suptitle(this_title_string, fontsize=25) this_file_name = '{0:s}/{1:s}.jpg'.format( output_dir_name, example_id_strings[i] ) print('Saving figure to: "{0:s}"...'.format(this_file_name)) this_figure_object.savefig( this_file_name, dpi=FIGURE_RESOLUTION_DPI, pad_inches=0, bbox_inches='tight' ) pyplot.close(this_figure_object)
def _add_colour_bar(axes_object, basemap_object, colour_map_object, colour_norm_object): """Adds colour bar to figure. :param axes_object: See input doc for `plot_storm_tracks`. :param basemap_object: Same. :param colour_map_object: See output doc for `_process_colour_args`. :param colour_norm_object: Same. :return: colour_bar_object: Handle for colour bar. """ latitude_range_deg = basemap_object.urcrnrlat - basemap_object.llcrnrlat longitude_range_deg = basemap_object.urcrnrlon - basemap_object.llcrnrlon if latitude_range_deg > longitude_range_deg: orientation_string = 'vertical' padding = None else: orientation_string = 'horizontal' padding = 0.05 dummy_values = numpy.array([0, 1e12], dtype=int) colour_bar_object = plotting_utils.plot_linear_colour_bar( axes_object_or_matrix=axes_object, data_matrix=dummy_values, colour_map_object=colour_map_object, min_value=colour_norm_object.vmin, max_value=colour_norm_object.vmax, orientation_string=orientation_string, padding=padding, extend_min=False, extend_max=False, fraction_of_axis_length=0.9, font_size=COLOUR_BAR_FONT_SIZE) tick_times_unix_sec = numpy.round( colour_bar_object.get_ticks()).astype(int) tick_time_strings = [ time_conversion.unix_sec_to_string(t, COLOUR_BAR_TIME_FORMAT) for t in tick_times_unix_sec ] colour_bar_object.set_ticks(tick_times_unix_sec) colour_bar_object.set_ticklabels(tick_time_strings) return colour_bar_object
def _plot_one_feature_map(feature_matrix_2d, max_colour_value, plot_colour_bar, axes_object): """Plots one feature map. M = number of rows in grid N = number of columns in grid :param feature_matrix_2d: M-by-N numpy array of feature values. :param max_colour_value: Max value in colour scheme. :param plot_colour_bar: Boolean flag. :param axes_object: Will plot on these axes (instance of `matplotlib.axes._subplots.AxesSubplot`). """ min_colour_value = -1 * max_colour_value axes_object.pcolormesh( feature_matrix_2d, cmap=COLOUR_MAP_OBJECT, vmin=min_colour_value, vmax=max_colour_value, shading='flat', edgecolors='None') axes_object.set_xlim(0., feature_matrix_2d.shape[1]) axes_object.set_ylim(0., feature_matrix_2d.shape[0]) axes_object.set_xticks([]) axes_object.set_yticks([]) if not plot_colour_bar: return colour_bar_object = plotting_utils.plot_linear_colour_bar( axes_object_or_matrix=axes_object, data_matrix=feature_matrix_2d, colour_map_object=COLOUR_MAP_OBJECT, min_value=min_colour_value, max_value=max_colour_value, orientation_string='horizontal', padding=0.015, fraction_of_axis_length=0.9, extend_min=True, extend_max=True, font_size=DEFAULT_FONT_SIZE) tick_values = colour_bar_object.ax.get_xticks() tick_label_strings = ['{0:.1f}'.format(x) for x in tick_values] colour_bar_object.set_ticks(tick_values) colour_bar_object.set_ticklabels(tick_label_strings)
def _plot_saliency_vector_p_vector_t(saliency_matrix, predictor_names, target_names, height_labels, example_id_string, colour_map_object, max_colour_percentile, output_dir_name): """Plots saliency for one example: vector predictors, vector targets. P = number of predictor variables T = number of target variables H = number of heights :param saliency_matrix: H-by-P-by-H-by-T numpy array of saliency values. :param predictor_names: length-P list of predictor names. :param target_names: length-T list of target names. :param height_labels: length-H list of height labels (strings). :param example_id_string: Example ID. :param colour_map_object: See documentation at top of file. :param max_colour_percentile: Same. :param output_dir_name: Same. """ predictor_names_verbose = [ PREDICTOR_NAME_TO_VERBOSE[n] for n in predictor_names ] target_names_verbose = [TARGET_NAME_TO_VERBOSE[n] for n in target_names] num_targets = len(target_names) num_predictors = len(predictor_names) num_heights = len(height_labels) for j in range(num_predictors): for k in range(num_targets): max_colour_value = numpy.percentile( numpy.abs(saliency_matrix[:, j, :, k]), max_colour_percentile) max_colour_value = numpy.maximum(max_colour_value, 0.001) min_colour_value = -1 * max_colour_value figure_object, axes_object = pyplot.subplots( 1, 1, figsize=(FIGURE_WIDTH_INCHES, FIGURE_HEIGHT_INCHES)) axes_object.imshow(numpy.transpose(saliency_matrix[:, j, :, k]), cmap=colour_map_object, vmin=min_colour_value, vmax=max_colour_value, origin='lower') tick_values = numpy.linspace(0, num_heights - 1, num=num_heights, dtype=float) axes_object.set_xticks(tick_values) axes_object.set_yticks(tick_values) axes_object.set_xticklabels(height_labels, fontsize=TICK_LABEL_FONT_SIZE, rotation=90.) axes_object.set_yticklabels(height_labels, fontsize=TICK_LABEL_FONT_SIZE) axes_object.set_xlabel('Predictor height (km AGL)') axes_object.set_ylabel('Target height (km AGL)') axes_object.plot(axes_object.get_xlim(), axes_object.get_ylim(), color=REFERENCE_LINE_COLOUR, linestyle='dashed', linewidth=REFERENCE_LINE_WIDTH) colour_bar_object = plotting_utils.plot_linear_colour_bar( axes_object_or_matrix=axes_object, data_matrix=saliency_matrix[:, j, :, k], colour_map_object=colour_map_object, min_value=min_colour_value, max_value=max_colour_value, orientation_string='horizontal', padding=0.1, extend_min=True, extend_max=True, fraction_of_axis_length=0.8, font_size=DEFAULT_FONT_SIZE) tick_values = colour_bar_object.get_ticks() tick_strings = ['{0:.1f}'.format(v) for v in tick_values] colour_bar_object.set_ticks(tick_values) colour_bar_object.set_ticklabels(tick_strings) title_string = 'Saliency for {0:s} with respect to {1:s}'.format( target_names_verbose[k], predictor_names_verbose[j]) axes_object.set_title(title_string, fontsize=DEFAULT_FONT_SIZE) output_file_name = '{0:s}/{1:s}_{2:s}_{3:s}.jpg'.format( output_dir_name, example_id_string.replace('_', '-'), predictor_names[j].replace('_', '-'), target_names[k].replace('_', '-')) print('Saving figure to: "{0:s}"...'.format(output_file_name)) figure_object.savefig(output_file_name, dpi=FIGURE_RESOLUTION_DPI, pad_inches=0, bbox_inches='tight') pyplot.close(figure_object)
def _plot_saliency_scalar_p_scalar_t(saliency_matrix, predictor_names, target_names, example_id_string, colour_map_object, max_colour_percentile, output_dir_name): """Plots saliency for one example: scalar predictors, scalar targets. P = number of predictor variables T = number of target variables :param saliency_matrix: P-by-T numpy array of saliency values. :param predictor_names: length-P list of predictor names. :param target_names: length-T list of target names. :param example_id_string: Example ID. :param colour_map_object: See documentation at top of file. :param max_colour_percentile: Same. :param output_dir_name: Same. """ predictor_names_verbose = [ PREDICTOR_NAME_TO_VERBOSE[n] for n in predictor_names ] target_names_verbose = [TARGET_NAME_TO_VERBOSE[n] for n in target_names] max_colour_value = numpy.percentile(numpy.absolute(saliency_matrix), max_colour_percentile) max_colour_value = numpy.maximum(max_colour_value, 0.001) min_colour_value = -1 * max_colour_value figure_object, axes_object = pyplot.subplots( 1, 1, figsize=(FIGURE_WIDTH_INCHES, FIGURE_HEIGHT_INCHES)) axes_object.imshow(numpy.transpose(saliency_matrix), cmap=colour_map_object, vmin=min_colour_value, vmax=max_colour_value, origin='lower') num_predictors = len(predictor_names) num_targets = len(target_names) x_tick_values = numpy.linspace(0, num_predictors - 1, num=num_predictors, dtype=float) y_tick_values = numpy.linspace(0, num_targets - 1, num=num_targets, dtype=float) axes_object.set_xticks(x_tick_values) axes_object.set_yticks(y_tick_values) x_tick_labels = [ '{0:s}{1:s}'.format(n[0].upper(), n[1:]) for n in predictor_names_verbose ] y_tick_labels = [ '{0:s}{1:s}'.format(n[0].upper(), n[1:]) for n in target_names_verbose ] axes_object.set_xticklabels(x_tick_labels, fontsize=TICK_LABEL_FONT_SIZE, rotation=90.) axes_object.set_yticklabels(y_tick_labels, fontsize=TICK_LABEL_FONT_SIZE) axes_object.set_xlabel('Predictor') axes_object.set_ylabel('Target') orientation_string = ('horizontal' if len(x_tick_values) >= len(y_tick_values) else 'vertical') colour_bar_object = plotting_utils.plot_linear_colour_bar( axes_object_or_matrix=axes_object, data_matrix=saliency_matrix, colour_map_object=colour_map_object, min_value=min_colour_value, max_value=max_colour_value, orientation_string=orientation_string, padding=0.1 if orientation_string == 'horizontal' else 0.01, extend_min=True, extend_max=True, fraction_of_axis_length=0.8, font_size=DEFAULT_FONT_SIZE) tick_values = colour_bar_object.get_ticks() tick_strings = ['{0:.1f}'.format(v) for v in tick_values] colour_bar_object.set_ticks(tick_values) colour_bar_object.set_ticklabels(tick_strings) axes_object.set_title( 'Saliency for scalar targets with respect to scalar predictors', fontsize=DEFAULT_FONT_SIZE) output_file_name = '{0:s}/{1:s}_scalars.jpg'.format( output_dir_name, example_id_string.replace('_', '-')) print('Saving figure to: "{0:s}"...'.format(output_file_name)) figure_object.savefig(output_file_name, dpi=FIGURE_RESOLUTION_DPI, pad_inches=0, bbox_inches='tight') pyplot.close(figure_object)
def _plot_one_value(data_matrix, grid_metadata_dict, colour_map_object, min_colour_value, max_colour_value, plot_cbar_min_arrow, plot_cbar_max_arrow, log_scale=False): """Plots one value (score, num examples, or num positive examples). M = number of rows in grid N = number of columns in grid :param data_matrix: M-by-N numpy array of values to plot. :param grid_metadata_dict: Dictionary returned by `grids.read_equidistant_metafile`. :param colour_map_object: See documentation at top of file. :param min_colour_value: Minimum value in colour scheme. :param max_colour_value: Max value in colour scheme. :param plot_cbar_min_arrow: Boolean flag. If True, will plot arrow at bottom of colour bar (to signify that lower values are possible). :param plot_cbar_max_arrow: Boolean flag. If True, will plot arrow at top of colour bar (to signify that higher values are possible). :param log_scale: Boolean flag (True if `data_matrix` contains data in log scale). :return: figure_object: Figure handle (instance of `matplotlib.figure.Figure`). :return: axes_object: Axes handle (instance of `matplotlib.axes._subplots.AxesSubplot`). """ figure_object, axes_object = pyplot.subplots( 1, 1, figsize=(FIGURE_WIDTH_INCHES, FIGURE_HEIGHT_INCHES)) basemap_object, basemap_x_matrix_metres, basemap_y_matrix_metres = ( _get_basemap(grid_metadata_dict)) num_grid_rows = data_matrix.shape[0] num_grid_columns = data_matrix.shape[1] x_spacing_metres = ( (basemap_x_matrix_metres[0, -1] - basemap_x_matrix_metres[0, 0]) / (num_grid_columns - 1)) y_spacing_metres = ( (basemap_y_matrix_metres[-1, 0] - basemap_y_matrix_metres[0, 0]) / (num_grid_rows - 1)) data_matrix_at_edges, edge_x_coords_metres, edge_y_coords_metres = ( grids.xy_field_grid_points_to_edges( field_matrix=data_matrix, x_min_metres=basemap_x_matrix_metres[0, 0], y_min_metres=basemap_y_matrix_metres[0, 0], x_spacing_metres=x_spacing_metres, y_spacing_metres=y_spacing_metres)) data_matrix_at_edges = numpy.ma.masked_where( numpy.isnan(data_matrix_at_edges), data_matrix_at_edges) # data_matrix_at_edges[numpy.isnan(data_matrix_at_edges)] = -1 plotting_utils.plot_coastlines(basemap_object=basemap_object, axes_object=axes_object, line_colour=BORDER_COLOUR) plotting_utils.plot_countries(basemap_object=basemap_object, axes_object=axes_object, line_colour=BORDER_COLOUR) plotting_utils.plot_states_and_provinces(basemap_object=basemap_object, axes_object=axes_object, line_colour=BORDER_COLOUR) plotting_utils.plot_parallels(basemap_object=basemap_object, axes_object=axes_object, num_parallels=NUM_PARALLELS) plotting_utils.plot_meridians(basemap_object=basemap_object, axes_object=axes_object, num_meridians=NUM_MERIDIANS) basemap_object.pcolormesh(edge_x_coords_metres, edge_y_coords_metres, data_matrix_at_edges, cmap=colour_map_object, vmin=min_colour_value, vmax=max_colour_value, shading='flat', edgecolors='None', axes=axes_object, zorder=-1e12) colour_bar_object = plotting_utils.plot_linear_colour_bar( axes_object_or_matrix=axes_object, data_matrix=data_matrix, colour_map_object=colour_map_object, min_value=min_colour_value, max_value=max_colour_value, orientation_string='horizontal', extend_min=plot_cbar_min_arrow, extend_max=plot_cbar_max_arrow, padding=0.05) tick_values = colour_bar_object.get_ticks() if log_scale: tick_strings = [ '{0:d}'.format(int(numpy.round(10**v))) for v in tick_values ] elif numpy.nanmax(data_matrix) >= 6: tick_strings = [ '{0:d}'.format(int(numpy.round(v))) for v in tick_values ] else: tick_strings = ['{0:.2f}'.format(v) for v in tick_values] colour_bar_object.set_ticks(tick_values) colour_bar_object.set_ticklabels(tick_strings) return figure_object, axes_object
def _plot_2d_radar_saliency( saliency_matrix, colour_map_object, max_colour_value, half_num_contours, label_colour_bars, colour_bar_length, figure_objects, axes_object_matrices, model_metadata_dict, output_dir_name, significance_matrix=None, full_storm_id_string=None, storm_time_unix_sec=None): """Plots saliency map for 2-D radar data. M = number of rows in spatial grid N = number of columns in spatial grid C = number of radar channels If this method is plotting a composite rather than single example (storm object), `full_storm_id_string` and `storm_time_unix_sec` can be None. :param saliency_matrix: M-by-N-by-C numpy array of saliency values. :param colour_map_object: See documentation at top of file. :param max_colour_value: Same. :param half_num_contours: Same. :param label_colour_bars: Same. :param colour_bar_length: Same. :param figure_objects: See doc for `plot_input_examples._plot_2d_radar_scan`. :param axes_object_matrices: Same. :param model_metadata_dict: Dictionary returned by `cnn.read_model_metadata`. :param output_dir_name: Path to output directory. Figure(s) will be saved here. :param significance_matrix: M-by-N-by-H numpy array of Boolean flags, indicating where differences with some other saliency map are significant. :param full_storm_id_string: Full storm ID. :param storm_time_unix_sec: Storm time. """ if max_colour_value is None: max_colour_value = numpy.percentile( numpy.absolute(saliency_matrix), MAX_COLOUR_PERCENTILE ) pmm_flag = full_storm_id_string is None and storm_time_unix_sec is None conv_2d3d = model_metadata_dict[cnn.CONV_2D3D_KEY] if conv_2d3d: figure_index = 1 radar_field_name = 'shear' else: figure_index = 0 radar_field_name = None saliency_plotting.plot_many_2d_grids_with_contours( saliency_matrix_3d=numpy.flip(saliency_matrix, axis=0), axes_object_matrix=axes_object_matrices[figure_index], colour_map_object=colour_map_object, max_absolute_contour_level=max_colour_value, contour_interval=max_colour_value / half_num_contours, row_major=False) if significance_matrix is not None: significance_plotting.plot_many_2d_grids_without_coords( significance_matrix=numpy.flip(significance_matrix, axis=0), axes_object_matrix=axes_object_matrices[figure_index], row_major=False) colour_bar_object = plotting_utils.plot_linear_colour_bar( axes_object_or_matrix=axes_object_matrices[figure_index], data_matrix=saliency_matrix, colour_map_object=colour_map_object, min_value=0., max_value=max_colour_value, orientation_string='horizontal', fraction_of_axis_length=colour_bar_length / (1 + int(conv_2d3d)), extend_min=False, extend_max=True, font_size=COLOUR_BAR_FONT_SIZE) if label_colour_bars: colour_bar_object.set_label( 'Absolute saliency', fontsize=COLOUR_BAR_FONT_SIZE) output_file_name = plot_examples.metadata_to_file_name( output_dir_name=output_dir_name, is_sounding=False, pmm_flag=pmm_flag, full_storm_id_string=full_storm_id_string, storm_time_unix_sec=storm_time_unix_sec, radar_field_name=radar_field_name) print('Saving figure to: "{0:s}"...'.format(output_file_name)) figure_objects[figure_index].savefig( output_file_name, dpi=FIGURE_RESOLUTION_DPI, pad_inches=0, bbox_inches='tight' ) pyplot.close(figure_objects[figure_index])
def _plot_one_score(score_matrix, colour_map_object, min_colour_value, max_colour_value, colour_bar_label, is_score_bias, best_model_index, output_file_name): """Plots one score. :param score_matrix: 4-D numpy array of scores, where the first axis represents dropout rate; second represents L2 weight; third represents num dense layers; and fourth is data augmentation (yes or no). :param colour_map_object: See documentation at top of file. :param min_colour_value: Minimum value in colour scheme. :param max_colour_value: Max value in colour scheme. :param colour_bar_label: Label string for colour bar. :param is_score_bias: Boolean flag. If True, score to be plotted is frequency bias, which changes settings for colour scheme. :param best_model_index: Linear index of best model. :param output_file_name: Path to output file (figure will be saved here). """ if is_score_bias: colour_map_object, colour_norm_object = _get_bias_colour_scheme( max_value=max_colour_value) else: colour_norm_object = None num_dense_layer_counts = len(DENSE_LAYER_COUNTS) num_data_aug_flags = len(DATA_AUGMENTATION_FLAGS) figure_object, axes_object_matrix = plotting_utils.create_paneled_figure( num_rows=num_dense_layer_counts * num_data_aug_flags, num_columns=1, horizontal_spacing=0.15, vertical_spacing=0.15, shared_x_axis=False, shared_y_axis=False, keep_aspect_ratio=True) axes_object_matrix = numpy.reshape( axes_object_matrix, (num_dense_layer_counts, num_data_aug_flags)) x_axis_label = r'L$_2$ weight (log$_{10}$)' y_axis_label = 'Dropout rate' x_tick_labels = ['{0:.1f}'.format(w) for w in numpy.log10(L2_WEIGHTS)] y_tick_labels = ['{0:.3f}'.format(d) for d in DROPOUT_RATES] best_model_index_tuple = numpy.unravel_index(best_model_index, score_matrix.shape) for k in range(num_dense_layer_counts): for m in range(num_data_aug_flags): model_eval.plot_hyperparam_grid( score_matrix=score_matrix[..., k, m], min_colour_value=min_colour_value, max_colour_value=max_colour_value, colour_map_object=colour_map_object, colour_norm_object=colour_norm_object, axes_object=axes_object_matrix[k, m]) axes_object_matrix[k, m].set_xticklabels( x_tick_labels, fontsize=TICK_LABEL_FONT_SIZE, rotation=90.) axes_object_matrix[k, m].set_yticklabels( y_tick_labels, fontsize=TICK_LABEL_FONT_SIZE) axes_object_matrix[k, m].set_ylabel(y_axis_label, fontsize=TICK_LABEL_FONT_SIZE) if k == num_dense_layer_counts - 1 and m == num_data_aug_flags - 1: axes_object_matrix[k, m].set_xlabel(x_axis_label) else: axes_object_matrix[k, m].set_xticks([], []) this_title_string = '{0:d} dense layer{1:s}, DA {2:s}'.format( DENSE_LAYER_COUNTS[k], 's' if DENSE_LAYER_COUNTS[k] > 1 else '', 'on' if DATA_AUGMENTATION_FLAGS[m] else 'off') axes_object_matrix[k, m].set_title(this_title_string) i = best_model_index_tuple[0] j = best_model_index_tuple[1] k = best_model_index_tuple[2] m = best_model_index_tuple[3] axes_object_matrix[k, m].plot(j, i, linestyle='None', marker=BEST_MODEL_MARKER_TYPE, markersize=BEST_MODEL_MARKER_SIZE, markerfacecolor=MARKER_COLOUR, markeredgecolor=MARKER_COLOUR, markeredgewidth=BEST_MODEL_MARKER_WIDTH) corrupt_model_indices = numpy.where(numpy.isnan( numpy.ravel(score_matrix)))[0] for this_linear_index in corrupt_model_indices: i, j, k, m = numpy.unravel_index(this_linear_index, score_matrix.shape) axes_object_matrix[k, m].plot(j, i, linestyle='None', marker=CORRUPT_MODEL_MARKER_TYPE, markersize=CORRUPT_MODEL_MARKER_SIZE, markerfacecolor=MARKER_COLOUR, markeredgecolor=MARKER_COLOUR, markeredgewidth=CORRUPT_MODEL_MARKER_WIDTH) if is_score_bias: colour_bar_object = plotting_utils.plot_colour_bar( axes_object_or_matrix=axes_object_matrix, data_matrix=score_matrix, colour_map_object=colour_map_object, colour_norm_object=colour_norm_object, orientation_string='vertical', extend_min=False, extend_max=True, font_size=DEFAULT_FONT_SIZE) tick_values = colour_bar_object.get_ticks() tick_strings = ['{0:.1f}'.format(v) for v in tick_values] colour_bar_object.set_ticks(tick_values) colour_bar_object.set_ticklabels(tick_strings) else: colour_bar_object = plotting_utils.plot_linear_colour_bar( axes_object_or_matrix=axes_object_matrix, data_matrix=score_matrix, colour_map_object=colour_map_object, min_value=min_colour_value, max_value=max_colour_value, orientation_string='vertical', extend_min=True, extend_max=True, font_size=DEFAULT_FONT_SIZE) colour_bar_object.set_label(colour_bar_label) print('Saving figure to: "{0:s}"...'.format(output_file_name)) figure_object.savefig(output_file_name, dpi=FIGURE_RESOLUTION_DPI, pad_inches=0, bbox_inches='tight') pyplot.close(figure_object)
def _add_colour_bar(figure_file_name, colour_map_object, max_colour_value, temporary_dir_name): """Adds colour bar to saved image file. :param figure_file_name: Path to saved image file. Colour bar will be added to this image. :param colour_map_object: Colour scheme (instance of `matplotlib.pyplot.cm` or similar). :param max_colour_value: Max value in colour scheme. :param temporary_dir_name: Name of temporary output directory. """ this_image_matrix = Image.open(figure_file_name) figure_width_px, figure_height_px = this_image_matrix.size figure_width_inches = float(figure_width_px) / FIGURE_RESOLUTION_DPI figure_height_inches = float(figure_height_px) / FIGURE_RESOLUTION_DPI extra_figure_object, extra_axes_object = pyplot.subplots( 1, 1, figsize=(figure_width_inches, figure_height_inches)) extra_axes_object.axis('off') dummy_values = numpy.array([0., max_colour_value]) colour_bar_object = plotting_utils.plot_linear_colour_bar( axes_object_or_matrix=extra_axes_object, data_matrix=dummy_values, colour_map_object=colour_map_object, min_value=0., max_value=max_colour_value, orientation_string='vertical', fraction_of_axis_length=1.25, extend_min=False, extend_max=True, font_size=COLOUR_BAR_FONT_SIZE, aspect_ratio=50.) tick_values = colour_bar_object.get_ticks() if max_colour_value <= 0.005: tick_strings = ['{0:.4f}'.format(v) for v in tick_values] elif max_colour_value <= 0.05: tick_strings = ['{0:.3f}'.format(v) for v in tick_values] else: tick_strings = ['{0:.2f}'.format(v) for v in tick_values] colour_bar_object.set_ticks(tick_values) colour_bar_object.set_ticklabels(tick_strings) extra_file_name = '{0:s}/saliency_colour-bar.jpg'.format( temporary_dir_name) print('Saving colour bar to: "{0:s}"...'.format(extra_file_name)) extra_figure_object.savefig(extra_file_name, dpi=FIGURE_RESOLUTION_DPI, pad_inches=0, bbox_inches='tight') pyplot.close(extra_figure_object) print('Concatenating colour bar to: "{0:s}"...'.format(figure_file_name)) imagemagick_utils.concatenate_images( input_file_names=[figure_file_name, extra_file_name], output_file_name=figure_file_name, num_panel_rows=1, num_panel_columns=2, extra_args_string='-gravity Center') os.remove(extra_file_name) imagemagick_utils.trim_whitespace(input_file_name=figure_file_name, output_file_name=figure_file_name)
def plot_storm_tracks(storm_object_table, axes_object, basemap_object, colour_map_object='random', line_colour=DEFAULT_TRACK_COLOUR, line_width=DEFAULT_TRACK_WIDTH, start_marker_type=DEFAULT_START_MARKER_TYPE, end_marker_type=DEFAULT_END_MARKER_TYPE, start_marker_size=DEFAULT_START_MARKER_SIZE, end_marker_size=DEFAULT_END_MARKER_SIZE): """Plots one or more storm tracks on the same map. :param storm_object_table: See doc for `plot_storm_outlines`. :param axes_object: Same. :param basemap_object: Same. :param colour_map_object: There are 3 cases. If "random", each track will be plotted in a random colour from `get_storm_track_colours`. If None, each track will be plotted in `line_colour` (the next input arg). If real colour map (instance of `matplotlib.pyplot.cm`), track segments will be coloured by time, according to this colour map. :param line_colour: [used only if `colour_map_object is None`] length-3 numpy array with (R, G, B). Will be used for all tracks. :param line_width: Width of each storm track. :param start_marker_type: Marker type for beginning of track (in any format accepted by `matplotlib.lines`). If `start_marker_type is None`, markers will not be used to show beginning of each track. :param end_marker_type: Same but for end of track. :param start_marker_size: Size of each start-point marker. :param end_marker_size: Size of each end-point marker. """ plot_start_markers = start_marker_type is not None plot_end_markers = end_marker_type is not None if start_marker_type is None: start_marker_type = DEFAULT_START_MARKER_TYPE start_marker_size = DEFAULT_START_MARKER_SIZE if end_marker_type is None: end_marker_type = DEFAULT_END_MARKER_TYPE end_marker_size = DEFAULT_END_MARKER_SIZE x_coords_metres, y_coords_metres = basemap_object( storm_object_table[tracking_utils.CENTROID_LONGITUDE_COLUMN].values, storm_object_table[tracking_utils.CENTROID_LATITUDE_COLUMN].values) storm_object_table = storm_object_table.assign( **{ tracking_utils.CENTROID_X_COLUMN: x_coords_metres, tracking_utils.CENTROID_Y_COLUMN: y_coords_metres }) rgb_matrix = None num_colours = None colour_norm_object = None if colour_map_object is None: error_checking.assert_is_numpy_array(line_colour, exact_dimensions=numpy.array( [3], dtype=int)) rgb_matrix = numpy.reshape(line_colour, (1, 3)) num_colours = rgb_matrix.shape[0] elif colour_map_object == 'random': rgb_matrix = get_storm_track_colours() num_colours = rgb_matrix.shape[0] colour_map_object = None else: first_time_unix_sec = numpy.min( storm_object_table[tracking_utils.VALID_TIME_COLUMN].values) last_time_unix_sec = numpy.max( storm_object_table[tracking_utils.VALID_TIME_COLUMN].values) colour_norm_object = pyplot.Normalize(first_time_unix_sec, last_time_unix_sec) track_primary_id_strings, object_to_track_indices = numpy.unique( storm_object_table[tracking_utils.PRIMARY_ID_COLUMN].values, return_inverse=True) num_tracks = len(track_primary_id_strings) for k in range(num_tracks): if colour_map_object is None: this_colour = rgb_matrix[numpy.mod(k, num_colours), :] this_colour = plotting_utils.colour_from_numpy_to_tuple( this_colour) else: this_colour = None these_object_indices = numpy.where(object_to_track_indices == k)[0] for i in these_object_indices: these_next_indices = temporal_tracking.find_immediate_successors( storm_object_table=storm_object_table, target_row=i) # if len(these_next_indices) > 1: # axes_object.text( # storm_object_table[ # tracking_utils.CENTROID_X_COLUMN].values[i], # storm_object_table[ # tracking_utils.CENTROID_Y_COLUMN].values[i], # '{0:d}-WAY SPLIT'.format(len(these_next_indices)), # fontsize=12, color='k', # horizontalalignment='left', verticalalignment='top') for j in these_next_indices: these_x_coords_metres = storm_object_table[ tracking_utils.CENTROID_X_COLUMN].values[[i, j]] these_y_coords_metres = storm_object_table[ tracking_utils.CENTROID_Y_COLUMN].values[[i, j]] if colour_map_object is None: axes_object.plot(these_x_coords_metres, these_y_coords_metres, color=this_colour, linestyle='solid', linewidth=line_width) else: this_point_matrix = numpy.array( [these_x_coords_metres, these_y_coords_metres]).T.reshape(-1, 1, 2) this_segment_matrix = numpy.concatenate( [this_point_matrix[:-1], this_point_matrix[1:]], axis=1) this_time_unix_sec = numpy.mean(storm_object_table[ tracking_utils.VALID_TIME_COLUMN].values[[i, j]]) this_line_collection_object = LineCollection( this_segment_matrix, cmap=colour_map_object, norm=colour_norm_object) this_line_collection_object.set_array( numpy.array([this_time_unix_sec])) this_line_collection_object.set_linewidth(line_width) axes_object.add_collection(this_line_collection_object) these_prev_indices = temporal_tracking.find_immediate_predecessors( storm_object_table=storm_object_table, target_row=i) # if len(these_prev_indices) > 1: # axes_object.text( # storm_object_table[ # tracking_utils.CENTROID_X_COLUMN].values[i], # storm_object_table[ # tracking_utils.CENTROID_Y_COLUMN].values[i], # '{0:d}-WAY MERGER'.format(len(these_prev_indices)), # fontsize=12, color='k', # horizontalalignment='left', verticalalignment='top') plot_this_start_marker = ((plot_start_markers and len(these_prev_indices) == 0) or len(these_object_indices) == 1) if plot_this_start_marker: if colour_map_object is not None: this_colour = colour_map_object( colour_norm_object(storm_object_table[ tracking_utils.VALID_TIME_COLUMN].values[i])) if start_marker_type == 'x': this_edge_width = 2 else: this_edge_width = 1 axes_object.plot( storm_object_table[ tracking_utils.CENTROID_X_COLUMN].values[i], storm_object_table[ tracking_utils.CENTROID_Y_COLUMN].values[i], linestyle='None', marker=start_marker_type, markerfacecolor=this_colour, markeredgecolor=this_colour, markersize=start_marker_size, markeredgewidth=this_edge_width) plot_this_end_marker = ((plot_end_markers and len(these_next_indices) == 0) or len(these_object_indices) == 1) if plot_this_end_marker: if colour_map_object is not None: this_colour = colour_map_object( colour_norm_object(storm_object_table[ tracking_utils.VALID_TIME_COLUMN].values[i])) if end_marker_type == 'x': this_edge_width = 2 else: this_edge_width = 1 axes_object.plot( storm_object_table[ tracking_utils.CENTROID_X_COLUMN].values[i], storm_object_table[ tracking_utils.CENTROID_Y_COLUMN].values[i], linestyle='None', marker=end_marker_type, markerfacecolor=this_colour, markeredgecolor=this_colour, markersize=end_marker_size, markeredgewidth=this_edge_width) if colour_map_object is None: return min_plot_latitude_deg = basemap_object.llcrnrlat max_plot_latitude_deg = basemap_object.urcrnrlat min_plot_longitude_deg = basemap_object.llcrnrlon max_plot_longitude_deg = basemap_object.urcrnrlon latitude_range_deg = max_plot_latitude_deg - min_plot_latitude_deg longitude_range_deg = max_plot_longitude_deg - min_plot_longitude_deg if latitude_range_deg > longitude_range_deg: orientation_string = 'vertical' else: orientation_string = 'horizontal' colour_bar_object = plotting_utils.plot_linear_colour_bar( axes_object_or_matrix=axes_object, data_matrix=storm_object_table[ tracking_utils.VALID_TIME_COLUMN].values, colour_map_object=colour_map_object, min_value=colour_norm_object.vmin, max_value=colour_norm_object.vmax, orientation_string=orientation_string, extend_min=False, extend_max=False, fraction_of_axis_length=0.9, font_size=COLOUR_BAR_FONT_SIZE) if orientation_string == 'horizontal': tick_values = colour_bar_object.ax.get_xticks() else: tick_values = colour_bar_object.ax.get_yticks() tick_times_unix_sec = numpy.round( colour_norm_object.inverse(tick_values)).astype(int) slope_sec_per_sec = (float(last_time_unix_sec - first_time_unix_sec) / (tick_times_unix_sec[-1] - tick_times_unix_sec[0])) tick_times_unix_sec = numpy.round( first_time_unix_sec + slope_sec_per_sec * (tick_times_unix_sec - tick_times_unix_sec[0])).astype(int) tick_time_strings = [ time_conversion.unix_sec_to_string(t, '%Y-%m-%d-%H%M%S') for t in tick_times_unix_sec ] print(tick_time_strings) tick_time_strings = [ time_conversion.unix_sec_to_string(t, COLOUR_BAR_TIME_FORMAT) for t in tick_times_unix_sec ] print(tick_time_strings) colour_bar_object.set_ticks(tick_values) colour_bar_object.set_ticklabels(tick_time_strings)
def _plot_rapruc_one_example( full_storm_id_string, storm_time_unix_sec, top_tracking_dir_name, latitude_buffer_deg, longitude_buffer_deg, lead_time_seconds, field_name_grib1, output_dir_name, rap_file_name=None, ruc_file_name=None): """Plots RAP or RUC field for one example. :param full_storm_id_string: Full storm ID. :param storm_time_unix_sec: Valid time. :param top_tracking_dir_name: See documentation at top of file. :param latitude_buffer_deg: Same. :param longitude_buffer_deg: Same. :param lead_time_seconds: Same. :param field_name_grib1: Same. :param output_dir_name: Same. :param rap_file_name: Path to file with RAP analysis. :param ruc_file_name: [used only if `rap_file_name is None`] Path to file with RUC analysis. """ tracking_file_name = tracking_io.find_file( top_tracking_dir_name=top_tracking_dir_name, tracking_scale_metres2=DUMMY_TRACKING_SCALE_METRES2, source_name=tracking_utils.SEGMOTION_NAME, valid_time_unix_sec=storm_time_unix_sec, spc_date_string= time_conversion.time_to_spc_date_string(storm_time_unix_sec), raise_error_if_missing=True ) print('Reading data from: "{0:s}"...'.format(tracking_file_name)) storm_object_table = tracking_io.read_file(tracking_file_name) storm_object_table = storm_object_table.loc[ storm_object_table[tracking_utils.FULL_ID_COLUMN] == full_storm_id_string ] extrap_times_sec = numpy.array([0, lead_time_seconds], dtype=int) storm_object_table = soundings._create_target_points_for_interp( storm_object_table=storm_object_table, lead_times_seconds=extrap_times_sec ) orig_latitude_deg = ( storm_object_table[tracking_utils.CENTROID_LATITUDE_COLUMN].values[0] ) orig_longitude_deg = ( storm_object_table[tracking_utils.CENTROID_LONGITUDE_COLUMN].values[0] ) extrap_latitude_deg = ( storm_object_table[tracking_utils.CENTROID_LATITUDE_COLUMN].values[1] ) extrap_longitude_deg = ( storm_object_table[tracking_utils.CENTROID_LONGITUDE_COLUMN].values[1] ) if rap_file_name is None: grib_file_name = ruc_file_name model_name = nwp_model_utils.RUC_MODEL_NAME else: grib_file_name = rap_file_name model_name = nwp_model_utils.RAP_MODEL_NAME pathless_grib_file_name = os.path.split(grib_file_name)[-1] grid_name = pathless_grib_file_name.split('_')[1] host_name = socket.gethostname() if 'casper' in host_name: wgrib_exe_name = '/glade/work/ryanlage/wgrib/wgrib' wgrib2_exe_name = '/glade/work/ryanlage/wgrib2/wgrib2/wgrib2' else: wgrib_exe_name = '/condo/swatwork/ralager/wgrib/wgrib' wgrib2_exe_name = '/condo/swatwork/ralager/grib2/wgrib2/wgrib2' print('Reading field "{0:s}" from: "{1:s}"...'.format( field_name_grib1, grib_file_name )) main_field_matrix = nwp_model_io.read_field_from_grib_file( grib_file_name=grib_file_name, field_name_grib1=field_name_grib1, model_name=model_name, grid_id=grid_name, wgrib_exe_name=wgrib_exe_name, wgrib2_exe_name=wgrib2_exe_name ) u_wind_name_grib1 = 'UGRD:{0:s}'.format( field_name_grib1.split(':')[-1] ) u_wind_name_grib1 = u_wind_name_grib1.replace('2 m', '10 m') print('Reading field "{0:s}" from: "{1:s}"...'.format( u_wind_name_grib1, grib_file_name )) u_wind_matrix_m_s01 = nwp_model_io.read_field_from_grib_file( grib_file_name=grib_file_name, field_name_grib1=u_wind_name_grib1, model_name=model_name, grid_id=grid_name, wgrib_exe_name=wgrib_exe_name, wgrib2_exe_name=wgrib2_exe_name ) v_wind_name_grib1 = 'VGRD:{0:s}'.format( u_wind_name_grib1.split(':')[-1] ) print('Reading field "{0:s}" from: "{1:s}"...'.format( v_wind_name_grib1, grib_file_name )) v_wind_matrix_m_s01 = nwp_model_io.read_field_from_grib_file( grib_file_name=grib_file_name, field_name_grib1=v_wind_name_grib1, model_name=model_name, grid_id=grid_name, wgrib_exe_name=wgrib_exe_name, wgrib2_exe_name=wgrib2_exe_name ) latitude_matrix_deg, longitude_matrix_deg = ( nwp_model_utils.get_latlng_grid_point_matrices( model_name=model_name, grid_name=grid_name) ) cosine_matrix, sine_matrix = nwp_model_utils.get_wind_rotation_angles( latitudes_deg=latitude_matrix_deg, longitudes_deg=longitude_matrix_deg, model_name=model_name ) u_wind_matrix_m_s01, v_wind_matrix_m_s01 = ( nwp_model_utils.rotate_winds_to_earth_relative( u_winds_grid_relative_m_s01=u_wind_matrix_m_s01, v_winds_grid_relative_m_s01=v_wind_matrix_m_s01, rotation_angle_cosines=cosine_matrix, rotation_angle_sines=sine_matrix) ) min_plot_latitude_deg = ( min([orig_latitude_deg, extrap_latitude_deg]) - latitude_buffer_deg ) max_plot_latitude_deg = ( max([orig_latitude_deg, extrap_latitude_deg]) + latitude_buffer_deg ) min_plot_longitude_deg = ( min([orig_longitude_deg, extrap_longitude_deg]) - longitude_buffer_deg ) max_plot_longitude_deg = ( max([orig_longitude_deg, extrap_longitude_deg]) + longitude_buffer_deg ) row_limits, column_limits = nwp_plotting.latlng_limits_to_rowcol_limits( min_latitude_deg=min_plot_latitude_deg, max_latitude_deg=max_plot_latitude_deg, min_longitude_deg=min_plot_longitude_deg, max_longitude_deg=max_plot_longitude_deg, model_name=model_name, grid_id=grid_name ) main_field_matrix = main_field_matrix[ row_limits[0]:(row_limits[1] + 1), column_limits[0]:(column_limits[1] + 1) ] u_wind_matrix_m_s01 = u_wind_matrix_m_s01[ row_limits[0]:(row_limits[1] + 1), column_limits[0]:(column_limits[1] + 1) ] v_wind_matrix_m_s01 = v_wind_matrix_m_s01[ row_limits[0]:(row_limits[1] + 1), column_limits[0]:(column_limits[1] + 1) ] _, axes_object, basemap_object = nwp_plotting.init_basemap( model_name=model_name, grid_id=grid_name, first_row_in_full_grid=row_limits[0], last_row_in_full_grid=row_limits[1], first_column_in_full_grid=column_limits[0], last_column_in_full_grid=column_limits[1] ) plotting_utils.plot_coastlines( basemap_object=basemap_object, axes_object=axes_object, line_colour=BORDER_COLOUR ) plotting_utils.plot_countries( basemap_object=basemap_object, axes_object=axes_object, line_colour=BORDER_COLOUR ) plotting_utils.plot_states_and_provinces( basemap_object=basemap_object, axes_object=axes_object, line_colour=BORDER_COLOUR ) plotting_utils.plot_parallels( basemap_object=basemap_object, axes_object=axes_object, num_parallels=NUM_PARALLELS ) plotting_utils.plot_meridians( basemap_object=basemap_object, axes_object=axes_object, num_meridians=NUM_MERIDIANS ) min_colour_value = numpy.nanpercentile( main_field_matrix, 100. - MAX_COLOUR_PERCENTILE ) max_colour_value = numpy.nanpercentile( main_field_matrix, MAX_COLOUR_PERCENTILE ) nwp_plotting.plot_subgrid( field_matrix=main_field_matrix, model_name=model_name, grid_id=grid_name, axes_object=axes_object, basemap_object=basemap_object, colour_map_object=COLOUR_MAP_OBJECT, min_colour_value=min_colour_value, max_colour_value=max_colour_value, first_row_in_full_grid=row_limits[0], first_column_in_full_grid=column_limits[0] ) nwp_plotting.plot_wind_barbs_on_subgrid( u_wind_matrix_m_s01=u_wind_matrix_m_s01, v_wind_matrix_m_s01=v_wind_matrix_m_s01, model_name=model_name, grid_id=grid_name, axes_object=axes_object, basemap_object=basemap_object, first_row_in_full_grid=row_limits[0], first_column_in_full_grid=column_limits[0], plot_every_k_rows=PLOT_EVERY_KTH_WIND_BARB, plot_every_k_columns=PLOT_EVERY_KTH_WIND_BARB, barb_length=WIND_BARB_LENGTH, empty_barb_radius=EMPTY_WIND_BARB_RADIUS, fill_empty_barb=True, colour_map=WIND_COLOUR_MAP_OBJECT, colour_minimum_kt=MIN_WIND_SPEED_KT, colour_maximum_kt=MAX_WIND_SPEED_KT ) orig_x_metres, orig_y_metres = basemap_object( orig_longitude_deg, orig_latitude_deg ) axes_object.plot( orig_x_metres, orig_y_metres, linestyle='None', marker=ORIGIN_MARKER_TYPE, markersize=ORIGIN_MARKER_SIZE, markeredgewidth=ORIGIN_MARKER_EDGE_WIDTH, markerfacecolor=MARKER_COLOUR, markeredgecolor=MARKER_COLOUR ) extrap_x_metres, extrap_y_metres = basemap_object( extrap_longitude_deg, extrap_latitude_deg ) axes_object.plot( extrap_x_metres, extrap_y_metres, linestyle='None', marker=EXTRAP_MARKER_TYPE, markersize=EXTRAP_MARKER_SIZE, markeredgewidth=EXTRAP_MARKER_EDGE_WIDTH, markerfacecolor=MARKER_COLOUR, markeredgecolor=MARKER_COLOUR ) plotting_utils.plot_linear_colour_bar( axes_object_or_matrix=axes_object, data_matrix=main_field_matrix, colour_map_object=COLOUR_MAP_OBJECT, min_value=min_colour_value, max_value=max_colour_value, orientation_string='vertical' ) output_file_name = '{0:s}/{1:s}_{2:s}.jpg'.format( output_dir_name, full_storm_id_string.replace('_', '-'), time_conversion.unix_sec_to_string( storm_time_unix_sec, FILE_NAME_TIME_FORMAT ) ) print('Saving figure to: "{0:s}"...'.format(output_file_name)) pyplot.savefig( output_file_name, dpi=FIGURE_RESOLUTION_DPI, pad_inches=0, bbox_inches='tight' ) pyplot.close()
def _plot_one_example( input_feature_matrix, feature_matrix_after_conv, feature_matrix_after_activn, feature_matrix_after_bn, feature_matrix_after_pooling, output_file_name): """Plots entire figure for one example (storm object). :param input_feature_matrix: 2-D numpy array with input features. :param feature_matrix_after_conv: 2-D numpy array with features after convolution. :param feature_matrix_after_activn: 2-D numpy array with features after activation. :param feature_matrix_after_bn: 2-D numpy array with features after batch normalization. :param feature_matrix_after_pooling: 2-D numpy array with features after pooling. :param output_file_name: Path to output file. Figure will be saved here. """ num_output_channels = feature_matrix_after_conv.shape[-1] figure_object, axes_object_matrix = plotting_utils.create_paneled_figure( num_rows=num_output_channels, num_columns=NUM_PANEL_COLUMNS, horizontal_spacing=0., vertical_spacing=0., shared_x_axis=False, shared_y_axis=False, keep_aspect_ratio=True) max_colour_value = numpy.percentile( numpy.absolute(input_feature_matrix), MAX_COLOUR_PERCENTILE ) axes_object_matrix[0, 0].set_title('Input', fontsize=TITLE_FONT_SIZE) for k in range(num_output_channels): if k == 0: _plot_one_feature_map( feature_matrix_2d=input_feature_matrix[..., k], max_colour_value=max_colour_value, plot_colour_bar=False, axes_object=axes_object_matrix[k, 0] ) continue axes_object_matrix[k, 0].axis('off') colour_bar_object = plotting_utils.plot_linear_colour_bar( axes_object_or_matrix=axes_object_matrix[num_output_channels - 1, 0], data_matrix=input_feature_matrix[..., 0], colour_map_object=COLOUR_MAP_OBJECT, min_value=-1 * max_colour_value, max_value=max_colour_value, orientation_string='horizontal', padding=0.015, fraction_of_axis_length=0.9, extend_min=True, extend_max=True, font_size=DEFAULT_FONT_SIZE) tick_values = colour_bar_object.ax.get_xticks() tick_label_strings = ['{0:.1f}'.format(x) for x in tick_values] colour_bar_object.set_ticks(tick_values) colour_bar_object.set_ticklabels(tick_label_strings) letter_label = 'a' plotting_utils.label_axes( axes_object=axes_object_matrix[0, 0], label_string='({0:s})'.format(letter_label), font_size=TITLE_FONT_SIZE, x_coord_normalized=0.125, y_coord_normalized=1.025 ) this_matrix = numpy.stack( (feature_matrix_after_conv, feature_matrix_after_activn), axis=0 ) max_colour_value = numpy.percentile( numpy.absolute(this_matrix), MAX_COLOUR_PERCENTILE ) axes_object_matrix[0, 1].set_title( ' After convolution', fontsize=TITLE_FONT_SIZE) for k in range(num_output_channels): _plot_one_feature_map( feature_matrix_2d=feature_matrix_after_conv[..., k], max_colour_value=max_colour_value, plot_colour_bar=k == num_output_channels - 1, axes_object=axes_object_matrix[k, 1] ) letter_label = chr(ord(letter_label) + 1) plotting_utils.label_axes( axes_object=axes_object_matrix[k, 1], label_string='({0:s})'.format(letter_label), font_size=TITLE_FONT_SIZE, x_coord_normalized=0.125, y_coord_normalized=1.025 ) axes_object_matrix[0, 2].set_title( ' After activation', fontsize=TITLE_FONT_SIZE) for k in range(num_output_channels): _plot_one_feature_map( feature_matrix_2d=feature_matrix_after_activn[..., k], max_colour_value=max_colour_value, plot_colour_bar=k == num_output_channels - 1, axes_object=axes_object_matrix[k, 2] ) letter_label = chr(ord(letter_label) + 1) plotting_utils.label_axes( axes_object=axes_object_matrix[k, 2], label_string='({0:s})'.format(letter_label), font_size=TITLE_FONT_SIZE, x_coord_normalized=0.125, y_coord_normalized=1.025 ) max_colour_value = numpy.percentile( numpy.absolute(feature_matrix_after_bn), MAX_COLOUR_PERCENTILE ) axes_object_matrix[0, 3].set_title( ' After batch norm', fontsize=TITLE_FONT_SIZE) for k in range(num_output_channels): _plot_one_feature_map( feature_matrix_2d=feature_matrix_after_bn[..., k], max_colour_value=max_colour_value, plot_colour_bar=k == num_output_channels - 1, axes_object=axes_object_matrix[k, 3] ) letter_label = chr(ord(letter_label) + 1) plotting_utils.label_axes( axes_object=axes_object_matrix[k, 3], label_string='({0:s})'.format(letter_label), font_size=TITLE_FONT_SIZE, x_coord_normalized=0.125, y_coord_normalized=1.025 ) max_colour_value = numpy.percentile( numpy.absolute(feature_matrix_after_pooling), MAX_COLOUR_PERCENTILE ) axes_object_matrix[0, 4].set_title( 'After pooling', fontsize=TITLE_FONT_SIZE) for k in range(num_output_channels): _plot_one_feature_map( feature_matrix_2d=feature_matrix_after_pooling[..., k], max_colour_value=max_colour_value, plot_colour_bar=k == num_output_channels - 1, axes_object=axes_object_matrix[k, 4] ) letter_label = chr(ord(letter_label) + 1) plotting_utils.label_axes( axes_object=axes_object_matrix[k, 4], label_string='({0:s})'.format(letter_label), font_size=TITLE_FONT_SIZE, x_coord_normalized=0.125, y_coord_normalized=1.025 ) print('Saving figure to: "{0:s}"...'.format(output_file_name)) figure_object.savefig( output_file_name, dpi=FIGURE_RESOLUTION_DPI, pad_inches=0, bbox_inches='tight' ) pyplot.close(figure_object)
def _plot_gradcam_one_example( gradcam_dict, example_index, model_metadata_dict, colour_map_object, max_colour_percentile, output_dir_name): """Plots class-activation map for one example, all target variables. :param gradcam_dict: Dictionary read by `gradcam.read_all_targets_file`. :param example_index: Will plot class-activation maps for example with this array index. :param model_metadata_dict: Dictionary read by `neural_net.read_metafile`. :param colour_map_object: See documentation at top of file. :param max_colour_percentile: Same. :param output_dir_name: Same. """ generator_option_dict = model_metadata_dict[neural_net.TRAINING_OPTIONS_KEY] target_names = generator_option_dict[neural_net.VECTOR_TARGET_NAMES_KEY] target_names_verbose = [ TARGET_NAME_TO_VERBOSE[n] for n in target_names ] heights_km_agl = ( METRES_TO_KM * generator_option_dict[neural_net.HEIGHTS_KEY] ) height_labels = profile_plotting.create_height_labels( tick_values_km_agl=heights_km_agl, use_log_scale=False ) height_labels = [ height_labels[k] if numpy.mod(k, 4) == 0 else ' ' for k in range(len(height_labels)) ] example_id_string = gradcam_dict[gradcam.EXAMPLE_IDS_KEY][example_index] class_activation_matrix_3d = ( gradcam_dict[gradcam.CLASS_ACTIVATIONS_KEY][example_index, ...] ) num_targets = len(target_names) num_heights = len(height_labels) for k in range(num_targets): class_activation_matrix_2d = class_activation_matrix_3d[..., k] max_colour_value = numpy.percentile( class_activation_matrix_2d, max_colour_percentile ) max_colour_value = numpy.maximum(max_colour_value, 0.001) figure_object, axes_object = pyplot.subplots( 1, 1, figsize=(FIGURE_WIDTH_INCHES, FIGURE_HEIGHT_INCHES) ) axes_object.imshow( numpy.transpose(class_activation_matrix_2d), cmap=colour_map_object, vmin=0., vmax=max_colour_value, origin='lower' ) tick_values = numpy.linspace( 0, num_heights - 1, num=num_heights, dtype=float ) axes_object.set_xticks(tick_values) axes_object.set_yticks(tick_values) axes_object.set_xticklabels( height_labels, fontsize=TICK_LABEL_FONT_SIZE, rotation=90. ) axes_object.set_yticklabels( height_labels, fontsize=TICK_LABEL_FONT_SIZE ) axes_object.set_xlabel('Predictor height (km AGL)') axes_object.set_ylabel('Target height (km AGL)') axes_object.plot( axes_object.get_xlim(), axes_object.get_ylim(), color=REFERENCE_LINE_COLOUR, linestyle='dashed', linewidth=REFERENCE_LINE_WIDTH ) colour_bar_object = plotting_utils.plot_linear_colour_bar( axes_object_or_matrix=axes_object, data_matrix=class_activation_matrix_2d, colour_map_object=colour_map_object, min_value=0., max_value=max_colour_value, orientation_string='horizontal', padding=0.1, extend_min=False, extend_max=True, fraction_of_axis_length=0.8, font_size=DEFAULT_FONT_SIZE ) tick_values = colour_bar_object.get_ticks() tick_strings = ['{0:.1f}'.format(v) for v in tick_values] colour_bar_object.set_ticks(tick_values) colour_bar_object.set_ticklabels(tick_strings) title_string = 'Class-activation map for {0:s}'.format( target_names_verbose[k] ) axes_object.set_title(title_string, fontsize=DEFAULT_FONT_SIZE) output_file_name = '{0:s}/{1:s}_{2:s}.jpg'.format( output_dir_name, example_id_string.replace('_', '-'), target_names[k].replace('_', '-') ) print('Saving figure to: "{0:s}"...'.format(output_file_name)) figure_object.savefig( output_file_name, dpi=FIGURE_RESOLUTION_DPI, pad_inches=0, bbox_inches='tight' ) pyplot.close(figure_object)
def _plot_2d_radar_cam( colour_map_object, min_unguided_value, max_unguided_value, num_unguided_contours, max_guided_value, half_num_guided_contours, label_colour_bars, colour_bar_length, figure_objects, axes_object_matrices, model_metadata_dict, output_dir_name, cam_matrix=None, guided_cam_matrix=None, full_storm_id_string=None, storm_time_unix_sec=None): """Plots class-activation map for 2-D radar data. M = number of rows in spatial grid N = number of columns in spatial grid F = number of radar fields If this method is plotting a composite rather than single example (storm object), `full_storm_id_string` and `storm_time_unix_sec` can be None. :param colour_map_object: See doc for `_plot_3d_radar_cam`. :param min_unguided_value: Same. :param max_unguided_value: Same. :param num_unguided_contours: Same. :param max_guided_value: Same. :param half_num_guided_contours: Same. :param label_colour_bars: Same. :param colour_bar_length: Same. :param figure_objects: See doc for `plot_input_examples._plot_2d_radar_scan`. :param axes_object_matrices: Same. :param model_metadata_dict: See doc for `_plot_3d_radar_cam`. :param output_dir_name: Same. :param cam_matrix: M-by-N numpy array of unguided class activations. :param guided_cam_matrix: [used only if `cam_matrix is None`] M-by-N-by-F numpy array of guided class activations. :param full_storm_id_string: Full storm ID. :param storm_time_unix_sec: Storm time. """ pmm_flag = full_storm_id_string is None and storm_time_unix_sec is None conv_2d3d = model_metadata_dict[cnn.CONV_2D3D_KEY] if conv_2d3d: figure_index = 1 radar_field_name = 'shear' else: figure_index = 0 radar_field_name = None list_of_layer_operation_dicts = model_metadata_dict[ cnn.LAYER_OPERATIONS_KEY] if list_of_layer_operation_dicts is None: training_option_dict = model_metadata_dict[cnn.TRAINING_OPTION_DICT_KEY] radar_field_names = training_option_dict[trainval_io.RADAR_FIELDS_KEY] num_channels = len(radar_field_names) else: num_channels = len(list_of_layer_operation_dicts) min_unguided_value_log10 = numpy.log10(min_unguided_value) max_unguided_value_log10 = numpy.log10(max_unguided_value) contour_interval_log10 = ( (max_unguided_value_log10 - min_unguided_value_log10) / (num_unguided_contours - 1) ) if cam_matrix is None: saliency_plotting.plot_many_2d_grids_with_contours( saliency_matrix_3d=numpy.flip(guided_cam_matrix, axis=0), axes_object_matrix=axes_object_matrices[figure_index], colour_map_object=colour_map_object, max_absolute_contour_level=max_guided_value, contour_interval=max_guided_value / half_num_guided_contours, row_major=False ) this_colour_bar_object = plotting_utils.plot_linear_colour_bar( axes_object_or_matrix=axes_object_matrices[figure_index], data_matrix=guided_cam_matrix, colour_map_object=colour_map_object, min_value=0., max_value=max_guided_value, orientation_string='horizontal', fraction_of_axis_length=colour_bar_length / (1 + int(conv_2d3d)), extend_min=False, extend_max=True, font_size=COLOUR_BAR_FONT_SIZE ) if label_colour_bars: this_colour_bar_object.set_label( 'Absolute guided class activation', fontsize=COLOUR_BAR_FONT_SIZE ) else: this_cam_matrix_log10 = numpy.log10( numpy.expand_dims(cam_matrix, axis=-1) ) this_cam_matrix_log10 = numpy.repeat( this_cam_matrix_log10, repeats=num_channels, axis=-1 ) cam_plotting.plot_many_2d_grids( class_activation_matrix_3d=numpy.flip( this_cam_matrix_log10, axis=0 ), axes_object_matrix=axes_object_matrices[figure_index], colour_map_object=colour_map_object, min_contour_level=min_unguided_value_log10, max_contour_level=max_unguided_value_log10, contour_interval=contour_interval_log10, row_major=False ) this_colour_bar_object = plotting_utils.plot_linear_colour_bar( axes_object_or_matrix=axes_object_matrices[figure_index], data_matrix=this_cam_matrix_log10, colour_map_object=colour_map_object, min_value=min_unguided_value_log10, max_value=max_unguided_value_log10, orientation_string='horizontal', fraction_of_axis_length=colour_bar_length / (1 + int(conv_2d3d)), extend_min=True, extend_max=True, font_size=COLOUR_BAR_FONT_SIZE ) these_tick_values = this_colour_bar_object.get_ticks() these_tick_strings = [ '{0:.2f}'.format(10 ** v)[:4] for v in these_tick_values ] this_colour_bar_object.set_ticks(these_tick_values) this_colour_bar_object.set_ticklabels(these_tick_strings) if label_colour_bars: this_colour_bar_object.set_label( 'Class activation', fontsize=COLOUR_BAR_FONT_SIZE ) output_file_name = plot_examples.metadata_to_file_name( output_dir_name=output_dir_name, is_sounding=False, pmm_flag=pmm_flag, full_storm_id_string=full_storm_id_string, storm_time_unix_sec=storm_time_unix_sec, radar_field_name=radar_field_name ) print('Saving figure to: "{0:s}"...'.format(output_file_name)) figure_objects[figure_index].savefig( output_file_name, dpi=FIGURE_RESOLUTION_DPI, pad_inches=0, bbox_inches='tight' ) pyplot.close(figure_objects[figure_index])
def _plot_3d_radar_cam( colour_map_object, min_unguided_value, max_unguided_value, num_unguided_contours, max_guided_value, half_num_guided_contours, label_colour_bars, colour_bar_length, figure_objects, axes_object_matrices, model_metadata_dict, output_dir_name, cam_matrix=None, guided_cam_matrix=None, full_storm_id_string=None, storm_time_unix_sec=None): """Plots class-activation map for 3-D radar data. M = number of rows in spatial grid N = number of columns in spatial grid H = number of heights in spatial grid F = number of radar fields If this method is plotting a composite rather than single example (storm object), `full_storm_id_string` and `storm_time_unix_sec` can be None. :param colour_map_object: See documentation at top of file. :param min_unguided_value: Same. :param max_unguided_value: Same. :param num_unguided_contours: Same. :param max_guided_value: Same. :param half_num_guided_contours: Same. :param label_colour_bars: Same. :param colour_bar_length: Same. :param figure_objects: See doc for `plot_input_examples._plot_3d_radar_scan`. :param axes_object_matrices: Same. :param model_metadata_dict: Dictionary returned by `cnn.read_model_metadata`. :param output_dir_name: Path to output directory. Figure(s) will be saved here. :param cam_matrix: M-by-N-by-H numpy array of unguided class activations. :param guided_cam_matrix: [used only if `cam_matrix is None`] M-by-N-by-H-by-F numpy array of guided class activations. :param full_storm_id_string: Full storm ID. :param storm_time_unix_sec: Storm time. """ pmm_flag = full_storm_id_string is None and storm_time_unix_sec is None conv_2d3d = model_metadata_dict[cnn.CONV_2D3D_KEY] if conv_2d3d: loop_max = 1 radar_field_names = ['reflectivity'] else: loop_max = len(figure_objects) training_option_dict = model_metadata_dict[cnn.TRAINING_OPTION_DICT_KEY] radar_field_names = training_option_dict[trainval_io.RADAR_FIELDS_KEY] min_unguided_value_log10 = numpy.log10(min_unguided_value) max_unguided_value_log10 = numpy.log10(max_unguided_value) contour_interval_log10 = ( (max_unguided_value_log10 - min_unguided_value_log10) / (num_unguided_contours - 1) ) for j in range(loop_max): if cam_matrix is None: saliency_plotting.plot_many_2d_grids_with_contours( saliency_matrix_3d=numpy.flip( guided_cam_matrix[..., j], axis=0 ), axes_object_matrix=axes_object_matrices[j], colour_map_object=colour_map_object, max_absolute_contour_level=max_guided_value, contour_interval=max_guided_value / half_num_guided_contours ) this_colour_bar_object = plotting_utils.plot_linear_colour_bar( axes_object_or_matrix=axes_object_matrices[j], data_matrix=guided_cam_matrix[..., j], colour_map_object=colour_map_object, min_value=0., max_value=max_guided_value, orientation_string='horizontal', fraction_of_axis_length=colour_bar_length, extend_min=False, extend_max=True, font_size=COLOUR_BAR_FONT_SIZE ) if label_colour_bars: this_colour_bar_object.set_label( 'Absolute guided class activation', fontsize=COLOUR_BAR_FONT_SIZE ) else: cam_matrix_log10 = numpy.log10(cam_matrix) cam_plotting.plot_many_2d_grids( class_activation_matrix_3d=numpy.flip(cam_matrix_log10, axis=0), axes_object_matrix=axes_object_matrices[j], colour_map_object=colour_map_object, min_contour_level=min_unguided_value_log10, max_contour_level=max_unguided_value_log10, contour_interval=contour_interval_log10 ) this_colour_bar_object = plotting_utils.plot_linear_colour_bar( axes_object_or_matrix=axes_object_matrices[j], data_matrix=cam_matrix_log10, colour_map_object=colour_map_object, min_value=min_unguided_value_log10, max_value=max_unguided_value_log10, orientation_string='horizontal', fraction_of_axis_length=colour_bar_length, extend_min=True, extend_max=True, font_size=COLOUR_BAR_FONT_SIZE ) these_tick_values = this_colour_bar_object.get_ticks() these_tick_strings = [ '{0:.2f}'.format(10 ** v)[:4] for v in these_tick_values ] this_colour_bar_object.set_ticks(these_tick_values) this_colour_bar_object.set_ticklabels(these_tick_strings) if label_colour_bars: this_colour_bar_object.set_label( 'Class activation', fontsize=COLOUR_BAR_FONT_SIZE ) this_file_name = plot_examples.metadata_to_file_name( output_dir_name=output_dir_name, is_sounding=False, pmm_flag=pmm_flag, full_storm_id_string=full_storm_id_string, storm_time_unix_sec=storm_time_unix_sec, radar_field_name=radar_field_names[j] ) print('Saving figure to: "{0:s}"...'.format(this_file_name)) figure_objects[j].savefig( this_file_name, dpi=FIGURE_RESOLUTION_DPI, pad_inches=0, bbox_inches='tight' ) pyplot.close(figure_objects[j])
def plot_saliency_for_sounding(saliency_matrix, sounding_field_names, pressure_levels_mb, colour_map_object, max_absolute_colour_value, min_font_size=DEFAULT_MIN_SOUNDING_FONT_SIZE, max_font_size=DEFAULT_MAX_SOUNDING_FONT_SIZE): """Plots saliency for one sounding. P = number of pressure levels F = number of fields :param saliency_matrix: P-by-F numpy array of saliency values. :param sounding_field_names: length-F list of field names. :param pressure_levels_mb: length-P list of pressure levels (millibars). :param colour_map_object: See doc for `plot_2d_grid`. :param max_absolute_colour_value: Same. :param min_font_size: Same. :param max_font_size: Same. """ error_checking.assert_is_geq(max_absolute_colour_value, 0.) max_absolute_colour_value = max([max_absolute_colour_value, 0.001]) error_checking.assert_is_greater_numpy_array(pressure_levels_mb, 0.) error_checking.assert_is_numpy_array(pressure_levels_mb, num_dimensions=1) error_checking.assert_is_list(sounding_field_names) error_checking.assert_is_numpy_array(numpy.array(sounding_field_names), num_dimensions=1) num_pressure_levels = len(pressure_levels_mb) num_sounding_fields = len(sounding_field_names) error_checking.assert_is_numpy_array_without_nan(saliency_matrix) error_checking.assert_is_numpy_array(saliency_matrix, exact_dimensions=numpy.array([ num_pressure_levels, num_sounding_fields ])) try: u_wind_index = sounding_field_names.index(soundings.U_WIND_NAME) v_wind_index = sounding_field_names.index(soundings.V_WIND_NAME) plot_wind_barbs = True except ValueError: plot_wind_barbs = False if plot_wind_barbs: u_wind_saliency_values = saliency_matrix[:, u_wind_index] v_wind_saliency_values = saliency_matrix[:, v_wind_index] wind_saliency_magnitudes = numpy.sqrt(u_wind_saliency_values**2 + v_wind_saliency_values**2) colour_norm_object = pyplot.Normalize(vmin=0., vmax=max_absolute_colour_value) rgb_matrix_for_wind = colour_map_object( colour_norm_object(wind_saliency_magnitudes))[..., :-1] non_wind_flags = numpy.array( [f not in WIND_COMPONENT_NAMES for f in sounding_field_names], dtype=bool) non_wind_indices = numpy.where(non_wind_flags)[0] saliency_matrix = saliency_matrix[:, non_wind_indices] sounding_field_names = [ sounding_field_names[k] for k in non_wind_indices ] sounding_field_names.append(WIND_NAME) num_sounding_fields = len(sounding_field_names) rgb_matrix, font_size_matrix = _saliency_to_colour_and_size( saliency_matrix=saliency_matrix, colour_map_object=colour_map_object, max_absolute_colour_value=max_absolute_colour_value, min_font_size=min_font_size, max_font_size=max_font_size) _, axes_object = pyplot.subplots(1, 1, figsize=(FIGURE_WIDTH_INCHES, FIGURE_HEIGHT_INCHES)) axes_object.set_facecolor( plotting_utils.colour_from_numpy_to_tuple( SOUNDING_SALIENCY_BACKGROUND_COLOUR)) for k in range(num_sounding_fields): if sounding_field_names[k] == WIND_NAME: for j in range(num_pressure_levels): this_vector = numpy.array( [u_wind_saliency_values[j], v_wind_saliency_values[j]]) this_vector = (WIND_SALIENCY_MULTIPLIER * this_vector / numpy.linalg.norm(this_vector, ord=2)) this_colour_tuple = plotting_utils.colour_from_numpy_to_tuple( rgb_matrix_for_wind[j, ...]) axes_object.barbs(k, pressure_levels_mb[j], this_vector[0], this_vector[1], length=WIND_BARB_LENGTH, fill_empty=True, rounding=False, sizes={'emptybarb': EMPTY_WIND_BARB_RADIUS}, color=this_colour_tuple) continue for j in range(num_pressure_levels): this_colour_tuple = plotting_utils.colour_from_numpy_to_tuple( rgb_matrix[j, k, ...]) if saliency_matrix[j, k] >= 0: axes_object.text(k, pressure_levels_mb[j], '+', fontsize=font_size_matrix[j, k], color=this_colour_tuple, horizontalalignment='center', verticalalignment='center') else: axes_object.text(k, pressure_levels_mb[j], '_', fontsize=font_size_matrix[j, k], color=this_colour_tuple, horizontalalignment='center', verticalalignment='bottom') axes_object.set_xlim(-0.5, num_sounding_fields - 0.5) axes_object.set_ylim(100, 1000) axes_object.invert_yaxis() pyplot.yscale('log') pyplot.minorticks_off() y_tick_locations = numpy.linspace(100, 1000, num=10, dtype=int) y_tick_labels = ['{0:d}'.format(p) for p in y_tick_locations] pyplot.yticks(y_tick_locations, y_tick_labels) x_tick_locations = numpy.linspace(0, num_sounding_fields - 1, num=num_sounding_fields, dtype=float) x_tick_labels = [FIELD_NAME_TO_LATEX_DICT[f] for f in sounding_field_names] pyplot.xticks(x_tick_locations, x_tick_labels) colour_bar_object = plotting_utils.plot_linear_colour_bar( axes_object_or_matrix=axes_object, data_matrix=saliency_matrix, colour_map_object=colour_map_object, min_value=0., max_value=max_absolute_colour_value, orientation_string='vertical', extend_min=True, extend_max=True) colour_bar_object.set_label('Saliency (absolute value)')
def _plot_data(num_days_matrix, grid_metadata_dict, colour_map_object): """Plots data. M = number of rows in grid N = number of columns in grid :param num_days_matrix: M-by-N numpy array with number of convective days for which grid cell is in domain. :param grid_metadata_dict: Dictionary created by `grids.create_equidistant_grid`. :param colour_map_object: See documentation at top of file. :return: figure_object: Figure handle (instance of `matplotlib.figure.Figure`). :return: axes_object: Axes handle (instance of `matplotlib.axes._subplots.AxesSubplot`). """ figure_object, axes_object = pyplot.subplots( 1, 1, figsize=(FIGURE_WIDTH_INCHES, FIGURE_HEIGHT_INCHES)) basemap_object, basemap_x_matrix_metres, basemap_y_matrix_metres = ( _get_basemap(grid_metadata_dict)) num_grid_rows = num_days_matrix.shape[0] num_grid_columns = num_days_matrix.shape[1] x_spacing_metres = ( (basemap_x_matrix_metres[0, -1] - basemap_x_matrix_metres[0, 0]) / (num_grid_columns - 1)) y_spacing_metres = ( (basemap_y_matrix_metres[-1, 0] - basemap_y_matrix_metres[0, 0]) / (num_grid_rows - 1)) matrix_to_plot, edge_x_coords_metres, edge_y_coords_metres = ( grids.xy_field_grid_points_to_edges( field_matrix=num_days_matrix, x_min_metres=basemap_x_matrix_metres[0, 0], y_min_metres=basemap_y_matrix_metres[0, 0], x_spacing_metres=x_spacing_metres, y_spacing_metres=y_spacing_metres)) matrix_to_plot = numpy.ma.masked_where(matrix_to_plot == 0, matrix_to_plot) plotting_utils.plot_coastlines(basemap_object=basemap_object, axes_object=axes_object, line_colour=BORDER_COLOUR) plotting_utils.plot_countries(basemap_object=basemap_object, axes_object=axes_object, line_colour=BORDER_COLOUR) plotting_utils.plot_states_and_provinces(basemap_object=basemap_object, axes_object=axes_object, line_colour=BORDER_COLOUR) plotting_utils.plot_parallels(basemap_object=basemap_object, axes_object=axes_object, num_parallels=NUM_PARALLELS) plotting_utils.plot_meridians(basemap_object=basemap_object, axes_object=axes_object, num_meridians=NUM_MERIDIANS) basemap_object.pcolormesh(edge_x_coords_metres, edge_y_coords_metres, matrix_to_plot, cmap=colour_map_object, vmin=1, vmax=numpy.max(num_days_matrix), shading='flat', edgecolors='None', axes=axes_object, zorder=-1e12) colour_bar_object = plotting_utils.plot_linear_colour_bar( axes_object_or_matrix=axes_object, data_matrix=num_days_matrix, colour_map_object=colour_map_object, min_value=1, max_value=numpy.max(num_days_matrix), orientation_string='horizontal', extend_min=False, extend_max=False, padding=0.05) tick_values = colour_bar_object.get_ticks() tick_strings = ['{0:d}'.format(int(numpy.round(v))) for v in tick_values] colour_bar_object.set_ticks(tick_values) colour_bar_object.set_ticklabels(tick_strings) axes_object.set_title('Number of convective days by grid cell') return figure_object, axes_object
def _plot_feature_maps_one_layer(feature_matrix, full_id_strings, storm_times_unix_sec, layer_name, output_dir_name): """Plots all feature maps for one layer. E = number of examples (storm objects) M = number of spatial rows N = number of spatial columns H = number of spatial depths (heights) C = number of channels :param feature_matrix: numpy array (E x M x N x C or E x M x N x H x C) of feature maps. :param full_id_strings: length-E list of full storm IDs. :param storm_times_unix_sec: length-E numpy array of storm times. :param layer_name: Name of layer. :param output_dir_name: Name of output directory for this layer. """ num_spatial_dimensions = len(feature_matrix.shape) - 2 num_storm_objects = feature_matrix.shape[0] num_channels = feature_matrix.shape[-1] if num_spatial_dimensions == 3: num_heights = feature_matrix.shape[-2] else: num_heights = None num_panel_rows = int(numpy.round(numpy.sqrt(num_channels))) annotation_string_by_channel = [None] * num_channels # annotation_string_by_channel = [ # 'Filter {0:d}'.format(c + 1) for c in range(num_channels) # ] if num_channels >= NUM_PANELS_FOR_NO_FONT: annotation_string_by_channel = [''] * num_channels font_size = TINY_FONT_SIZE + 0 elif num_channels >= NUM_PANELS_FOR_TINY_FONT: font_size = TINY_FONT_SIZE + 0 elif num_channels >= NUM_PANELS_FOR_SMALL_FONT: font_size = SMALL_FONT_SIZE + 0 else: font_size = MAIN_FONT_SIZE + 0 max_colour_value = numpy.percentile(numpy.absolute(feature_matrix), 99) min_colour_value = -1 * max_colour_value for i in range(num_storm_objects): this_time_string = time_conversion.unix_sec_to_string( storm_times_unix_sec[i], TIME_FORMAT) if num_spatial_dimensions == 2: _, this_axes_object_matrix = ( feature_map_plotting.plot_many_2d_feature_maps( feature_matrix=numpy.flip(feature_matrix[i, ...], axis=0), annotation_string_by_panel=annotation_string_by_channel, num_panel_rows=num_panel_rows, colour_map_object=pyplot.cm.seismic, min_colour_value=min_colour_value, max_colour_value=max_colour_value, font_size=font_size)) plotting_utils.plot_linear_colour_bar( axes_object_or_matrix=this_axes_object_matrix, data_matrix=feature_matrix[i, ...], colour_map_object=pyplot.cm.seismic, min_value=min_colour_value, max_value=max_colour_value, orientation_string='horizontal', extend_min=True, extend_max=True) this_title_string = 'Layer "{0:s}", storm "{1:s}" at {2:s}'.format( layer_name, full_id_strings[i], this_time_string) pyplot.suptitle(this_title_string, fontsize=MAIN_FONT_SIZE) this_figure_file_name = ( '{0:s}/storm={1:s}_{2:s}_features.jpg').format( output_dir_name, full_id_strings[i].replace('_', '-'), this_time_string) print('Saving figure to: "{0:s}"...'.format(this_figure_file_name)) pyplot.savefig(this_figure_file_name, dpi=FIGURE_RESOLUTION_DPI) pyplot.close() else: for k in range(num_heights): _, this_axes_object_matrix = ( feature_map_plotting.plot_many_2d_feature_maps( feature_matrix=numpy.flip(feature_matrix[i, :, :, k, :], axis=0), annotation_string_by_panel=annotation_string_by_channel, num_panel_rows=num_panel_rows, colour_map_object=pyplot.cm.seismic, min_colour_value=min_colour_value, max_colour_value=max_colour_value, font_size=font_size)) plotting_utils.plot_linear_colour_bar( axes_object_or_matrix=this_axes_object_matrix, data_matrix=feature_matrix[i, :, :, k, :], colour_map_object=pyplot.cm.seismic, min_value=min_colour_value, max_value=max_colour_value, orientation_string='horizontal', extend_min=True, extend_max=True) this_title_string = ( 'Layer "{0:s}", height {1:d} of {2:d}, storm "{3:s}" at ' '{4:s}').format(layer_name, k + 1, num_heights, full_id_strings[i], this_time_string) pyplot.suptitle(this_title_string, fontsize=MAIN_FONT_SIZE) this_figure_file_name = ( '{0:s}/storm={1:s}_{2:s}_features_height{3:02d}.jpg' ).format(output_dir_name, full_id_strings[i].replace('_', '-'), this_time_string, k + 1) print('Saving figure to: "{0:s}"...'.format( this_figure_file_name)) pyplot.savefig(this_figure_file_name, dpi=FIGURE_RESOLUTION_DPI) pyplot.close()
def _plot_score_one_field(latitude_matrix_deg, longitude_matrix_deg, score_matrix, colour_map_object, min_colour_value, max_colour_value, taper_cbar_top, taper_cbar_bottom, log_scale=False): """Plots one score for one field. M = number of rows in grid N = number of columns in grid :param latitude_matrix_deg: M-by-N numpy array of latitudes (deg N). :param longitude_matrix_deg: M-by-N numpy array of longitudes (deg E). :param score_matrix: M-by-N numpy array of score values. :param colour_map_object: Colour scheme (instance of `matplotlib.pyplot.cm`). :param min_colour_value: Minimum value in colour bar. :param max_colour_value: Max value in colour bar. :param taper_cbar_top: Boolean flag. If True, will taper bottom of colour bar, implying that lower values are possible. :param taper_cbar_bottom: Same but for top of colour bar. :param log_scale: Boolean flag. If True, will make colour bar logarithmic. :return: figure_object: Figure handle (instance of `matplotlib.figure.Figure`). :return: axes_object: Axes handle (instance of `matplotlib.axes._subplots.AxesSubplot`). """ (figure_object, axes_object, basemap_object) = plotting_utils.create_equidist_cylindrical_map( min_latitude_deg=latitude_matrix_deg[0, 0], max_latitude_deg=latitude_matrix_deg[-1, -1], min_longitude_deg=longitude_matrix_deg[0, 0], max_longitude_deg=longitude_matrix_deg[-1, -1], resolution_string=RESOLUTION_STRING) latitude_spacing_deg = latitude_matrix_deg[1, 0] - latitude_matrix_deg[0, 0] longitude_spacing_deg = (longitude_matrix_deg[0, 1] - longitude_matrix_deg[0, 0]) print(numpy.sum(numpy.invert(numpy.isnan(score_matrix)))) (score_matrix_at_edges, grid_edge_latitudes_deg, grid_edge_longitudes_deg) = grids.latlng_field_grid_points_to_edges( field_matrix=score_matrix, min_latitude_deg=latitude_matrix_deg[0, 0], min_longitude_deg=longitude_matrix_deg[0, 0], lat_spacing_deg=latitude_spacing_deg, lng_spacing_deg=longitude_spacing_deg) score_matrix_at_edges = numpy.ma.masked_where( numpy.isnan(score_matrix_at_edges), score_matrix_at_edges) plotting_utils.plot_coastlines(basemap_object=basemap_object, axes_object=axes_object, line_colour=BORDER_COLOUR, line_width=BORDER_WIDTH) plotting_utils.plot_countries(basemap_object=basemap_object, axes_object=axes_object, line_colour=BORDER_COLOUR, line_width=BORDER_WIDTH) # plotting_utils.plot_states_and_provinces( # basemap_object=basemap_object, axes_object=axes_object, # line_colour=BORDER_COLOUR # ) plotting_utils.plot_parallels(basemap_object=basemap_object, axes_object=axes_object, num_parallels=NUM_PARALLELS, line_width=0) plotting_utils.plot_meridians(basemap_object=basemap_object, axes_object=axes_object, num_meridians=NUM_MERIDIANS, line_width=0) pyplot.pcolormesh(grid_edge_longitudes_deg, grid_edge_latitudes_deg, score_matrix_at_edges, cmap=colour_map_object, vmin=min_colour_value, vmax=max_colour_value, shading='flat', edgecolors='None', axes=axes_object, zorder=-1e12) colour_bar_object = plotting_utils.plot_linear_colour_bar( axes_object_or_matrix=axes_object, data_matrix=score_matrix, colour_map_object=colour_map_object, min_value=min_colour_value, max_value=max_colour_value, orientation_string='horizontal', extend_min=taper_cbar_bottom, extend_max=taper_cbar_top, padding=0.05, font_size=COLOUR_BAR_FONT_SIZE) tick_values = colour_bar_object.get_ticks() if log_scale: tick_strings = [ '{0:d}'.format(int(numpy.round(10**v))) for v in tick_values ] elif numpy.nanmax(numpy.absolute(score_matrix)) >= 6: tick_strings = [ '{0:d}'.format(int(numpy.round(v))) for v in tick_values ] else: tick_strings = ['{0:.2f}'.format(v) for v in tick_values] colour_bar_object.set_ticks(tick_values) colour_bar_object.set_ticklabels(tick_strings) return figure_object, axes_object