Example #1
0
def _run(input_saliency_file_name, input_gradcam_file_name,
         input_bwo_file_name, input_novelty_file_name, max_percentile_level,
         radar_channel_idx_for_thres, threshold_value, threshold_type_string,
         output_file_name):
    """Runs probability-matched means (PMM).

    This is effectively the main method.

    :param input_saliency_file_name: See documentation at top of file.
    :param input_gradcam_file_name: Same.
    :param input_bwo_file_name: Same.
    :param input_novelty_file_name: Same.
    :param max_percentile_level: Same.
    :param radar_channel_idx_for_thres: Same.
    :param threshold_value: Same.
    :param threshold_type_string: Same.
    :param output_file_name: Same.
    """

    if input_saliency_file_name not in NONE_STRINGS:
        input_gradcam_file_name = None
        input_bwo_file_name = None
        input_novelty_file_name = None
    elif input_gradcam_file_name not in NONE_STRINGS:
        input_saliency_file_name = None
        input_bwo_file_name = None
        input_novelty_file_name = None
    elif input_bwo_file_name not in NONE_STRINGS:
        input_saliency_file_name = None
        input_gradcam_file_name = None
        input_novelty_file_name = None
    else:
        input_saliency_file_name = None
        input_gradcam_file_name = None
        input_bwo_file_name = None

    if radar_channel_idx_for_thres < 0:
        radar_channel_idx_for_thres = None
        threshold_value = None
        threshold_type_string = None

    if input_saliency_file_name is not None:
        print('Reading data from: "{0:s}"...'.format(input_saliency_file_name))

        saliency_dict = saliency_maps.read_standard_file(
            input_saliency_file_name)
        list_of_input_matrices = saliency_dict[
            saliency_maps.INPUT_MATRICES_KEY]

    elif input_gradcam_file_name is not None:
        print('Reading data from: "{0:s}"...'.format(input_gradcam_file_name))

        gradcam_dict = gradcam.read_standard_file(input_gradcam_file_name)
        list_of_input_matrices = gradcam_dict[gradcam.INPUT_MATRICES_KEY]

    elif input_bwo_file_name is not None:
        print('Reading data from: "{0:s}"...'.format(input_bwo_file_name))

        bwo_dictionary = backwards_opt.read_standard_file(input_bwo_file_name)
        list_of_input_matrices = bwo_dictionary[
            backwards_opt.INIT_FUNCTION_KEY]

    else:
        print('Reading data from: "{0:s}"...'.format(input_novelty_file_name))
        novelty_dict = novelty_detection.read_standard_file(
            input_novelty_file_name)

        list_of_input_matrices = novelty_dict[
            novelty_detection.TRIAL_INPUTS_KEY]
        novel_indices = novelty_dict[novelty_detection.NOVEL_INDICES_KEY]

        list_of_input_matrices = [
            a[novel_indices, ...] for a in list_of_input_matrices
        ]

    print('Running PMM on denormalized predictor matrices...')

    num_input_matrices = len(list_of_input_matrices)
    list_of_mean_input_matrices = [None] * num_input_matrices
    pmm_metadata_dict = None
    threshold_count_matrix = None

    for i in range(num_input_matrices):
        if i == 0:
            list_of_mean_input_matrices[i], threshold_count_matrix = (
                pmm.run_pmm_many_variables(
                    input_matrix=list_of_input_matrices[i],
                    max_percentile_level=max_percentile_level,
                    threshold_var_index=radar_channel_idx_for_thres,
                    threshold_value=threshold_value,
                    threshold_type_string=threshold_type_string))

            pmm_metadata_dict = pmm.check_input_args(
                input_matrix=list_of_input_matrices[i],
                max_percentile_level=max_percentile_level,
                threshold_var_index=radar_channel_idx_for_thres,
                threshold_value=threshold_value,
                threshold_type_string=threshold_type_string)
        else:
            list_of_mean_input_matrices[i] = pmm.run_pmm_many_variables(
                input_matrix=list_of_input_matrices[i],
                max_percentile_level=max_percentile_level)[0]

    if input_saliency_file_name is not None:
        print('Running PMM on saliency matrices...')
        list_of_saliency_matrices = saliency_dict[
            saliency_maps.SALIENCY_MATRICES_KEY]

        num_input_matrices = len(list_of_input_matrices)
        list_of_mean_saliency_matrices = [None] * num_input_matrices

        for i in range(num_input_matrices):
            list_of_mean_saliency_matrices[i] = pmm.run_pmm_many_variables(
                input_matrix=list_of_saliency_matrices[i],
                max_percentile_level=max_percentile_level)[0]

        print('Writing output to: "{0:s}"...'.format(output_file_name))
        saliency_maps.write_pmm_file(
            pickle_file_name=output_file_name,
            list_of_mean_input_matrices=list_of_mean_input_matrices,
            list_of_mean_saliency_matrices=list_of_mean_saliency_matrices,
            threshold_count_matrix=threshold_count_matrix,
            model_file_name=saliency_dict[saliency_maps.MODEL_FILE_KEY],
            standard_saliency_file_name=input_saliency_file_name,
            pmm_metadata_dict=pmm_metadata_dict)

        return

    if input_gradcam_file_name is not None:
        print('Running PMM on class-activation matrices...')

        list_of_cam_matrices = gradcam_dict[gradcam.CAM_MATRICES_KEY]
        list_of_guided_cam_matrices = gradcam_dict[
            gradcam.GUIDED_CAM_MATRICES_KEY]

        num_input_matrices = len(list_of_input_matrices)
        list_of_mean_cam_matrices = [None] * num_input_matrices
        list_of_mean_guided_cam_matrices = [None] * num_input_matrices

        for i in range(num_input_matrices):
            if list_of_cam_matrices[i] is None:
                continue

            list_of_mean_cam_matrices[i] = pmm.run_pmm_many_variables(
                input_matrix=numpy.expand_dims(list_of_cam_matrices[i],
                                               axis=-1),
                max_percentile_level=max_percentile_level)[0]

            list_of_mean_cam_matrices[i] = list_of_mean_cam_matrices[i][..., 0]

            list_of_mean_guided_cam_matrices[i] = pmm.run_pmm_many_variables(
                input_matrix=list_of_guided_cam_matrices[i],
                max_percentile_level=max_percentile_level)[0]

        print('Writing output to: "{0:s}"...'.format(output_file_name))
        gradcam.write_pmm_file(
            pickle_file_name=output_file_name,
            list_of_mean_input_matrices=list_of_mean_input_matrices,
            list_of_mean_cam_matrices=list_of_mean_cam_matrices,
            list_of_mean_guided_cam_matrices=list_of_mean_guided_cam_matrices,
            model_file_name=gradcam_dict[gradcam.MODEL_FILE_KEY],
            standard_gradcam_file_name=input_gradcam_file_name,
            pmm_metadata_dict=pmm_metadata_dict)

        return

    if input_bwo_file_name is not None:
        print('Running PMM on backwards-optimization output...')
        list_of_optimized_matrices = bwo_dictionary[
            backwards_opt.OPTIMIZED_MATRICES_KEY]

        num_input_matrices = len(list_of_input_matrices)
        list_of_mean_optimized_matrices = [None] * num_input_matrices

        for i in range(num_input_matrices):
            list_of_mean_optimized_matrices[i] = pmm.run_pmm_many_variables(
                input_matrix=list_of_optimized_matrices[i],
                max_percentile_level=max_percentile_level)[0]

        mean_initial_activation = numpy.mean(
            bwo_dictionary[backwards_opt.INITIAL_ACTIVATIONS_KEY])
        mean_final_activation = numpy.mean(
            bwo_dictionary[backwards_opt.FINAL_ACTIVATIONS_KEY])

        print('Writing output to: "{0:s}"...'.format(output_file_name))
        backwards_opt.write_pmm_file(
            pickle_file_name=output_file_name,
            list_of_mean_input_matrices=list_of_mean_input_matrices,
            list_of_mean_optimized_matrices=list_of_mean_optimized_matrices,
            mean_initial_activation=mean_initial_activation,
            mean_final_activation=mean_final_activation,
            threshold_count_matrix=threshold_count_matrix,
            model_file_name=bwo_dictionary[backwards_opt.MODEL_FILE_KEY],
            standard_bwo_file_name=input_bwo_file_name,
            pmm_metadata_dict=pmm_metadata_dict)

        return

    print('Running PMM on novelty-detection output...')

    mean_novel_image_matrix_upconv = pmm.run_pmm_many_variables(
        input_matrix=novelty_dict[novelty_detection.NOVEL_IMAGES_UPCONV_KEY],
        max_percentile_level=max_percentile_level)[0]

    mean_novel_image_matrix_upconv_svd = pmm.run_pmm_many_variables(
        input_matrix=novelty_dict[
            novelty_detection.NOVEL_IMAGES_UPCONV_SVD_KEY],
        max_percentile_level=max_percentile_level)[0]

    print('Writing output to: "{0:s}"...'.format(output_file_name))
    novelty_detection.write_pmm_file(
        pickle_file_name=output_file_name,
        mean_novel_image_matrix=list_of_mean_input_matrices[0],
        mean_novel_image_matrix_upconv=mean_novel_image_matrix_upconv,
        mean_novel_image_matrix_upconv_svd=mean_novel_image_matrix_upconv_svd,
        threshold_count_matrix=threshold_count_matrix,
        standard_novelty_file_name=input_novelty_file_name,
        pmm_metadata_dict=pmm_metadata_dict)
