Beispiel #1
0
def _run(input_saliency_file_name, input_gradcam_file_name,
         input_bwo_file_name, input_novelty_file_name, max_percentile_level,
         radar_channel_idx_for_thres, threshold_value, threshold_type_string,
         output_file_name):
    """Runs probability-matched means (PMM).

    This is effectively the main method.

    :param input_saliency_file_name: See documentation at top of file.
    :param input_gradcam_file_name: Same.
    :param input_bwo_file_name: Same.
    :param input_novelty_file_name: Same.
    :param max_percentile_level: Same.
    :param radar_channel_idx_for_thres: Same.
    :param threshold_value: Same.
    :param threshold_type_string: Same.
    :param output_file_name: Same.
    """

    if input_saliency_file_name not in NONE_STRINGS:
        input_gradcam_file_name = None
        input_bwo_file_name = None
        input_novelty_file_name = None
    elif input_gradcam_file_name not in NONE_STRINGS:
        input_saliency_file_name = None
        input_bwo_file_name = None
        input_novelty_file_name = None
    elif input_bwo_file_name not in NONE_STRINGS:
        input_saliency_file_name = None
        input_gradcam_file_name = None
        input_novelty_file_name = None
    else:
        input_saliency_file_name = None
        input_gradcam_file_name = None
        input_bwo_file_name = None

    if radar_channel_idx_for_thres < 0:
        radar_channel_idx_for_thres = None
        threshold_value = None
        threshold_type_string = None

    if input_saliency_file_name is not None:
        print('Reading data from: "{0:s}"...'.format(input_saliency_file_name))

        saliency_dict = saliency_maps.read_standard_file(
            input_saliency_file_name)
        list_of_input_matrices = saliency_dict[
            saliency_maps.INPUT_MATRICES_KEY]

    elif input_gradcam_file_name is not None:
        print('Reading data from: "{0:s}"...'.format(input_gradcam_file_name))

        gradcam_dict = gradcam.read_standard_file(input_gradcam_file_name)
        list_of_input_matrices = gradcam_dict[gradcam.INPUT_MATRICES_KEY]

    elif input_bwo_file_name is not None:
        print('Reading data from: "{0:s}"...'.format(input_bwo_file_name))

        bwo_dictionary = backwards_opt.read_standard_file(input_bwo_file_name)
        list_of_input_matrices = bwo_dictionary[
            backwards_opt.INIT_FUNCTION_KEY]

    else:
        print('Reading data from: "{0:s}"...'.format(input_novelty_file_name))
        novelty_dict = novelty_detection.read_standard_file(
            input_novelty_file_name)

        list_of_input_matrices = novelty_dict[
            novelty_detection.TRIAL_INPUTS_KEY]
        novel_indices = novelty_dict[novelty_detection.NOVEL_INDICES_KEY]

        list_of_input_matrices = [
            a[novel_indices, ...] for a in list_of_input_matrices
        ]

    print('Running PMM on denormalized predictor matrices...')

    num_input_matrices = len(list_of_input_matrices)
    list_of_mean_input_matrices = [None] * num_input_matrices
    pmm_metadata_dict = None
    threshold_count_matrix = None

    for i in range(num_input_matrices):
        if i == 0:
            list_of_mean_input_matrices[i], threshold_count_matrix = (
                pmm.run_pmm_many_variables(
                    input_matrix=list_of_input_matrices[i],
                    max_percentile_level=max_percentile_level,
                    threshold_var_index=radar_channel_idx_for_thres,
                    threshold_value=threshold_value,
                    threshold_type_string=threshold_type_string))

            pmm_metadata_dict = pmm.check_input_args(
                input_matrix=list_of_input_matrices[i],
                max_percentile_level=max_percentile_level,
                threshold_var_index=radar_channel_idx_for_thres,
                threshold_value=threshold_value,
                threshold_type_string=threshold_type_string)
        else:
            list_of_mean_input_matrices[i] = pmm.run_pmm_many_variables(
                input_matrix=list_of_input_matrices[i],
                max_percentile_level=max_percentile_level)[0]

    if input_saliency_file_name is not None:
        print('Running PMM on saliency matrices...')
        list_of_saliency_matrices = saliency_dict[
            saliency_maps.SALIENCY_MATRICES_KEY]

        num_input_matrices = len(list_of_input_matrices)
        list_of_mean_saliency_matrices = [None] * num_input_matrices

        for i in range(num_input_matrices):
            list_of_mean_saliency_matrices[i] = pmm.run_pmm_many_variables(
                input_matrix=list_of_saliency_matrices[i],
                max_percentile_level=max_percentile_level)[0]

        print('Writing output to: "{0:s}"...'.format(output_file_name))
        saliency_maps.write_pmm_file(
            pickle_file_name=output_file_name,
            list_of_mean_input_matrices=list_of_mean_input_matrices,
            list_of_mean_saliency_matrices=list_of_mean_saliency_matrices,
            threshold_count_matrix=threshold_count_matrix,
            model_file_name=saliency_dict[saliency_maps.MODEL_FILE_KEY],
            standard_saliency_file_name=input_saliency_file_name,
            pmm_metadata_dict=pmm_metadata_dict)

        return

    if input_gradcam_file_name is not None:
        print('Running PMM on class-activation matrices...')

        list_of_cam_matrices = gradcam_dict[gradcam.CAM_MATRICES_KEY]
        list_of_guided_cam_matrices = gradcam_dict[
            gradcam.GUIDED_CAM_MATRICES_KEY]

        num_input_matrices = len(list_of_input_matrices)
        list_of_mean_cam_matrices = [None] * num_input_matrices
        list_of_mean_guided_cam_matrices = [None] * num_input_matrices

        for i in range(num_input_matrices):
            if list_of_cam_matrices[i] is None:
                continue

            list_of_mean_cam_matrices[i] = pmm.run_pmm_many_variables(
                input_matrix=numpy.expand_dims(list_of_cam_matrices[i],
                                               axis=-1),
                max_percentile_level=max_percentile_level)[0]

            list_of_mean_cam_matrices[i] = list_of_mean_cam_matrices[i][..., 0]

            list_of_mean_guided_cam_matrices[i] = pmm.run_pmm_many_variables(
                input_matrix=list_of_guided_cam_matrices[i],
                max_percentile_level=max_percentile_level)[0]

        print('Writing output to: "{0:s}"...'.format(output_file_name))
        gradcam.write_pmm_file(
            pickle_file_name=output_file_name,
            list_of_mean_input_matrices=list_of_mean_input_matrices,
            list_of_mean_cam_matrices=list_of_mean_cam_matrices,
            list_of_mean_guided_cam_matrices=list_of_mean_guided_cam_matrices,
            model_file_name=gradcam_dict[gradcam.MODEL_FILE_KEY],
            standard_gradcam_file_name=input_gradcam_file_name,
            pmm_metadata_dict=pmm_metadata_dict)

        return

    if input_bwo_file_name is not None:
        print('Running PMM on backwards-optimization output...')
        list_of_optimized_matrices = bwo_dictionary[
            backwards_opt.OPTIMIZED_MATRICES_KEY]

        num_input_matrices = len(list_of_input_matrices)
        list_of_mean_optimized_matrices = [None] * num_input_matrices

        for i in range(num_input_matrices):
            list_of_mean_optimized_matrices[i] = pmm.run_pmm_many_variables(
                input_matrix=list_of_optimized_matrices[i],
                max_percentile_level=max_percentile_level)[0]

        mean_initial_activation = numpy.mean(
            bwo_dictionary[backwards_opt.INITIAL_ACTIVATIONS_KEY])
        mean_final_activation = numpy.mean(
            bwo_dictionary[backwards_opt.FINAL_ACTIVATIONS_KEY])

        print('Writing output to: "{0:s}"...'.format(output_file_name))
        backwards_opt.write_pmm_file(
            pickle_file_name=output_file_name,
            list_of_mean_input_matrices=list_of_mean_input_matrices,
            list_of_mean_optimized_matrices=list_of_mean_optimized_matrices,
            mean_initial_activation=mean_initial_activation,
            mean_final_activation=mean_final_activation,
            threshold_count_matrix=threshold_count_matrix,
            model_file_name=bwo_dictionary[backwards_opt.MODEL_FILE_KEY],
            standard_bwo_file_name=input_bwo_file_name,
            pmm_metadata_dict=pmm_metadata_dict)

        return

    print('Running PMM on novelty-detection output...')

    mean_novel_image_matrix_upconv = pmm.run_pmm_many_variables(
        input_matrix=novelty_dict[novelty_detection.NOVEL_IMAGES_UPCONV_KEY],
        max_percentile_level=max_percentile_level)[0]

    mean_novel_image_matrix_upconv_svd = pmm.run_pmm_many_variables(
        input_matrix=novelty_dict[
            novelty_detection.NOVEL_IMAGES_UPCONV_SVD_KEY],
        max_percentile_level=max_percentile_level)[0]

    print('Writing output to: "{0:s}"...'.format(output_file_name))
    novelty_detection.write_pmm_file(
        pickle_file_name=output_file_name,
        mean_novel_image_matrix=list_of_mean_input_matrices[0],
        mean_novel_image_matrix_upconv=mean_novel_image_matrix_upconv,
        mean_novel_image_matrix_upconv_svd=mean_novel_image_matrix_upconv_svd,
        threshold_count_matrix=threshold_count_matrix,
        standard_novelty_file_name=input_novelty_file_name,
        pmm_metadata_dict=pmm_metadata_dict)