Example #2
0
def _run(input_file_name, allow_whitespace, plot_significance,
         plot_regions_of_interest, colour_map_name, max_colour_percentile,
         top_output_dir_name):
    """Plots Grad-CAM output (class-activation maps).

    This is effectively the main method.

    :param input_file_name: See documentation at top of file.
    :param allow_whitespace: Same.
    :param plot_significance: Same.
    :param plot_regions_of_interest: Same.
    :param colour_map_name: Same.
    :param max_colour_percentile: Same.
    :param top_output_dir_name: Same.
    """

    if plot_significance:
        plot_regions_of_interest = False

    unguided_cam_dir_name = '{0:s}/main_gradcam'.format(top_output_dir_name)
    guided_cam_dir_name = '{0:s}/guided_gradcam'.format(top_output_dir_name)

    file_system_utils.mkdir_recursive_if_necessary(
        directory_name=unguided_cam_dir_name)
    file_system_utils.mkdir_recursive_if_necessary(
        directory_name=guided_cam_dir_name)

    # Check input args.
    error_checking.assert_is_geq(max_colour_percentile, 0.)
    error_checking.assert_is_leq(max_colour_percentile, 100.)
    colour_map_object = pyplot.cm.get_cmap(colour_map_name)

    print('Reading data from: "{0:s}"...'.format(input_file_name))

    try:
        gradcam_dict = gradcam.read_standard_file(input_file_name)
        list_of_input_matrices = gradcam_dict[gradcam.INPUT_MATRICES_KEY]
        list_of_cam_matrices = gradcam_dict[gradcam.CAM_MATRICES_KEY]
        list_of_guided_cam_matrices = gradcam_dict[
            gradcam.GUIDED_CAM_MATRICES_KEY]

        full_storm_id_strings = gradcam_dict[gradcam.FULL_IDS_KEY]
        storm_times_unix_sec = gradcam_dict[gradcam.STORM_TIMES_KEY]

    except ValueError:
        gradcam_dict = gradcam.read_pmm_file(input_file_name)
        list_of_input_matrices = gradcam_dict[gradcam.MEAN_INPUT_MATRICES_KEY]
        list_of_cam_matrices = gradcam_dict[gradcam.MEAN_CAM_MATRICES_KEY]
        list_of_guided_cam_matrices = gradcam_dict[
            gradcam.MEAN_GUIDED_CAM_MATRICES_KEY]

        for i in range(len(list_of_input_matrices)):
            list_of_input_matrices[i] = numpy.expand_dims(
                list_of_input_matrices[i], axis=0)

            if list_of_cam_matrices[i] is None:
                continue

            list_of_cam_matrices[i] = numpy.expand_dims(
                list_of_cam_matrices[i], axis=0)
            list_of_guided_cam_matrices[i] = numpy.expand_dims(
                list_of_guided_cam_matrices[i], axis=0)

        full_storm_id_strings = [None]
        storm_times_unix_sec = [None]

    pmm_flag = (full_storm_id_strings[0] is None
                and storm_times_unix_sec[0] is None)

    # Read metadata for CNN.
    model_file_name = gradcam_dict[gradcam.MODEL_FILE_KEY]
    model_metafile_name = '{0:s}/model_metadata.p'.format(
        os.path.split(model_file_name)[0])

    print(
        'Reading model metadata from: "{0:s}"...'.format(model_metafile_name))
    model_metadata_dict = cnn.read_model_metadata(model_metafile_name)
    print(SEPARATOR_STRING)

    cam_monte_carlo_dict = (gradcam_dict[gradcam.CAM_MONTE_CARLO_KEY]
                            if plot_significance else None)

    guided_cam_monte_carlo_dict = (
        gradcam_dict[gradcam.GUIDED_CAM_MONTE_CARLO_KEY]
        if plot_significance else None)

    region_dict = (gradcam_dict[gradcam.REGION_DICT_KEY]
                   if plot_regions_of_interest else None)

    num_examples = list_of_input_matrices[0].shape[0]
    num_input_matrices = len(list_of_input_matrices)

    for i in range(num_examples):
        these_figure_objects, these_axes_object_matrices = (
            plot_input_examples.plot_one_example(
                list_of_predictor_matrices=list_of_input_matrices,
                model_metadata_dict=model_metadata_dict,
                plot_sounding=False,
                allow_whitespace=allow_whitespace,
                pmm_flag=pmm_flag,
                example_index=i,
                full_storm_id_string=full_storm_id_strings[i],
                storm_time_unix_sec=storm_times_unix_sec[i]))

        for j in range(num_input_matrices):
            if list_of_cam_matrices[j] is None:
                continue

            if cam_monte_carlo_dict is None:
                this_significance_matrix = None
            else:
                this_significance_matrix = numpy.logical_or(
                    cam_monte_carlo_dict[
                        monte_carlo.TRIAL_PMM_MATRICES_KEY][j][i, ...] <
                    cam_monte_carlo_dict[monte_carlo.MIN_MATRICES_KEY][j][i,
                                                                          ...],
                    cam_monte_carlo_dict[
                        monte_carlo.TRIAL_PMM_MATRICES_KEY][j][i, ...] >
                    cam_monte_carlo_dict[monte_carlo.MAX_MATRICES_KEY][j][i,
                                                                          ...])

            this_num_spatial_dim = len(list_of_input_matrices[j].shape) - 2

            if this_num_spatial_dim == 3:
                _plot_3d_radar_cam(
                    colour_map_object=colour_map_object,
                    max_colour_percentile=max_colour_percentile,
                    figure_objects=these_figure_objects,
                    axes_object_matrices=these_axes_object_matrices,
                    model_metadata_dict=model_metadata_dict,
                    output_dir_name=unguided_cam_dir_name,
                    cam_matrix=list_of_cam_matrices[j][i, ...],
                    significance_matrix=this_significance_matrix,
                    full_storm_id_string=full_storm_id_strings[i],
                    storm_time_unix_sec=storm_times_unix_sec[i])
            else:
                if region_dict is None:
                    _plot_2d_radar_cam(
                        colour_map_object=colour_map_object,
                        max_colour_percentile=max_colour_percentile,
                        figure_objects=these_figure_objects,
                        axes_object_matrices=these_axes_object_matrices,
                        model_metadata_dict=model_metadata_dict,
                        output_dir_name=unguided_cam_dir_name,
                        cam_matrix=list_of_cam_matrices[j][i, ...],
                        significance_matrix=this_significance_matrix,
                        full_storm_id_string=full_storm_id_strings[i],
                        storm_time_unix_sec=storm_times_unix_sec[i])
                else:
                    _plot_2d_regions(
                        figure_objects=these_figure_objects,
                        axes_object_matrices=these_axes_object_matrices,
                        model_metadata_dict=model_metadata_dict,
                        list_of_polygon_objects=region_dict[
                            gradcam.POLYGON_OBJECTS_KEY][j][i],
                        output_dir_name=unguided_cam_dir_name,
                        full_storm_id_string=full_storm_id_strings[i],
                        storm_time_unix_sec=storm_times_unix_sec[i])

        these_figure_objects, these_axes_object_matrices = (
            plot_input_examples.plot_one_example(
                list_of_predictor_matrices=list_of_input_matrices,
                model_metadata_dict=model_metadata_dict,
                plot_sounding=False,
                allow_whitespace=allow_whitespace,
                pmm_flag=pmm_flag,
                example_index=i,
                full_storm_id_string=full_storm_id_strings[i],
                storm_time_unix_sec=storm_times_unix_sec[i]))

        for j in range(num_input_matrices):
            if list_of_guided_cam_matrices[j] is None:
                continue

            if guided_cam_monte_carlo_dict is None:
                this_significance_matrix = None
            else:
                this_significance_matrix = numpy.logical_or(
                    guided_cam_monte_carlo_dict[
                        monte_carlo.TRIAL_PMM_MATRICES_KEY][j][i, ...] <
                    guided_cam_monte_carlo_dict[monte_carlo.MIN_MATRICES_KEY]
                    [j][i, ...], guided_cam_monte_carlo_dict[
                        monte_carlo.TRIAL_PMM_MATRICES_KEY][j][i, ...] >
                    guided_cam_monte_carlo_dict[
                        monte_carlo.MAX_MATRICES_KEY][j][i, ...])

            this_num_spatial_dim = len(list_of_input_matrices[j].shape) - 2

            if this_num_spatial_dim == 3:
                _plot_3d_radar_cam(
                    colour_map_object=colour_map_object,
                    max_colour_percentile=max_colour_percentile,
                    figure_objects=these_figure_objects,
                    axes_object_matrices=these_axes_object_matrices,
                    model_metadata_dict=model_metadata_dict,
                    output_dir_name=guided_cam_dir_name,
                    guided_cam_matrix=list_of_guided_cam_matrices[j][i, ...],
                    significance_matrix=this_significance_matrix,
                    full_storm_id_string=full_storm_id_strings[i],
                    storm_time_unix_sec=storm_times_unix_sec[i])
            else:
                _plot_2d_radar_cam(
                    colour_map_object=colour_map_object,
                    max_colour_percentile=max_colour_percentile,
                    figure_objects=these_figure_objects,
                    axes_object_matrices=these_axes_object_matrices,
                    model_metadata_dict=model_metadata_dict,
                    output_dir_name=guided_cam_dir_name,
                    guided_cam_matrix=list_of_guided_cam_matrices[j][i, ...],
                    significance_matrix=this_significance_matrix,
                    full_storm_id_string=full_storm_id_strings[i],
                    storm_time_unix_sec=storm_times_unix_sec[i])
def _run(interpretation_type_string, baseline_file_name, trial_file_name,
         max_pmm_percentile_level, num_iterations, confidence_level,
         output_file_name):
    """Runs Monte Carlo significance test for interpretation output.

    This is effectively the main method.

    :param interpretation_type_string: See documentation at top of file.
    :param baseline_file_name: Same.
    :param trial_file_name: Same.
    :param max_pmm_percentile_level: Same.
    :param num_iterations: Same.
    :param confidence_level: Same.
    :param output_file_name: Same.
    :raises: ValueError: if
        `interpretation_type_string not in VALID_INTERPRETATION_TYPE_STRINGS`.
    """

    if interpretation_type_string not in VALID_INTERPRETATION_TYPE_STRINGS:
        error_string = (
            '\n{0:s}\nValid interpretation types (listed above) do not include '
            '"{1:s}".'
        ).format(
            str(VALID_INTERPRETATION_TYPE_STRINGS), interpretation_type_string
        )

        raise ValueError(error_string)

    print('Reading baseline set from: "{0:s}"...'.format(baseline_file_name))

    if interpretation_type_string == SALIENCY_STRING:
        baseline_dict = saliency_maps.read_standard_file(baseline_file_name)
    elif interpretation_type_string == GRADCAM_STRING:
        baseline_dict = gradcam.read_standard_file(baseline_file_name)
    else:
        baseline_dict = backwards_opt.read_standard_file(baseline_file_name)

    print('Reading trial set from: "{0:s}"...'.format(trial_file_name))
    monte_carlo_dict = None
    cam_monte_carlo_dict = None
    guided_cam_monte_carlo_dict = None

    if interpretation_type_string == SALIENCY_STRING:
        trial_dict = saliency_maps.read_standard_file(trial_file_name)

        monte_carlo_dict = monte_carlo.run_monte_carlo_test(
            list_of_baseline_matrices=baseline_dict[
                saliency_maps.SALIENCY_MATRICES_KEY],
            list_of_trial_matrices=trial_dict[
                saliency_maps.SALIENCY_MATRICES_KEY],
            max_pmm_percentile_level=max_pmm_percentile_level,
            num_iterations=num_iterations, confidence_level=confidence_level)

        monte_carlo_dict[monte_carlo.BASELINE_FILE_KEY] = baseline_file_name
        list_of_input_matrices = trial_dict[saliency_maps.INPUT_MATRICES_KEY]

    elif interpretation_type_string == GRADCAM_STRING:
        trial_dict = gradcam.read_standard_file(trial_file_name)

        cam_monte_carlo_dict = monte_carlo.run_monte_carlo_test(
            list_of_baseline_matrices=baseline_dict[gradcam.CAM_MATRICES_KEY],
            list_of_trial_matrices=trial_dict[gradcam.CAM_MATRICES_KEY],
            max_pmm_percentile_level=max_pmm_percentile_level,
            num_iterations=num_iterations, confidence_level=confidence_level)

        guided_cam_monte_carlo_dict = monte_carlo.run_monte_carlo_test(
            list_of_baseline_matrices=baseline_dict[
                gradcam.GUIDED_CAM_MATRICES_KEY],
            list_of_trial_matrices=trial_dict[
                gradcam.GUIDED_CAM_MATRICES_KEY],
            max_pmm_percentile_level=max_pmm_percentile_level,
            num_iterations=num_iterations, confidence_level=confidence_level)

        cam_monte_carlo_dict[
            monte_carlo.BASELINE_FILE_KEY] = baseline_file_name
        guided_cam_monte_carlo_dict[
            monte_carlo.BASELINE_FILE_KEY] = baseline_file_name
        list_of_input_matrices = trial_dict[gradcam.INPUT_MATRICES_KEY]

    else:
        trial_dict = backwards_opt.read_standard_file(trial_file_name)

        monte_carlo_dict = monte_carlo.run_monte_carlo_test(
            list_of_baseline_matrices=baseline_dict[
                backwards_opt.OPTIMIZED_MATRICES_KEY],
            list_of_trial_matrices=trial_dict[
                backwards_opt.OPTIMIZED_MATRICES_KEY],
            max_pmm_percentile_level=max_pmm_percentile_level,
            num_iterations=num_iterations, confidence_level=confidence_level)

        monte_carlo_dict[monte_carlo.BASELINE_FILE_KEY] = baseline_file_name
        list_of_input_matrices = trial_dict[backwards_opt.INIT_FUNCTION_KEY]

    print(SEPARATOR_STRING)

    num_matrices = len(list_of_input_matrices)
    list_of_mean_input_matrices = [None] * num_matrices

    for i in range(num_matrices):
        list_of_mean_input_matrices[i] = pmm.run_pmm_many_variables(
            input_matrix=list_of_input_matrices[i],
            max_percentile_level=max_pmm_percentile_level
        )[0]

    pmm_metadata_dict = pmm.check_input_args(
        input_matrix=list_of_input_matrices[0],
        max_percentile_level=max_pmm_percentile_level,
        threshold_var_index=None, threshold_value=None,
        threshold_type_string=None)

    print('Writing results to: "{0:s}"...'.format(output_file_name))

    if interpretation_type_string == SALIENCY_STRING:
        saliency_maps.write_pmm_file(
            pickle_file_name=output_file_name,
            list_of_mean_input_matrices=list_of_mean_input_matrices,
            list_of_mean_saliency_matrices=copy.deepcopy(
                monte_carlo_dict[monte_carlo.TRIAL_PMM_MATRICES_KEY]
            ),
            threshold_count_matrix=None,
            model_file_name=trial_dict[saliency_maps.MODEL_FILE_KEY],
            standard_saliency_file_name=trial_file_name,
            pmm_metadata_dict=pmm_metadata_dict,
            monte_carlo_dict=monte_carlo_dict)

    elif interpretation_type_string == GRADCAM_STRING:
        gradcam.write_pmm_file(
            pickle_file_name=output_file_name,
            list_of_mean_input_matrices=list_of_mean_input_matrices,
            list_of_mean_cam_matrices=copy.deepcopy(
                cam_monte_carlo_dict[monte_carlo.TRIAL_PMM_MATRICES_KEY]
            ),
            list_of_mean_guided_cam_matrices=copy.deepcopy(
                guided_cam_monte_carlo_dict[monte_carlo.TRIAL_PMM_MATRICES_KEY]
            ),
            model_file_name=trial_dict[gradcam.MODEL_FILE_KEY],
            standard_gradcam_file_name=trial_file_name,
            pmm_metadata_dict=pmm_metadata_dict,
            cam_monte_carlo_dict=cam_monte_carlo_dict,
            guided_cam_monte_carlo_dict=guided_cam_monte_carlo_dict)

    else:
        backwards_opt.write_pmm_file(
            pickle_file_name=output_file_name,
            list_of_mean_input_matrices=list_of_mean_input_matrices,
            list_of_mean_optimized_matrices=copy.deepcopy(
                monte_carlo_dict[monte_carlo.TRIAL_PMM_MATRICES_KEY]
            ),
            mean_initial_activation=numpy.mean(
                trial_dict[backwards_opt.INITIAL_ACTIVATIONS_KEY]
            ),
            mean_final_activation=numpy.mean(
                trial_dict[backwards_opt.FINAL_ACTIVATIONS_KEY]
            ),
            threshold_count_matrix=None,
            model_file_name=trial_dict[backwards_opt.MODEL_FILE_KEY],
            standard_bwo_file_name=trial_file_name,
            pmm_metadata_dict=pmm_metadata_dict,
            monte_carlo_dict=monte_carlo_dict)