Beispiel #2
0
def _run(input_file_name, diff_colour_map_name, max_colour_percentile_for_diff,
         top_output_dir_name):
    """Plots results of backwards optimization.

    This is effectively the main method.

    :param input_file_name: See documentation at top of file.
    :param diff_colour_map_name: Same.
    :param max_colour_percentile_for_diff: Same.
    :param top_output_dir_name: Same.
    """

    pmm_flag = False

    error_checking.assert_is_geq(max_colour_percentile_for_diff, 0.)
    error_checking.assert_is_leq(max_colour_percentile_for_diff, 100.)
    diff_colour_map_object = pyplot.cm.get_cmap(diff_colour_map_name)

    print 'Reading data from: "{0:s}"...'.format(input_file_name)

    try:
        backwards_opt_dict = backwards_opt.read_standard_file(input_file_name)
        list_of_optimized_matrices = backwards_opt_dict.pop(
            backwards_opt.OPTIMIZED_MATRICES_KEY)
        list_of_input_matrices = backwards_opt_dict.pop(
            backwards_opt.INIT_FUNCTION_KEY)

        if not isinstance(list_of_input_matrices, list):
            list_of_input_matrices = None

        bwo_metadata_dict = backwards_opt_dict
        storm_ids = bwo_metadata_dict[backwards_opt.STORM_IDS_KEY]
        storm_times_unix_sec = bwo_metadata_dict[backwards_opt.STORM_TIMES_KEY]

    except ValueError:
        pmm_flag = True
        backwards_opt_dict = backwards_opt.read_pmm_file(input_file_name)

        list_of_input_matrices = backwards_opt_dict.pop(
            backwards_opt.MEAN_INPUT_MATRICES_KEY)
        list_of_optimized_matrices = backwards_opt_dict.pop(
            backwards_opt.MEAN_OPTIMIZED_MATRICES_KEY)

        for i in range(len(list_of_input_matrices)):
            list_of_input_matrices[i] = numpy.expand_dims(
                list_of_input_matrices[i], axis=0)
            list_of_optimized_matrices[i] = numpy.expand_dims(
                list_of_optimized_matrices[i], axis=0)

        original_bwo_file_name = backwards_opt_dict[
            backwards_opt.STANDARD_FILE_NAME_KEY]

        print 'Reading metadata from: "{0:s}"...'.format(
            original_bwo_file_name)
        original_bwo_dict = backwards_opt.read_standard_file(
            original_bwo_file_name)

        original_bwo_dict.pop(backwards_opt.OPTIMIZED_MATRICES_KEY)
        original_bwo_dict.pop(backwards_opt.INIT_FUNCTION_KEY)
        bwo_metadata_dict = original_bwo_dict

        storm_ids = None
        storm_times_unix_sec = None

    model_file_name = bwo_metadata_dict[backwards_opt.MODEL_FILE_NAME_KEY]
    model_metafile_name = '{0:s}/model_metadata.p'.format(
        os.path.split(model_file_name)[0]
    )

    print 'Reading metadata from: "{0:s}"...'.format(model_metafile_name)
    model_metadata_dict = cnn.read_model_metadata(model_metafile_name)
    training_option_dict = model_metadata_dict[cnn.TRAINING_OPTION_DICT_KEY]
    sounding_field_names = training_option_dict[trainval_io.SOUNDING_FIELDS_KEY]

    print SEPARATOR_STRING

    if sounding_field_names is not None:
        if list_of_input_matrices is None:
            this_input_matrix = None
        else:
            this_input_matrix = list_of_input_matrices[-1]

        _plot_bwo_for_soundings(
            optimized_sounding_matrix=list_of_optimized_matrices[-1],
            training_option_dict=training_option_dict,
            top_output_dir_name=top_output_dir_name, pmm_flag=pmm_flag,
            input_sounding_matrix=this_input_matrix, storm_ids=storm_ids,
            storm_times_unix_sec=storm_times_unix_sec)
        print SEPARATOR_STRING

    if model_metadata_dict[cnn.USE_2D3D_CONVOLUTION_KEY]:
        _plot_bwo_for_2d3d_radar(
            list_of_optimized_matrices=list_of_optimized_matrices,
            training_option_dict=training_option_dict,
            diff_colour_map_object=diff_colour_map_object,
            max_colour_percentile_for_diff=max_colour_percentile_for_diff,
            top_output_dir_name=top_output_dir_name, pmm_flag=pmm_flag,
            list_of_input_matrices=list_of_input_matrices,
            storm_ids=storm_ids, storm_times_unix_sec=storm_times_unix_sec)
        return

    if list_of_input_matrices is None:
        this_input_matrix = None
    else:
        this_input_matrix = list_of_input_matrices[0]

    num_radar_dimensions = len(list_of_optimized_matrices[0].shape) - 2
    if num_radar_dimensions == 3:
        _plot_bwo_for_3d_radar(
            optimized_radar_matrix=list_of_optimized_matrices[0],
            training_option_dict=training_option_dict,
            diff_colour_map_object=diff_colour_map_object,
            max_colour_percentile_for_diff=max_colour_percentile_for_diff,
            top_output_dir_name=top_output_dir_name, pmm_flag=pmm_flag,
            input_radar_matrix=this_input_matrix,
            storm_ids=storm_ids, storm_times_unix_sec=storm_times_unix_sec)
        return

    _plot_bwo_for_2d_radar(
        optimized_radar_matrix=list_of_optimized_matrices[0],
        model_metadata_dict=model_metadata_dict,
        diff_colour_map_object=diff_colour_map_object,
        max_colour_percentile_for_diff=max_colour_percentile_for_diff,
        top_output_dir_name=top_output_dir_name, pmm_flag=pmm_flag,
        input_radar_matrix=this_input_matrix,
        storm_ids=storm_ids, storm_times_unix_sec=storm_times_unix_sec)