def _run(input_gradcam_file_name, percentile_threshold, min_class_activation,
         output_file_name):
    """Thresholds Grad-CAM output to create regions of interest (polygons).

    This is effectively the main method.

    :param input_gradcam_file_name: See documentation at top of file.
    :param percentile_threshold: Same.
    :param min_class_activation: Same.
    :param output_file_name: Same.
    :raises: TypeError: if any class-activation map contains not-2 spatial
        dimensions.
    """

    error_checking.assert_is_geq(percentile_threshold, 50.)
    error_checking.assert_is_less_than(percentile_threshold, 100.)
    error_checking.assert_is_greater(min_class_activation, 0.)

    print('Reading data from: "{0:s}"...\n'.format(input_gradcam_file_name))
    pmm_flag = False

    try:
        gradcam_dict = gradcam.read_standard_file(input_gradcam_file_name)
        list_of_cam_matrices = gradcam_dict.pop(gradcam.CAM_MATRICES_KEY)
    except ValueError:
        gradcam_dict = gradcam.read_pmm_file(input_gradcam_file_name)
        list_of_cam_matrices = gradcam_dict.pop(gradcam.MEAN_CAM_MATRICES_KEY)

        for j in range(len(list_of_cam_matrices)):
            if list_of_cam_matrices[j] is None:
                continue

            list_of_cam_matrices[j] = numpy.expand_dims(
                list_of_cam_matrices[j], axis=0
            )

        pmm_flag = True

    num_matrices = len(list_of_cam_matrices)
    num_examples = None

    for j in range(num_matrices):
        if list_of_cam_matrices[j] is None:
            continue

        num_examples = list_of_cam_matrices[j].shape[0]
        this_num_spatial_dim = len(list_of_cam_matrices[j].shape) - 1
        if this_num_spatial_dim == 2:
            continue

        error_string = (
            'This script deals with only 2-D class-activation maps.  {0:d}th '
            'input matrix contains {1:d} spatial dimensions.'
        ).format(j + 1, this_num_spatial_dim)

        raise TypeError(error_string)

    list_of_mask_matrices = [None] * num_matrices
    list_of_polygon_objects = [[[] * 0] * num_examples] * num_matrices

    for i in range(num_examples):
        for j in range(num_matrices):
            if list_of_cam_matrices[j] is None:
                continue

            this_min_class_activation = numpy.percentile(
                list_of_cam_matrices[j][i, ...], percentile_threshold
            )

            this_min_class_activation = max([
                this_min_class_activation, min_class_activation
            ])

            print((
                'Creating mask for {0:d}th example and {1:d}th class-activation'
                ' matrix, with threshold = {2:.3e}...'
            ).format(
                i + 1, j + 1, this_min_class_activation
            ))

            this_mask_matrix = (
                list_of_cam_matrices[j][i, ...] >= this_min_class_activation
            )

            print('{0:d} of {1:d} grid points are inside mask.\n'.format(
                numpy.sum(this_mask_matrix.astype(int)), this_mask_matrix.size
            ))

            list_of_polygon_objects[j][i] = _mask_to_polygons(this_mask_matrix)
            this_mask_matrix = numpy.expand_dims(this_mask_matrix, axis=0)

            if list_of_mask_matrices[j] is None:
                list_of_mask_matrices[j] = copy.deepcopy(this_mask_matrix)
            else:
                list_of_mask_matrices[j] = numpy.concatenate(
                    (list_of_mask_matrices[j], this_mask_matrix), axis=0
                )

    if pmm_flag:
        for j in range(len(list_of_mask_matrices)):
            if list_of_mask_matrices[j] is None:
                continue

            list_of_mask_matrices[j] = list_of_mask_matrices[j][0, ...]

    region_dict = {
        gradcam.MASK_MATRICES_KEY: list_of_mask_matrices,
        gradcam.POLYGON_OBJECTS_KEY: list_of_polygon_objects,
        gradcam.PERCENTILE_THRESHOLD_KEY: percentile_threshold,
        gradcam.MIN_CLASS_ACTIVATION_KEY: min_class_activation
    }

    if output_file_name in ['', 'None']:
        output_file_name = input_gradcam_file_name

    print('Writing regions of interest to: "{0:s}"...'.format(output_file_name))
    gradcam.add_regions_to_file(
        input_file_name=input_gradcam_file_name,
        output_file_name=output_file_name, region_dict=region_dict)