def _run(interpretation_type_string, baseline_file_name, trial_file_name,
         max_pmm_percentile_level, num_iterations, confidence_level,
         output_file_name):
    """Runs Monte Carlo significance test for interpretation output.

    This is effectively the main method.

    :param interpretation_type_string: See documentation at top of file.
    :param baseline_file_name: Same.
    :param trial_file_name: Same.
    :param max_pmm_percentile_level: Same.
    :param num_iterations: Same.
    :param confidence_level: Same.
    :param output_file_name: Same.
    :raises: ValueError: if
        `interpretation_type_string not in VALID_INTERPRETATION_TYPE_STRINGS`.
    """

    if interpretation_type_string not in VALID_INTERPRETATION_TYPE_STRINGS:
        error_string = (
            '\n{0:s}\nValid interpretation types (listed above) do not include '
            '"{1:s}".'
        ).format(
            str(VALID_INTERPRETATION_TYPE_STRINGS), interpretation_type_string
        )

        raise ValueError(error_string)

    print('Reading baseline set from: "{0:s}"...'.format(baseline_file_name))

    if interpretation_type_string == SALIENCY_STRING:
        baseline_dict = saliency_maps.read_standard_file(baseline_file_name)
    elif interpretation_type_string == GRADCAM_STRING:
        baseline_dict = gradcam.read_standard_file(baseline_file_name)
    else:
        baseline_dict = backwards_opt.read_standard_file(baseline_file_name)

    print('Reading trial set from: "{0:s}"...'.format(trial_file_name))
    monte_carlo_dict = None
    cam_monte_carlo_dict = None
    guided_cam_monte_carlo_dict = None

    if interpretation_type_string == SALIENCY_STRING:
        trial_dict = saliency_maps.read_standard_file(trial_file_name)

        monte_carlo_dict = monte_carlo.run_monte_carlo_test(
            list_of_baseline_matrices=baseline_dict[
                saliency_maps.SALIENCY_MATRICES_KEY],
            list_of_trial_matrices=trial_dict[
                saliency_maps.SALIENCY_MATRICES_KEY],
            max_pmm_percentile_level=max_pmm_percentile_level,
            num_iterations=num_iterations, confidence_level=confidence_level)

        monte_carlo_dict[monte_carlo.BASELINE_FILE_KEY] = baseline_file_name
        list_of_input_matrices = trial_dict[saliency_maps.INPUT_MATRICES_KEY]

    elif interpretation_type_string == GRADCAM_STRING:
        trial_dict = gradcam.read_standard_file(trial_file_name)

        cam_monte_carlo_dict = monte_carlo.run_monte_carlo_test(
            list_of_baseline_matrices=baseline_dict[gradcam.CAM_MATRICES_KEY],
            list_of_trial_matrices=trial_dict[gradcam.CAM_MATRICES_KEY],
            max_pmm_percentile_level=max_pmm_percentile_level,
            num_iterations=num_iterations, confidence_level=confidence_level)

        guided_cam_monte_carlo_dict = monte_carlo.run_monte_carlo_test(
            list_of_baseline_matrices=baseline_dict[
                gradcam.GUIDED_CAM_MATRICES_KEY],
            list_of_trial_matrices=trial_dict[
                gradcam.GUIDED_CAM_MATRICES_KEY],
            max_pmm_percentile_level=max_pmm_percentile_level,
            num_iterations=num_iterations, confidence_level=confidence_level)

        cam_monte_carlo_dict[
            monte_carlo.BASELINE_FILE_KEY] = baseline_file_name
        guided_cam_monte_carlo_dict[
            monte_carlo.BASELINE_FILE_KEY] = baseline_file_name
        list_of_input_matrices = trial_dict[gradcam.INPUT_MATRICES_KEY]

    else:
        trial_dict = backwards_opt.read_standard_file(trial_file_name)

        monte_carlo_dict = monte_carlo.run_monte_carlo_test(
            list_of_baseline_matrices=baseline_dict[
                backwards_opt.OPTIMIZED_MATRICES_KEY],
            list_of_trial_matrices=trial_dict[
                backwards_opt.OPTIMIZED_MATRICES_KEY],
            max_pmm_percentile_level=max_pmm_percentile_level,
            num_iterations=num_iterations, confidence_level=confidence_level)

        monte_carlo_dict[monte_carlo.BASELINE_FILE_KEY] = baseline_file_name
        list_of_input_matrices = trial_dict[backwards_opt.INIT_FUNCTION_KEY]

    print(SEPARATOR_STRING)

    num_matrices = len(list_of_input_matrices)
    list_of_mean_input_matrices = [None] * num_matrices

    for i in range(num_matrices):
        list_of_mean_input_matrices[i] = pmm.run_pmm_many_variables(
            input_matrix=list_of_input_matrices[i],
            max_percentile_level=max_pmm_percentile_level
        )[0]

    pmm_metadata_dict = pmm.check_input_args(
        input_matrix=list_of_input_matrices[0],
        max_percentile_level=max_pmm_percentile_level,
        threshold_var_index=None, threshold_value=None,
        threshold_type_string=None)

    print('Writing results to: "{0:s}"...'.format(output_file_name))

    if interpretation_type_string == SALIENCY_STRING:
        saliency_maps.write_pmm_file(
            pickle_file_name=output_file_name,
            list_of_mean_input_matrices=list_of_mean_input_matrices,
            list_of_mean_saliency_matrices=copy.deepcopy(
                monte_carlo_dict[monte_carlo.TRIAL_PMM_MATRICES_KEY]
            ),
            threshold_count_matrix=None,
            model_file_name=trial_dict[saliency_maps.MODEL_FILE_KEY],
            standard_saliency_file_name=trial_file_name,
            pmm_metadata_dict=pmm_metadata_dict,
            monte_carlo_dict=monte_carlo_dict)

    elif interpretation_type_string == GRADCAM_STRING:
        gradcam.write_pmm_file(
            pickle_file_name=output_file_name,
            list_of_mean_input_matrices=list_of_mean_input_matrices,
            list_of_mean_cam_matrices=copy.deepcopy(
                cam_monte_carlo_dict[monte_carlo.TRIAL_PMM_MATRICES_KEY]
            ),
            list_of_mean_guided_cam_matrices=copy.deepcopy(
                guided_cam_monte_carlo_dict[monte_carlo.TRIAL_PMM_MATRICES_KEY]
            ),
            model_file_name=trial_dict[gradcam.MODEL_FILE_KEY],
            standard_gradcam_file_name=trial_file_name,
            pmm_metadata_dict=pmm_metadata_dict,
            cam_monte_carlo_dict=cam_monte_carlo_dict,
            guided_cam_monte_carlo_dict=guided_cam_monte_carlo_dict)

    else:
        backwards_opt.write_pmm_file(
            pickle_file_name=output_file_name,
            list_of_mean_input_matrices=list_of_mean_input_matrices,
            list_of_mean_optimized_matrices=copy.deepcopy(
                monte_carlo_dict[monte_carlo.TRIAL_PMM_MATRICES_KEY]
            ),
            mean_initial_activation=numpy.mean(
                trial_dict[backwards_opt.INITIAL_ACTIVATIONS_KEY]
            ),
            mean_final_activation=numpy.mean(
                trial_dict[backwards_opt.FINAL_ACTIVATIONS_KEY]
            ),
            threshold_count_matrix=None,
            model_file_name=trial_dict[backwards_opt.MODEL_FILE_KEY],
            standard_bwo_file_name=trial_file_name,
            pmm_metadata_dict=pmm_metadata_dict,
            monte_carlo_dict=monte_carlo_dict)