Example #5
0
def _run(input_human_file_name, input_machine_file_name, guided_gradcam_flag,
         abs_percentile_threshold, output_dir_name):
    """Compares human-generated vs. machine-generated interpretation map.

    This is effectively the main method.

    :param input_human_file_name: See documentation at top of file.
    :param input_machine_file_name: Same.
    :param guided_gradcam_flag: Same.
    :param abs_percentile_threshold: Same.
    :param output_dir_name: Same.
    """

    file_system_utils.mkdir_recursive_if_necessary(
        directory_name=output_dir_name)

    if abs_percentile_threshold < 0:
        abs_percentile_threshold = None
    if abs_percentile_threshold is not None:
        error_checking.assert_is_leq(abs_percentile_threshold, 100.)

    print('Reading data from: "{0:s}"...'.format(input_human_file_name))
    human_polygon_dict = human_polygons.read_polygons(input_human_file_name)

    human_positive_mask_matrix_4d = human_polygon_dict[
        human_polygons.POSITIVE_MASK_MATRIX_KEY]
    human_negative_mask_matrix_4d = human_polygon_dict[
        human_polygons.NEGATIVE_MASK_MATRIX_KEY]

    full_storm_id_string = human_polygon_dict[human_polygons.STORM_ID_KEY]
    storm_time_unix_sec = human_polygon_dict[human_polygons.STORM_TIME_KEY]
    pmm_flag = full_storm_id_string is None and storm_time_unix_sec is None

    print('Reading data from: "{0:s}"...'.format(input_machine_file_name))

    # TODO(thunderhoser): This is a HACK.
    machine_channel_indices = numpy.array([2, 8], dtype=int)

    if pmm_flag:
        try:
            saliency_dict = saliency_maps.read_pmm_file(input_machine_file_name)
            saliency_flag = True
            model_file_name = saliency_dict[saliency_maps.MODEL_FILE_KEY]

            predictor_matrix = saliency_dict.pop(
                saliency_maps.MEAN_INPUT_MATRICES_KEY
            )[0][..., machine_channel_indices]

            machine_interpretation_matrix_3d = saliency_dict.pop(
                saliency_maps.MEAN_SALIENCY_MATRICES_KEY
            )[0][..., machine_channel_indices]

        except ValueError:
            gradcam_dict = gradcam.read_pmm_file(input_machine_file_name)
            saliency_flag = False
            model_file_name = gradcam_dict[gradcam.MODEL_FILE_KEY]

            predictor_matrix = gradcam_dict.pop(
                gradcam.MEAN_INPUT_MATRICES_KEY
            )[0][..., machine_channel_indices]

            if guided_gradcam_flag:
                machine_interpretation_matrix_3d = gradcam_dict.pop(
                    gradcam.MEAN_GUIDED_GRADCAM_KEY
                )[..., machine_channel_indices]
            else:
                machine_interpretation_matrix_3d = gradcam_dict.pop(
                    gradcam.MEAN_CLASS_ACTIVATIONS_KEY)
    else:
        try:
            saliency_dict = saliency_maps.read_standard_file(
                input_machine_file_name)

            saliency_flag = True
            all_full_id_strings = saliency_dict[saliency_maps.FULL_IDS_KEY]
            all_times_unix_sec = saliency_dict[saliency_maps.STORM_TIMES_KEY]
            model_file_name = saliency_dict[saliency_maps.MODEL_FILE_KEY]

            predictor_matrix = saliency_dict.pop(
                saliency_maps.INPUT_MATRICES_KEY
            )[0][..., machine_channel_indices]

            machine_interpretation_matrix_3d = saliency_dict.pop(
                saliency_maps.SALIENCY_MATRICES_KEY
            )[0][..., machine_channel_indices]

        except ValueError:
            gradcam_dict = gradcam.read_standard_file(input_machine_file_name)

            saliency_flag = False
            all_full_id_strings = gradcam_dict[gradcam.FULL_IDS_KEY]
            all_times_unix_sec = gradcam_dict[gradcam.STORM_TIMES_KEY]
            model_file_name = gradcam_dict[gradcam.MODEL_FILE_KEY]

            predictor_matrix = gradcam_dict.pop(
                gradcam.INPUT_MATRICES_KEY
            )[0][..., machine_channel_indices]

            if guided_gradcam_flag:
                machine_interpretation_matrix_3d = gradcam_dict.pop(
                    gradcam.GUIDED_GRADCAM_KEY
                )[..., machine_channel_indices]
            else:
                machine_interpretation_matrix_3d = gradcam_dict.pop(
                    gradcam.CLASS_ACTIVATIONS_KEY)

        storm_object_index = tracking_utils.find_storm_objects(
            all_id_strings=all_full_id_strings,
            all_times_unix_sec=all_times_unix_sec,
            id_strings_to_keep=[full_storm_id_string],
            times_to_keep_unix_sec=numpy.array(
                [storm_time_unix_sec], dtype=int
            ),
            allow_missing=False
        )[0]

        predictor_matrix = predictor_matrix[storm_object_index, ...]
        machine_interpretation_matrix_3d = machine_interpretation_matrix_3d[
            storm_object_index, ...]

    if not saliency_flag and not guided_gradcam_flag:
        machine_interpretation_matrix_3d = numpy.expand_dims(
            machine_interpretation_matrix_3d, axis=-1)

        machine_interpretation_matrix_3d = numpy.repeat(
            a=machine_interpretation_matrix_3d,
            repeats=predictor_matrix.shape[-1], axis=-1)

    if not (saliency_flag or guided_gradcam_flag):
        human_negative_mask_matrix_4d = None

    model_metafile_name = '{0:s}/model_metadata.p'.format(
        os.path.split(model_file_name)[0]
    )

    print('Reading metadata from: "{0:s}"...'.format(model_metafile_name))
    model_metadata_dict = cnn.read_model_metadata(model_metafile_name)
    model_metadata_dict[cnn.LAYER_OPERATIONS_KEY] = [
        model_metadata_dict[cnn.LAYER_OPERATIONS_KEY][k]
        for k in machine_channel_indices
    ]

    human_positive_mask_matrix_3d, human_negative_mask_matrix_3d = (
        _reshape_human_maps(
            model_metadata_dict=model_metadata_dict,
            positive_mask_matrix_4d=human_positive_mask_matrix_4d,
            negative_mask_matrix_4d=human_negative_mask_matrix_4d)
    )

    num_channels = human_positive_mask_matrix_3d.shape[-1]
    machine_positive_mask_matrix_3d = numpy.full(
        human_positive_mask_matrix_3d.shape, False, dtype=bool)
    positive_iou_by_channel = numpy.full(num_channels, numpy.nan)

    if human_negative_mask_matrix_3d is None:
        machine_negative_mask_matrix_3d = None
        negative_iou_by_channel = None
    else:
        machine_negative_mask_matrix_3d = numpy.full(
            human_negative_mask_matrix_3d.shape, False, dtype=bool)
        negative_iou_by_channel = numpy.full(num_channels, numpy.nan)

    for k in range(num_channels):
        this_negative_matrix = (
            None if human_negative_mask_matrix_3d is None
            else human_negative_mask_matrix_3d[..., k]
        )

        this_comparison_dict = _do_comparison_one_channel(
            machine_interpretation_matrix_2d=machine_interpretation_matrix_3d[
                ..., k],
            abs_percentile_threshold=abs_percentile_threshold,
            human_positive_mask_matrix_2d=human_positive_mask_matrix_3d[..., k],
            human_negative_mask_matrix_2d=this_negative_matrix)

        machine_positive_mask_matrix_3d[..., k] = this_comparison_dict[
            MACHINE_POSITIVE_MASK_KEY]
        positive_iou_by_channel[k] = this_comparison_dict[POSITIVE_IOU_KEY]

        if human_negative_mask_matrix_3d is None:
            continue

        machine_negative_mask_matrix_3d[..., k] = this_comparison_dict[
            MACHINE_NEGATIVE_MASK_KEY]
        negative_iou_by_channel[k] = this_comparison_dict[NEGATIVE_IOU_KEY]

    this_file_name = '{0:s}/positive_comparison.jpg'.format(output_dir_name)
    _plot_comparison(
        predictor_matrix=predictor_matrix,
        model_metadata_dict=model_metadata_dict,
        machine_mask_matrix_3d=machine_positive_mask_matrix_3d,
        human_mask_matrix_3d=human_positive_mask_matrix_3d,
        iou_by_channel=positive_iou_by_channel,
        positive_flag=True, output_file_name=this_file_name)

    if human_negative_mask_matrix_3d is None:
        return

    this_file_name = '{0:s}/negative_comparison.jpg'.format(output_dir_name)
    _plot_comparison(
        predictor_matrix=predictor_matrix,
        model_metadata_dict=model_metadata_dict,
        machine_mask_matrix_3d=machine_negative_mask_matrix_3d,
        human_mask_matrix_3d=human_negative_mask_matrix_3d,
        iou_by_channel=negative_iou_by_channel,
        positive_flag=False, output_file_name=this_file_name)