def _run(input_file_name, plot_significance, diff_colour_map_name,
         max_colour_percentile, top_output_dir_name):
    """Plots results of backwards optimization.

    This is effectively the main method.

    :param input_file_name: See documentation at top of file.
    :param plot_significance: Same.
    :param diff_colour_map_name: Same.
    :param max_colour_percentile: Same.
    :param top_output_dir_name: Same.
    """

    before_optimization_dir_name = '{0:s}/before_optimization'.format(
        top_output_dir_name)
    after_optimization_dir_name = '{0:s}/after_optimization'.format(
        top_output_dir_name)
    difference_dir_name = '{0:s}/difference'.format(top_output_dir_name)

    file_system_utils.mkdir_recursive_if_necessary(
        directory_name=before_optimization_dir_name)
    file_system_utils.mkdir_recursive_if_necessary(
        directory_name=after_optimization_dir_name)
    file_system_utils.mkdir_recursive_if_necessary(
        directory_name=difference_dir_name)

    error_checking.assert_is_geq(max_colour_percentile, 0.)
    error_checking.assert_is_leq(max_colour_percentile, 100.)
    diff_colour_map_object = pyplot.cm.get_cmap(diff_colour_map_name)

    print('Reading data from: "{0:s}"...'.format(input_file_name))

    try:
        backwards_opt_dict = backwards_opt.read_standard_file(input_file_name)
        list_of_optimized_matrices = backwards_opt_dict[
            backwards_opt.OPTIMIZED_MATRICES_KEY]
        list_of_input_matrices = backwards_opt_dict[
            backwards_opt.INIT_FUNCTION_KEY]

        full_storm_id_strings = backwards_opt_dict[backwards_opt.FULL_IDS_KEY]
        storm_times_unix_sec = backwards_opt_dict[
            backwards_opt.STORM_TIMES_KEY]

        storm_time_strings = [
            time_conversion.unix_sec_to_string(t,
                                               plot_input_examples.TIME_FORMAT)
            for t in storm_times_unix_sec
        ]

    except ValueError:
        backwards_opt_dict = backwards_opt.read_pmm_file(input_file_name)
        list_of_input_matrices = backwards_opt_dict[
            backwards_opt.MEAN_INPUT_MATRICES_KEY]
        list_of_optimized_matrices = backwards_opt_dict[
            backwards_opt.MEAN_OPTIMIZED_MATRICES_KEY]

        for i in range(len(list_of_input_matrices)):
            list_of_input_matrices[i] = numpy.expand_dims(
                list_of_input_matrices[i], axis=0)
            list_of_optimized_matrices[i] = numpy.expand_dims(
                list_of_optimized_matrices[i], axis=0)

        full_storm_id_strings = [None]
        storm_times_unix_sec = [None]
        storm_time_strings = [None]

    pmm_flag = (full_storm_id_strings[0] is None
                and storm_time_strings[0] is None)

    model_file_name = backwards_opt_dict[backwards_opt.MODEL_FILE_KEY]
    model_metafile_name = '{0:s}/model_metadata.p'.format(
        os.path.split(model_file_name)[0])

    print('Reading metadata from: "{0:s}"...'.format(model_metafile_name))
    model_metadata_dict = cnn.read_model_metadata(model_metafile_name)
    print(SEPARATOR_STRING)

    training_option_dict = model_metadata_dict[cnn.TRAINING_OPTION_DICT_KEY]
    include_soundings = (training_option_dict[trainval_io.SOUNDING_FIELDS_KEY]
                         is not None)

    if include_soundings:
        _plot_bwo_for_soundings(
            input_sounding_matrix=list_of_input_matrices[-1],
            optimized_sounding_matrix=list_of_optimized_matrices[-1],
            training_option_dict=training_option_dict,
            pmm_flag=pmm_flag,
            backwards_opt_dict=backwards_opt_dict,
            top_output_dir_name=top_output_dir_name)

        print(SEPARATOR_STRING)

    # TODO(thunderhoser): Make sure to not plot soundings here.
    plot_input_examples.plot_examples(
        list_of_predictor_matrices=list_of_input_matrices,
        model_metadata_dict=model_metadata_dict,
        output_dir_name=before_optimization_dir_name,
        allow_whitespace=True,
        pmm_flag=pmm_flag,
        full_storm_id_strings=full_storm_id_strings,
        storm_times_unix_sec=storm_times_unix_sec)
    print(SEPARATOR_STRING)

    plot_input_examples.plot_examples(
        list_of_predictor_matrices=list_of_optimized_matrices,
        model_metadata_dict=model_metadata_dict,
        output_dir_name=after_optimization_dir_name,
        allow_whitespace=True,
        pmm_flag=pmm_flag,
        full_storm_id_strings=full_storm_id_strings,
        storm_times_unix_sec=storm_times_unix_sec)
    print(SEPARATOR_STRING)

    monte_carlo_dict = (
        backwards_opt_dict[backwards_opt.MONTE_CARLO_DICT_KEY]
        if plot_significance
        and backwards_opt.MONTE_CARLO_DICT_KEY in backwards_opt_dict else None)

    num_examples = list_of_optimized_matrices[0].shape[0]
    num_radar_matrices = (len(list_of_optimized_matrices) -
                          int(include_soundings))

    for i in range(num_examples):
        # TODO(thunderhoser): Make BWO file always store initial matrices, even
        # if they are created by a function.

        for j in range(num_radar_matrices):
            if monte_carlo_dict is None:
                this_significance_matrix = None
            else:
                this_significance_matrix = numpy.logical_or(
                    monte_carlo_dict[monte_carlo.TRIAL_PMM_MATRICES_KEY][j][
                        i, ...] <
                    monte_carlo_dict[monte_carlo.MIN_MATRICES_KEY][j][i, ...],
                    monte_carlo_dict[monte_carlo.TRIAL_PMM_MATRICES_KEY][j][
                        i, ...] >
                    monte_carlo_dict[monte_carlo.MAX_MATRICES_KEY][j][i, ...])

            this_difference_matrix = (list_of_optimized_matrices[j][i, ...] -
                                      list_of_input_matrices[j][i, ...])

            this_num_spatial_dim = len(list_of_input_matrices[j].shape) - 2

            if this_num_spatial_dim == 3:
                _plot_3d_radar_difference(
                    difference_matrix=this_difference_matrix,
                    colour_map_object=diff_colour_map_object,
                    max_colour_percentile=max_colour_percentile,
                    model_metadata_dict=model_metadata_dict,
                    backwards_opt_dict=backwards_opt_dict,
                    output_dir_name=difference_dir_name,
                    example_index=i,
                    significance_matrix=this_significance_matrix)
            else:
                _plot_2d_radar_difference(
                    difference_matrix=this_difference_matrix,
                    colour_map_object=diff_colour_map_object,
                    max_colour_percentile=max_colour_percentile,
                    model_metadata_dict=model_metadata_dict,
                    backwards_opt_dict=backwards_opt_dict,
                    output_dir_name=difference_dir_name,
                    example_index=i,
                    significance_matrix=this_significance_matrix)