def _run(input_file_name, cam_colour_map_name, max_colour_prctile_for_cam,
         top_output_dir_name):
    """Plots Grad-CAM output (class-activation maps).

    This is effectively the main method.

    :param input_file_name: See documentation at top of file.
    :param cam_colour_map_name: Same.
    :param max_colour_prctile_for_cam: Same.
    :param top_output_dir_name: Same.
    :raises: TypeError: if input file contains class-activation maps for
        soundings.
    """

    main_gradcam_dir_name = '{0:s}/main_gradcam'.format(top_output_dir_name)
    guided_gradcam_dir_name = '{0:s}/guided_gradcam'.format(top_output_dir_name)

    file_system_utils.mkdir_recursive_if_necessary(
        directory_name=main_gradcam_dir_name)
    file_system_utils.mkdir_recursive_if_necessary(
        directory_name=guided_gradcam_dir_name)

    # Check input args.
    error_checking.assert_is_geq(max_colour_prctile_for_cam, 0.)
    error_checking.assert_is_leq(max_colour_prctile_for_cam, 100.)
    cam_colour_map_object = pyplot.cm.get_cmap(cam_colour_map_name)

    print 'Reading data from: "{0:s}"...'.format(input_file_name)

    try:
        gradcam_dict = gradcam.read_standard_file(input_file_name)
        list_of_input_matrices = gradcam_dict.pop(gradcam.INPUT_MATRICES_KEY)
        class_activation_matrix = gradcam_dict.pop(
            gradcam.CLASS_ACTIVATIONS_KEY)
        ggradcam_output_matrix = gradcam_dict.pop(gradcam.GUIDED_GRADCAM_KEY)

        gradcam_metadata_dict = gradcam_dict
        storm_ids = gradcam_metadata_dict[gradcam.STORM_IDS_KEY]
        storm_times_unix_sec = gradcam_metadata_dict[gradcam.STORM_TIMES_KEY]

    except ValueError:
        gradcam_dict = gradcam.read_pmm_file(input_file_name)
        list_of_input_matrices = gradcam_dict[gradcam.MEAN_INPUT_MATRICES_KEY]
        class_activation_matrix = gradcam_dict[
            gradcam.MEAN_CLASS_ACTIVATIONS_KEY]
        ggradcam_output_matrix = gradcam_dict[gradcam.MEAN_GUIDED_GRADCAM_KEY]

        for i in range(len(list_of_input_matrices)):
            list_of_input_matrices[i] = numpy.expand_dims(
                list_of_input_matrices[i], axis=0)

        class_activation_matrix = numpy.expand_dims(
            class_activation_matrix, axis=0)
        ggradcam_output_matrix = numpy.expand_dims(
            ggradcam_output_matrix, axis=0)

        orig_gradcam_file_name = gradcam_dict[gradcam.STANDARD_FILE_NAME_KEY]

        print 'Reading metadata from: "{0:s}"...'.format(orig_gradcam_file_name)
        orig_gradcam_dict = gradcam.read_standard_file(orig_gradcam_file_name)

        orig_gradcam_dict.pop(gradcam.INPUT_MATRICES_KEY)
        orig_gradcam_dict.pop(gradcam.CLASS_ACTIVATIONS_KEY)
        orig_gradcam_dict.pop(gradcam.GUIDED_GRADCAM_KEY)
        gradcam_metadata_dict = orig_gradcam_dict

        storm_ids = None
        storm_times_unix_sec = None

    num_spatial_dimensions = len(class_activation_matrix.shape) - 1
    if num_spatial_dimensions == 1:
        raise TypeError('This script is not yet equipped to plot '
                        'class-activation maps for soundings.')

    # Read metadata for CNN.
    model_file_name = gradcam_metadata_dict[gradcam.MODEL_FILE_NAME_KEY]
    model_metafile_name = '{0:s}/model_metadata.p'.format(
        os.path.split(model_file_name)[0]
    )

    print 'Reading model metadata from: "{0:s}"...'.format(model_metafile_name)
    model_metadata_dict = cnn.read_model_metadata(model_metafile_name)
    print SEPARATOR_STRING

    # Do plotting.
    if num_spatial_dimensions == 3:
        _plot_3d_radar_cams(
            radar_matrix=list_of_input_matrices[0],
            class_activation_matrix=class_activation_matrix,
            model_metadata_dict=model_metadata_dict,
            cam_colour_map_object=cam_colour_map_object,
            max_colour_prctile_for_cam=max_colour_prctile_for_cam,
            output_dir_name=main_gradcam_dir_name,
            storm_ids=storm_ids, storm_times_unix_sec=storm_times_unix_sec)
        print SEPARATOR_STRING

        _plot_3d_radar_cams(
            radar_matrix=list_of_input_matrices[0],
            ggradcam_output_matrix=ggradcam_output_matrix,
            model_metadata_dict=model_metadata_dict,
            cam_colour_map_object=cam_colour_map_object,
            max_colour_prctile_for_cam=max_colour_prctile_for_cam,
            output_dir_name=guided_gradcam_dir_name,
            storm_ids=storm_ids, storm_times_unix_sec=storm_times_unix_sec)

    else:
        if len(list_of_input_matrices) == 3:
            radar_matrix = list_of_input_matrices[1]
        else:
            radar_matrix = list_of_input_matrices[0]

        _plot_2d_radar_cams(
            radar_matrix=radar_matrix,
            class_activation_matrix=class_activation_matrix,
            model_metadata_dict=model_metadata_dict,
            cam_colour_map_object=cam_colour_map_object,
            max_colour_prctile_for_cam=max_colour_prctile_for_cam,
            output_dir_name=main_gradcam_dir_name,
            storm_ids=storm_ids, storm_times_unix_sec=storm_times_unix_sec)
        print SEPARATOR_STRING

        _plot_2d_radar_cams(
            radar_matrix=radar_matrix,
            ggradcam_output_matrix=ggradcam_output_matrix,
            model_metadata_dict=model_metadata_dict,
            cam_colour_map_object=cam_colour_map_object,
            max_colour_prctile_for_cam=max_colour_prctile_for_cam,
            output_dir_name=guided_gradcam_dir_name,
            storm_ids=storm_ids, storm_times_unix_sec=storm_times_unix_sec)
Example #7
0
def _run(input_gradcam_file_name, input_human_file_name,
         output_gradcam_file_name):
    """Runs human novelty detection on class-activation maps.
    
    This is effectively the main method.
    
    :param input_gradcam_file_name: See documentation at top of file.
    :param input_human_file_name: Same.
    :param output_gradcam_file_name: Same.
    :raises: TypeError: if `input_gradcam_file_name` was created by a net that
        does both 2-D and 3-D convolution.
    :raises: TypeError: if class-activation maps are not 2-D.
    """

    print('Reading data from: "{0:s}"...'.format(input_human_file_name))
    human_point_dict = human_polygons.read_points(input_human_file_name)

    # TODO(thunderhoser): Deal with no human points.
    human_grid_rows = human_point_dict[human_polygons.GRID_ROW_BY_POINT_KEY]
    human_grid_columns = human_point_dict[
        human_polygons.GRID_COLUMN_BY_POINT_KEY]
    human_panel_rows = human_point_dict[human_polygons.PANEL_ROW_BY_POINT_KEY]
    human_panel_columns = human_point_dict[
        human_polygons.PANEL_COLUMN_BY_POINT_KEY]

    full_storm_id_string = human_point_dict[human_polygons.STORM_ID_KEY]
    storm_time_unix_sec = human_point_dict[human_polygons.STORM_TIME_KEY]
    pmm_flag = full_storm_id_string is None and storm_time_unix_sec is None

    print('Reading data from: "{0:s}"...'.format(input_gradcam_file_name))

    if pmm_flag:
        gradcam_dict = gradcam.read_pmm_file(input_gradcam_file_name)
    else:
        gradcam_dict = gradcam.read_standard_file(input_gradcam_file_name)

    machine_region_dict = gradcam_dict[gradcam.REGION_DICT_KEY]
    list_of_mask_matrices = machine_region_dict[gradcam.MASK_MATRICES_KEY]

    if len(list_of_mask_matrices) == 3:
        raise TypeError('This script cannot handle nets that do both 2-D and '
                        '3-D convolution.')

    machine_mask_matrix = list_of_mask_matrices[0]

    if pmm_flag:
        storm_object_index = -1
    else:
        storm_object_index = tracking_utils.find_storm_objects(
            all_id_strings=gradcam_dict[gradcam.FULL_IDS_KEY],
            all_times_unix_sec=gradcam_dict[gradcam.STORM_TIMES_KEY],
            id_strings_to_keep=[full_storm_id_string],
            times_to_keep_unix_sec=numpy.array([storm_time_unix_sec],
                                               dtype=int),
            allow_missing=False)[0]

        machine_mask_matrix = machine_mask_matrix[storm_object_index, ...]

    num_spatial_dimensions = len(machine_mask_matrix.shape) - 1
    if num_spatial_dimensions != 2:
        raise TypeError(
            'This script can compare only with 2-D class-activation'
            ' maps.')

    if pmm_flag:
        machine_polygon_objects = machine_region_dict[
            gradcam.POLYGON_OBJECTS_KEY][0][0]
    else:
        machine_polygon_objects = machine_region_dict[
            gradcam.POLYGON_OBJECTS_KEY][storm_object_index][0]