示例#1
0
def prune_prototypes(
    dataloader,
    prototype_network_parallel,
    k,
    prune_threshold,
    preprocess_input_function,
    original_model_dir,
    epoch_number,
    # model_name=None,
    log=print,
    copy_prototype_imgs=True,
):
    ### run global analysis
    nearest_train_patch_class_ids = find_nearest.find_k_nearest_patches_to_prototypes(
        dataloader=dataloader,
        prototype_network_parallel=prototype_network_parallel,
        k=k,
        preprocess_input_function=preprocess_input_function,
        full_save=False,
        log=log,
    )

    ### find prototypes to prune
    original_num_prototypes = prototype_network_parallel.module.num_prototypes

    prototypes_to_prune = []
    for j in range(prototype_network_parallel.module.num_prototypes):
        class_j = torch.argmax(prototype_network_parallel.module.
                               prototype_class_identity[j]).item()
        nearest_train_patch_class_counts_j = Counter(
            nearest_train_patch_class_ids[j])
        # if no such element is in Counter, it will return 0
        if nearest_train_patch_class_counts_j[class_j] < prune_threshold:
            prototypes_to_prune.append(j)

    log("k = {}, prune_threshold = {}".format(k, prune_threshold))
    log("{} prototypes will be pruned".format(len(prototypes_to_prune)))

    ### bookkeeping of prototypes to be pruned
    class_of_prototypes_to_prune = (torch.argmax(
        prototype_network_parallel.module.
        prototype_class_identity[prototypes_to_prune],
        dim=1,
    ).numpy().reshape(-1, 1))
    prototypes_to_prune_np = np.array(prototypes_to_prune).reshape(-1, 1)
    prune_info = np.hstack(
        (prototypes_to_prune_np, class_of_prototypes_to_prune))
    makedir(
        os.path.join(
            original_model_dir,
            "pruned_prototypes_epoch{}_k{}_pt{}".format(
                epoch_number, k, prune_threshold),
        ))
    np.save(
        os.path.join(
            original_model_dir,
            "pruned_prototypes_epoch{}_k{}_pt{}".format(
                epoch_number, k, prune_threshold),
            "prune_info.npy",
        ),
        prune_info,
    )

    ### prune prototypes
    prototype_network_parallel.module.prune_prototypes(prototypes_to_prune)
    # torch.save(obj=prototype_network_parallel.module,
    #           f=os.path.join(original_model_dir, 'pruned_prototypes_epoch{}_k{}_pt{}'.format(epoch_number,
    #                                              k,
    #                                              prune_threshold),
    #                          model_name + '-pruned.pth'))
    if copy_prototype_imgs:
        original_img_dir = os.path.join(original_model_dir, "img",
                                        "epoch-%d" % epoch_number)
        dst_img_dir = os.path.join(
            original_model_dir,
            "pruned_prototypes_epoch{}_k{}_pt{}".format(
                epoch_number, k, prune_threshold),
            "img",
            "epoch-%d" % epoch_number,
        )
        makedir(dst_img_dir)
        prototypes_to_keep = list(
            set(range(original_num_prototypes)) - set(prototypes_to_prune))

        for idx in range(len(prototypes_to_keep)):
            shutil.copyfile(
                src=os.path.join(
                    original_img_dir,
                    "prototype-img%d.png" % prototypes_to_keep[idx]),
                dst=os.path.join(dst_img_dir, "prototype-img%d.png" % idx),
            )

            shutil.copyfile(
                src=os.path.join(
                    original_img_dir,
                    "prototype-img-original%d.png" % prototypes_to_keep[idx],
                ),
                dst=os.path.join(dst_img_dir,
                                 "prototype-img-original%d.png" % idx),
            )

            shutil.copyfile(
                src=os.path.join(
                    original_img_dir,
                    "prototype-img-original_with_self_act%d.png" %
                    prototypes_to_keep[idx],
                ),
                dst=os.path.join(
                    dst_img_dir,
                    "prototype-img-original_with_self_act%d.png" % idx),
            )

            shutil.copyfile(
                src=os.path.join(
                    original_img_dir,
                    "prototype-self-act%d.npy" % prototypes_to_keep[idx],
                ),
                dst=os.path.join(dst_img_dir,
                                 "prototype-self-act%d.npy" % idx),
            )

            bb = np.load(
                os.path.join(original_img_dir, "bb%d.npy" % epoch_number))
            bb = bb[prototypes_to_keep]
            np.save(os.path.join(dst_img_dir, "bb%d.npy" % epoch_number), bb)

            bb_rf = np.load(
                os.path.join(original_img_dir,
                             "bb-receptive_field%d.npy" % epoch_number))
            bb_rf = bb_rf[prototypes_to_keep]
            np.save(
                os.path.join(dst_img_dir,
                             "bb-receptive_field%d.npy" % epoch_number),
                bb_rf,
            )

    return prune_info
def push_prototypes(
    dataloader,  # pytorch dataloader (must be unnormalized in [0,1])
    prototype_network_parallel,  # pytorch network with prototype_vectors
    class_specific=True,
    preprocess_input_function=None,  # normalize if needed
    prototype_layer_stride=1,
    root_dir_for_saving_prototypes=None,  # if not None, prototypes will be saved here
    epoch_number=None,  # if not provided, prototypes saved previously will be overwritten
    prototype_img_filename_prefix=None,
    prototype_self_act_filename_prefix=None,
    proto_bound_boxes_filename_prefix=None,
    save_prototype_class_identity=True,  # which class the prototype image comes from
    log=print,
    prototype_activation_function_in_numpy=None,
):

    prototype_network_parallel.eval()
    log("\tpush")

    start = time.time()
    prototype_shape = prototype_network_parallel.module.prototype_shape
    n_prototypes = prototype_network_parallel.module.num_prototypes
    # saves the closest distance seen so far
    global_min_proto_dist = np.full(n_prototypes, np.inf)
    # saves the patch representation that gives the current smallest distance
    global_min_fmap_patches = np.zeros([
        n_prototypes, prototype_shape[1], prototype_shape[2],
        prototype_shape[3]
    ])
    """
    proto_rf_boxes and proto_bound_boxes column:
    0: image index in the entire dataset
    1: height start index
    2: height end index
    3: width start index
    4: width end index
    5: (optional) class identity
    """
    if save_prototype_class_identity:
        proto_rf_boxes = np.full(shape=[n_prototypes, 6], fill_value=-1)
        proto_bound_boxes = np.full(shape=[n_prototypes, 6], fill_value=-1)
    else:
        proto_rf_boxes = np.full(shape=[n_prototypes, 5], fill_value=-1)
        proto_bound_boxes = np.full(shape=[n_prototypes, 5], fill_value=-1)

    if root_dir_for_saving_prototypes != None:
        if epoch_number != None:
            proto_epoch_dir = os.path.join(root_dir_for_saving_prototypes,
                                           "epoch-" + str(epoch_number))
            makedir(proto_epoch_dir)
        else:
            proto_epoch_dir = root_dir_for_saving_prototypes
    else:
        proto_epoch_dir = None

    search_batch_size = dataloader.batch_size

    num_classes = prototype_network_parallel.module.num_classes

    for push_iter, (search_batch_input, search_y) in enumerate(dataloader):
        """
        start_index_of_search keeps track of the index of the image
        assigned to serve as prototype
        """
        start_index_of_search_batch = push_iter * search_batch_size

        update_prototypes_on_batch(
            search_batch_input,
            start_index_of_search_batch,
            prototype_network_parallel,
            global_min_proto_dist,
            global_min_fmap_patches,
            proto_rf_boxes,
            proto_bound_boxes,
            class_specific=class_specific,
            search_y=search_y,
            num_classes=num_classes,
            preprocess_input_function=preprocess_input_function,
            prototype_layer_stride=prototype_layer_stride,
            dir_for_saving_prototypes=proto_epoch_dir,
            prototype_img_filename_prefix=prototype_img_filename_prefix,
            prototype_self_act_filename_prefix=
            prototype_self_act_filename_prefix,
            prototype_activation_function_in_numpy=
            prototype_activation_function_in_numpy,
        )

    if proto_epoch_dir != None and proto_bound_boxes_filename_prefix != None:
        np.save(
            os.path.join(
                proto_epoch_dir,
                proto_bound_boxes_filename_prefix + "-receptive_field" +
                str(epoch_number) + ".npy",
            ),
            proto_rf_boxes,
        )
        np.save(
            os.path.join(
                proto_epoch_dir,
                proto_bound_boxes_filename_prefix + str(epoch_number) + ".npy",
            ),
            proto_bound_boxes,
        )

    log("\tExecuting push ...")
    prototype_update = np.reshape(global_min_fmap_patches,
                                  tuple(prototype_shape))
    prototype_network_parallel.module.prototype_vectors.data.copy_(
        torch.tensor(prototype_update, dtype=torch.float32).cuda())
    # prototype_network_parallel.cuda()
    end = time.time()
    log("\tpush time: \t{0}".format(end - start))
def find_k_nearest_patches_to_prototypes(
    dataloader,  # pytorch dataloader (must be unnormalized in [0,1])
    prototype_network_parallel,  # pytorch network with prototype_vectors
    k=5,
    preprocess_input_function=None,  # normalize if needed
    full_save=False,  # save all the images
    root_dir_for_saving_images="./nearest",
    log=print,
    prototype_activation_function_in_numpy=None,
):
    prototype_network_parallel.eval()
    """
    full_save=False will only return the class identity of the closest
    patches, but it will not save anything.
    """
    log("find nearest patches")
    start = time.time()
    n_prototypes = prototype_network_parallel.module.num_prototypes

    prototype_shape = prototype_network_parallel.module.prototype_shape
    max_dist = prototype_shape[1] * prototype_shape[2] * prototype_shape[3]

    protoL_rf_info = prototype_network_parallel.module.proto_layer_rf_info

    heaps = []
    # allocate an array of n_prototypes number of heaps
    for _ in range(n_prototypes):
        # a heap in python is just a maintained list
        heaps.append([])

    for idx, (search_batch_input, search_y) in enumerate(dataloader):
        print("batch {}".format(idx))
        if preprocess_input_function is not None:
            # print('preprocessing input for pushing ...')
            # search_batch = copy.deepcopy(search_batch_input)
            search_batch = preprocess_input_function(search_batch_input)

        else:
            search_batch = search_batch_input

        with torch.no_grad():
            search_batch = search_batch.cuda()
            (
                protoL_input_torch,
                proto_dist_torch,
            ) = prototype_network_parallel.module.push_forward(search_batch)

        # protoL_input_ = np.copy(protoL_input_torch.detach().cpu().numpy())
        proto_dist_ = np.copy(proto_dist_torch.detach().cpu().numpy())

        for img_idx, distance_map in enumerate(proto_dist_):
            for j in range(n_prototypes):
                # find the closest patches in this batch to prototype j

                closest_patch_distance_to_prototype_j = np.amin(distance_map[j])

                if full_save:
                    closest_patch_indices_in_distance_map_j = list(
                        np.unravel_index(
                            np.argmin(distance_map[j], axis=None), distance_map[j].shape
                        )
                    )
                    closest_patch_indices_in_distance_map_j = [
                        0
                    ] + closest_patch_indices_in_distance_map_j
                    closest_patch_indices_in_img = compute_rf_prototype(
                        search_batch.size(2),
                        closest_patch_indices_in_distance_map_j,
                        protoL_rf_info,
                    )
                    closest_patch = search_batch_input[
                        img_idx,
                        :,
                        closest_patch_indices_in_img[1] : closest_patch_indices_in_img[
                            2
                        ],
                        closest_patch_indices_in_img[3] : closest_patch_indices_in_img[
                            4
                        ],
                    ]
                    closest_patch = closest_patch.numpy()
                    closest_patch = np.transpose(closest_patch, (1, 2, 0))

                    original_img = search_batch_input[img_idx].numpy()
                    original_img = np.transpose(original_img, (1, 2, 0))

                    if (
                        prototype_network_parallel.module.prototype_activation_function
                        == "log"
                    ):
                        act_pattern = np.log(
                            (distance_map[j] + 1)
                            / (
                                distance_map[j]
                                + prototype_network_parallel.module.epsilon
                            )
                        )
                    elif (
                        prototype_network_parallel.module.prototype_activation_function
                        == "linear"
                    ):
                        act_pattern = max_dist - distance_map[j]
                    else:
                        act_pattern = prototype_activation_function_in_numpy(
                            distance_map[j]
                        )

                    # 4 numbers: height_start, height_end, width_start, width_end
                    patch_indices = closest_patch_indices_in_img[1:5]

                    # construct the closest patch object
                    closest_patch = ImagePatch(
                        patch=closest_patch,
                        label=search_y[img_idx],
                        distance=closest_patch_distance_to_prototype_j,
                        original_img=original_img,
                        act_pattern=act_pattern,
                        patch_indices=patch_indices,
                    )
                else:
                    closest_patch = ImagePatchInfo(
                        label=search_y[img_idx],
                        distance=closest_patch_distance_to_prototype_j,
                    )

                # add to the j-th heap
                if len(heaps[j]) < k:
                    heapq.heappush(heaps[j], closest_patch)
                else:
                    # heappushpop runs more efficiently than heappush
                    # followed by heappop
                    heapq.heappushpop(heaps[j], closest_patch)

    # after looping through the dataset every heap will
    # have the k closest prototypes
    for j in range(n_prototypes):
        # finally sort the heap; the heap only contains the k closest
        # but they are not ranked yet
        heaps[j].sort()
        heaps[j] = heaps[j][::-1]

        if full_save:

            dir_for_saving_images = os.path.join(root_dir_for_saving_images, str(j))
            makedir(dir_for_saving_images)

            labels = []

            for i, patch in enumerate(heaps[j]):
                # save the activation pattern of the original image where the patch comes from
                np.save(
                    os.path.join(
                        dir_for_saving_images, "nearest-" + str(i + 1) + "_act.npy"
                    ),
                    patch.act_pattern,
                )

                # save the original image where the patch comes from
                plt.imsave(
                    fname=os.path.join(
                        dir_for_saving_images, "nearest-" + str(i + 1) + "_original.png"
                    ),
                    arr=patch.original_img,
                    vmin=0.0,
                    vmax=1.0,
                )

                # overlay (upsampled) activation on original image and save the result
                img_size = patch.original_img.shape[0]
                upsampled_act_pattern = cv2.resize(
                    patch.act_pattern,
                    dsize=(img_size, img_size),
                    interpolation=cv2.INTER_CUBIC,
                )
                rescaled_act_pattern = upsampled_act_pattern - np.amin(
                    upsampled_act_pattern
                )
                rescaled_act_pattern = rescaled_act_pattern / np.amax(
                    rescaled_act_pattern
                )
                heatmap = cv2.applyColorMap(
                    np.uint8(255 * rescaled_act_pattern), cv2.COLORMAP_JET
                )
                heatmap = np.float32(heatmap) / 255
                heatmap = heatmap[..., ::-1]
                overlayed_original_img = 0.5 * patch.original_img + 0.3 * heatmap
                plt.imsave(
                    fname=os.path.join(
                        dir_for_saving_images,
                        "nearest-" + str(i + 1) + "_original_with_heatmap.png",
                    ),
                    arr=overlayed_original_img,
                    vmin=0.0,
                    vmax=1.0,
                )

                # if different from original image, save the patch (i.e. receptive field)
                if patch.patch.shape[0] != img_size or patch.patch.shape[1] != img_size:
                    np.save(
                        os.path.join(
                            dir_for_saving_images,
                            "nearest-" + str(i + 1) + "_receptive_field_indices.npy",
                        ),
                        patch.patch_indices,
                    )
                    plt.imsave(
                        fname=os.path.join(
                            dir_for_saving_images,
                            "nearest-" + str(i + 1) + "_receptive_field.png",
                        ),
                        arr=patch.patch,
                        vmin=0.0,
                        vmax=1.0,
                    )
                    # save the receptive field patch with heatmap
                    overlayed_patch = overlayed_original_img[
                        patch.patch_indices[0] : patch.patch_indices[1],
                        patch.patch_indices[2] : patch.patch_indices[3],
                        :,
                    ]
                    plt.imsave(
                        fname=os.path.join(
                            dir_for_saving_images,
                            "nearest-"
                            + str(i + 1)
                            + "_receptive_field_with_heatmap.png",
                        ),
                        arr=overlayed_patch,
                        vmin=0.0,
                        vmax=1.0,
                    )

                # save the highly activated patch
                high_act_patch_indices = find_high_activation_crop(
                    upsampled_act_pattern
                )
                high_act_patch = patch.original_img[
                    high_act_patch_indices[0] : high_act_patch_indices[1],
                    high_act_patch_indices[2] : high_act_patch_indices[3],
                    :,
                ]
                np.save(
                    os.path.join(
                        dir_for_saving_images,
                        "nearest-" + str(i + 1) + "_high_act_patch_indices.npy",
                    ),
                    high_act_patch_indices,
                )
                plt.imsave(
                    fname=os.path.join(
                        dir_for_saving_images,
                        "nearest-" + str(i + 1) + "_high_act_patch.png",
                    ),
                    arr=high_act_patch,
                    vmin=0.0,
                    vmax=1.0,
                )
                # save the original image with bounding box showing high activation patch
                imsave_with_bbox(
                    fname=os.path.join(
                        dir_for_saving_images,
                        "nearest-" + str(i + 1) + "_high_act_patch_in_original_img.png",
                    ),
                    img_rgb=patch.original_img,
                    bbox_height_start=high_act_patch_indices[0],
                    bbox_height_end=high_act_patch_indices[1],
                    bbox_width_start=high_act_patch_indices[2],
                    bbox_width_end=high_act_patch_indices[3],
                    color=(0, 255, 255),
                )

            labels = np.array([patch.label for patch in heaps[j]])
            np.save(os.path.join(dir_for_saving_images, "class_id.npy"), labels)

    labels_all_prototype = np.array(
        [[patch.label for patch in heaps[j]] for j in range(n_prototypes)]
    )

    if full_save:
        np.save(
            os.path.join(root_dir_for_saving_images, "full_class_id.npy"),
            labels_all_prototype,
        )

    end = time.time()
    log("\tfind nearest patches time: \t{0}".format(end - start))

    return labels_all_prototype
示例#4
0
if colab:
    model_dir = (
        "/content/PPNet/saved_models/" + base_architecture + "/" + experiment_run + "/"
    )
else:
    model_dir = (
        "/cluster/scratch/"
        + username
        + "/PPNet/saved_models/"
        + base_architecture
        + "/"
        + experiment_run
        + "/"
    )
makedir(model_dir)
shutil.copy(src=os.path.join(os.getcwd(), __file__), dst=model_dir)
shutil.copy(src=os.path.join(os.getcwd(), "settings.py"), dst=model_dir)
shutil.copy(
    src=os.path.join(
        os.getcwd(), "src/models/", base_architecture_type + "_features.py"
    ),
    dst=model_dir,
)
shutil.copy(src=os.path.join(os.getcwd(), "src/models/", "model.py"), dst=model_dir)
shutil.copy(
    src=os.path.join(os.getcwd(), "src/training/", "train_and_test.py"), dst=model_dir
)

log, logclose = create_logger(log_filename=os.path.join(model_dir, "train.log"))
img_dir = os.path.join(model_dir, "img")
    def jpeg_visualization2(
        self,
        pil_img,
        img_name,
        test_image_label,
        preprocess_clean,
        preprocess_compressed,
        show_images=False,
        idx=0,
        max_prototypes=5,
        top_n=None,
        save_histogram=False,
        save_name="histogram",
    ):
        """
        Perform detailed local analysis comparing compressed and uncompressed images.
        Args:
            pil_img: is the PIL test image which is to be inspected
            img_name: name of the image to save it
            test_image_label: label of the test image
            preprocess_clean: first torch vision transform pipeline (here, without compression)
            preprocess_compressed: second torch vision transform pipeline (here, with compression)
            show_images (default = False): Boolean value to show images
            idx (default = 0): Index Value, when retrieving results of the PPNet
            max_prototypes (int): number of most similar prototypes to display (default: 5)
            top_n (default = None): Visualize the top_n_th most activating prototype
        Returns:
            A list containing 8 images in the following order:
                1. Full picture of most activated prototype with bounding box
                2. Most activated prototype of compressed image
                3. Compressed image passed through, with activated patch in bounding box
                4. Corresponding activation map of the compressed image

                5. Full picture of most activated prototype with bounding box
                6. Most activated prototype of compressed image
                7. Uncompressed image passed through, with activated patch in bounding box
                8. Corresponding activation map of the uncompressed image
        """
        # How to save the images
        specific_folder = self.save_analysis_path + "/" + img_name
        makedir(specific_folder)

        # Preprocess Clean image
        img_tensor_clean = preprocess_clean(pil_img)
        img_variable_clean = Variable(img_tensor_clean.unsqueeze(0))

        # Preprocess compressed image
        img_tensor_compressed = preprocess_compressed(pil_img)
        img_variable_compressed = Variable(img_tensor_compressed.unsqueeze(0))

        img_variables = [img_variable_compressed, img_variable_clean]

        # Save activations
        dict_prototype_activations, dict_tables = {}, {}
        dict_prototype_activation_patterns = {}
        dict_array_act, dict_sorted_indices_act = {}, {}
        dict_original_img = {}

        for k, img_variable in enumerate(img_variables):

            # Forward the image variable through the network
            images_test = img_variable.cuda()
            labels_test = torch.tensor([test_image_label])

            logits, min_distances = self.ppnet_multi(images_test)
            conv_output, distances = self.ppnet.push_forward(images_test)
            prototype_activations = self.ppnet.distance_2_similarity(min_distances)
            prototype_activation_patterns = self.ppnet.distance_2_similarity(distances)
            if self.ppnet.prototype_activation_function == "linear":
                prototype_activations = prototype_activations + max_dist
                prototype_activation_patterns = prototype_activation_patterns + max_dist

            tables = []
            for i in range(logits.size(0)):
                tables.append(
                    (torch.argmax(logits, dim=1)[i].item(), labels_test[i].item())
                )

            idx = idx
            predicted_cls = tables[idx][0]
            correct_cls = tables[idx][1]
            self.log("Uncompressed image" if k == 1 else "JPEG compressed image")
            if predicted_cls == correct_cls:
                pred_text = "Prediction is correct."
            else:
                pred_text = "Prediction is wrong."
            self.log(
                "Predicted: "
                + str(predicted_cls)
                + "\t Actual: "
                + str(correct_cls)
                + "\t "
                + pred_text
            )
            self.log("------------------------------")

            dict_original_img[k] = save_preprocessed_img(
                os.path.join(specific_folder, "original_img.png"), images_test, idx
            )

            ##### MOST ACTIVATED (NEAREST) PROTOTYPES OF THIS IMAGE
            makedir(os.path.join(specific_folder, "most_activated_prototypes"))
            dict_prototype_activations[k], dict_tables[k] = (
                prototype_activations,
                tables,
            )
            dict_prototype_activation_patterns[k] = prototype_activation_patterns
            dict_array_act[k], dict_sorted_indices_act[k] = torch.sort(
                prototype_activations[idx]
            )

        # Initialize as none, and will be filled after examining the first image
        inspected_index = None
        inspected_min = None
        inspected_max = None

        logs = {0: [], 1: []}
        display_images = []
        for i in range(1, max_prototypes + 1):
            if top_n is not None and top_n != i:
                continue

            for k in [0, 1]:
                if top_n is not None:
                    if inspected_index is None:
                        inspected_index = dict_sorted_indices_act[k][-i].item()
                    else:
                        i = np.where(
                            dict_sorted_indices_act[k].cpu().numpy() == inspected_index
                        )[0][0]
                else:
                    inspected_index, inspected_min, inspected_max = None, None, None
                    inspected_index = dict_sorted_indices_act[k][-i].item()

                p_img = save_prototype(
                    self.load_img_dir,
                    os.path.join(
                        self.save_analysis_path,
                        "most_activated_prototypes",
                        "top-%d_activated_prototype.png" % i,
                    ),
                    self.start_epoch_number,
                    inspected_index,
                )

                p_oimg_with_bbox = save_prototype_original_img_with_bbox(
                    self.load_img_dir,
                    fname=os.path.join(
                        self.save_analysis_path,
                        "most_activated_prototypes",
                        "top-%d_activated_prototype_in_original_pimg.png" % i,
                    ),
                    epoch=self.start_epoch_number,
                    index=inspected_index,
                    bbox_height_start=self.prototype_info[inspected_index][1],
                    bbox_height_end=self.prototype_info[inspected_index][2],
                    bbox_width_start=self.prototype_info[inspected_index][3],
                    bbox_width_end=self.prototype_info[inspected_index][4],
                    color=(0, 255, 255),
                )
                p_img_with_self_actn = save_prototype_self_activation(
                    self.load_img_dir,
                    os.path.join(
                        self.save_analysis_path,
                        "most_activated_prototypes",
                        "top-%d_activated_prototype_self_act.png" % i,
                    ),
                    self.start_epoch_number,
                    inspected_index,
                )

                activation_pattern = (
                    dict_prototype_activation_patterns[k][idx][inspected_index]
                    .detach()
                    .cpu()
                    .numpy()
                )
                upsampled_activation_pattern = cv2.resize(
                    activation_pattern,
                    dsize=(self.img_size, self.img_size),
                    interpolation=cv2.INTER_CUBIC,
                )

                # show the most highly activated patch of the image by this prototype
                high_act_patch_indices = find_high_activation_crop(
                    upsampled_activation_pattern
                )
                high_act_patch = dict_original_img[k][
                    high_act_patch_indices[0] : high_act_patch_indices[1],
                    high_act_patch_indices[2] : high_act_patch_indices[3],
                    :,
                ]

                plt.imsave(
                    os.path.join(
                        specific_folder,
                        "most_activated_prototypes",
                        "most_highly_activated_patch_by_top-%d_prototype.png" % i,
                    ),
                    high_act_patch,
                )

                p_img_with_bbox = imsave_with_bbox(
                    fname=os.path.join(
                        specific_folder,
                        "most_activated_prototypes",
                        "most_highly_activated_patch_in_original_img_by_top-%d_prototype.png"
                        % i,
                    ),
                    img_rgb=dict_original_img[k],
                    bbox_height_start=high_act_patch_indices[0],
                    bbox_height_end=high_act_patch_indices[1],
                    bbox_width_start=high_act_patch_indices[2],
                    bbox_width_end=high_act_patch_indices[3],
                    color=(0, 255, 255),
                )

                # show the image overlayed with prototype activation map and use normalization values of first run
                if inspected_min is None:
                    inspected_min = np.amin(upsampled_activation_pattern)

                rescaled_activation_pattern = (
                    upsampled_activation_pattern - inspected_min
                )

                if inspected_max is None:
                    inspected_max = np.amax(rescaled_activation_pattern)

                rescaled_activation_pattern = (
                    rescaled_activation_pattern / inspected_max
                )
                heatmap = cv2.applyColorMap(
                    np.uint8(255 * rescaled_activation_pattern), cv2.COLORMAP_JET
                )
                heatmap = np.float32(heatmap) / 255
                heatmap = heatmap[..., ::-1]
                overlayed_img = 0.5 * dict_original_img[k] + 0.3 * heatmap

                plt.imsave(
                    os.path.join(
                        specific_folder,
                        "most_activated_prototypes",
                        "prototype_activation_map_by_top-%d_prototype.png" % i,
                    ),
                    overlayed_img,
                )

                display_images += [
                    p_oimg_with_bbox,
                    p_img,
                    p_img_with_bbox,
                    overlayed_img,
                ]
                logs[k].append(
                    {
                        "Prototype Id": inspected_index,
                        "Rank": i,
                        "Prototype Class": self.prototype_img_identity[inspected_index],
                        "Similarity Score": dict_prototype_activations[k][idx][
                            inspected_index
                        ].item(),
                    }
                )

        # Visualize Logs and Images
        if show_images:
            logs_df = pd.concat(
                [pd.DataFrame(logs[0]), pd.DataFrame(logs[1])],
                axis=1,
                keys=["Compressed Image", "Clean Image"],
            )
            display(logs_df)

            display_titles = [
                "Training Image from which \nprototype is taken",
                "Prototype",
                "Test Image + BBox",
                "Test Image + Activation Map",
            ]
            display_titles += [x + "\n(uncompressed)" for x in display_titles]

            visualize_image_grid(images=display_images, titles=display_titles, ncols=8)
            plt.tight_layout()
            plt.show()

            # Display Histograms
            f = plt.figure(figsize=(12, 3.5))
            for k in range(0, 2):
                plt.subplot(1, 2, k + 1)
                _, pids = torch.topk(dict_prototype_activations[k][idx], 75)
                pids = pids.cpu().numpy()
                plt.bar(
                    np.arange(0, len(pids)) - 0.2,
                    dict_prototype_activations[1][idx][pids].cpu().numpy(),
                    width=0.4,
                    label="Clean",
                )
                plt.bar(
                    np.arange(0, len(pids)) + 0.2,
                    dict_prototype_activations[0][idx][pids].cpu().numpy(),
                    width=0.4,
                    label="Compressed",
                )
                plt.title(
                    "Top 75 for compressed image"
                    if k == 0
                    else "Top 75 for clean image"
                )
                plt.ylabel("Similarities")
                plt.xlabel("Prototype")
                plt.legend()
            plt.tight_layout()
            plt.show()
            if save_histogram:
                f.savefig("{}.pdf".format(save_name), bbox_inches="tight")

        return display_images
    def jpeg_visualization(
        self,
        pil_img,
        img_name,
        test_image_label,
        preprocess_clean,
        preprocess_compressed,
        show_images=False,
        idx=0,
        top_n=1,
    ):
        """
        Perform local analysis comparing compressed and uncompressed images.
        Args:
            pil_img: is the PIL test image which is to be inspected
            img_name: name of the image to save it
            test_image_label: label of the test image
            preprocess_clean: first torch vision transform pipeline (here, without compression)
            preprocess_compressed: second torch vision transform pipeline (here, with compression)
            show_images (default = False): Boolean value to show images
            idx (default = 0): Index Value, when retrieving results of the PPNet
            top_n (default = 1): Visualize the n_th most activating prototype

        Returns:
            A list containing 8 images in the following order:
                1. Full picture of most activated prototype with bounding box
                2. Most activated prototype of compressed image
                3. Compressed image passed through, with activated patch in bounding box
                4. Corresponding activation map of the compressed image

                5. Full picture of most activated prototype with bounding box
                6. Most activated prototype of compressed image
                7. Uncompressed image passed through, with activated patch in bounding box
                8. Corresponding activation map of the uncompressed image
        """
        # How to save the images
        specific_folder = self.save_analysis_path + "/" + img_name
        makedir(specific_folder)

        # Preprocess Clean image
        img_tensor_clean = preprocess_clean(pil_img)
        img_variable_clean = Variable(img_tensor_clean.unsqueeze(0))

        # Preprocess compressed image
        img_tensor_compressed = preprocess_compressed(pil_img)
        img_variable_compressed = Variable(img_tensor_compressed.unsqueeze(0))

        # Initialize as none, and will be filled after examining the first image
        inspected_index = None
        inspected_min = None
        inspected_max = None

        img_variables = [img_variable_compressed, img_variable_clean]

        display_images = []
        for img_variable in img_variables:

            # Forward the image variable through the network
            images_test = img_variable.cuda()
            labels_test = torch.tensor([test_image_label])

            logits, min_distances = self.ppnet_multi(images_test)
            conv_output, distances = self.ppnet.push_forward(images_test)
            prototype_activations = self.ppnet.distance_2_similarity(min_distances)
            prototype_activation_patterns = self.ppnet.distance_2_similarity(distances)
            if self.ppnet.prototype_activation_function == "linear":
                prototype_activations = prototype_activations + max_dist
                prototype_activation_patterns = prototype_activation_patterns + max_dist

            tables = []
            for i in range(logits.size(0)):
                tables.append(
                    (torch.argmax(logits, dim=1)[i].item(), labels_test[i].item())
                )

            idx = idx
            predicted_cls = tables[idx][0]
            correct_cls = tables[idx][1]
            if predicted_cls == correct_cls:
                pred_text = "Prediction is correct."
            else:
                pred_text = "Prediction is wrong."
            self.log(
                "Predicted: "
                + str(predicted_cls)
                + "\t Actual: "
                + str(correct_cls)
                + "\t "
                + pred_text
            )

            original_img = save_preprocessed_img(
                os.path.join(specific_folder, "original_img.png"), images_test, idx
            )

            ##### MOST ACTIVATED (NEAREST) 10 PROTOTYPES OF THIS IMAGE
            makedir(os.path.join(specific_folder, "most_activated_prototypes"))
            array_act, sorted_indices_act = torch.sort(prototype_activations[idx])

            i = top_n

            if inspected_index is None:
                inspected_index = sorted_indices_act[-i].item()

            self.log("protoype index: " + str(inspected_index))

            p_img = save_prototype(
                self.load_img_dir,
                os.path.join(
                    self.save_analysis_path,
                    "most_activated_prototypes",
                    "top-%d_activated_prototype.png" % i,
                ),
                self.start_epoch_number,
                inspected_index,
            )

            p_oimg_with_bbox = save_prototype_original_img_with_bbox(
                self.load_img_dir,
                fname=os.path.join(
                    self.save_analysis_path,
                    "most_activated_prototypes",
                    "top-%d_activated_prototype_in_original_pimg.png" % i,
                ),
                epoch=self.start_epoch_number,
                index=inspected_index,
                bbox_height_start=self.prototype_info[inspected_index][1],
                bbox_height_end=self.prototype_info[inspected_index][2],
                bbox_width_start=self.prototype_info[inspected_index][3],
                bbox_width_end=self.prototype_info[inspected_index][4],
                color=(0, 255, 255),
            )
            p_img_with_self_actn = save_prototype_self_activation(
                self.load_img_dir,
                os.path.join(
                    self.save_analysis_path,
                    "most_activated_prototypes",
                    "top-%d_activated_prototype_self_act.png" % i,
                ),
                self.start_epoch_number,
                inspected_index,
            )

            self.log(
                "prototype class identity: {0}".format(
                    self.prototype_img_identity[inspected_index]
                )
            )
            if (
                self.prototype_max_connection[inspected_index]
                != self.prototype_img_identity[sorted_indices_act[-i].item()]
            ):
                self.log(
                    "prototype connection identity: {0}".format(
                        self.prototype_max_connection[inspected_index]
                    )
                )
            self.log(
                "activation value (similarity score): {0}".format(
                    prototype_activations[idx][inspected_index]
                )
            )

            activation_pattern = (
                prototype_activation_patterns[idx][inspected_index]
                .detach()
                .cpu()
                .numpy()
            )
            upsampled_activation_pattern = cv2.resize(
                activation_pattern,
                dsize=(self.img_size, self.img_size),
                interpolation=cv2.INTER_CUBIC,
            )

            # show the most highly activated patch of the image by this prototype
            high_act_patch_indices = find_high_activation_crop(
                upsampled_activation_pattern
            )
            high_act_patch = original_img[
                high_act_patch_indices[0] : high_act_patch_indices[1],
                high_act_patch_indices[2] : high_act_patch_indices[3],
                :,
            ]

            plt.imsave(
                os.path.join(
                    specific_folder,
                    "most_activated_prototypes",
                    "most_highly_activated_patch_by_top-%d_prototype.png" % i,
                ),
                high_act_patch,
            )

            p_img_with_bbox = imsave_with_bbox(
                fname=os.path.join(
                    specific_folder,
                    "most_activated_prototypes",
                    "most_highly_activated_patch_in_original_img_by_top-%d_prototype.png"
                    % i,
                ),
                img_rgb=original_img,
                bbox_height_start=high_act_patch_indices[0],
                bbox_height_end=high_act_patch_indices[1],
                bbox_width_start=high_act_patch_indices[2],
                bbox_width_end=high_act_patch_indices[3],
                color=(0, 255, 255),
            )

            # show the image overlayed with prototype activation map and use normalization values of first run
            if inspected_min is None:
                inspected_min = np.amin(upsampled_activation_pattern)

            rescaled_activation_pattern = upsampled_activation_pattern - inspected_min

            if inspected_max is None:
                inspected_max = np.amax(rescaled_activation_pattern)

            rescaled_activation_pattern = rescaled_activation_pattern / inspected_max
            heatmap = cv2.applyColorMap(
                np.uint8(255 * rescaled_activation_pattern), cv2.COLORMAP_JET
            )
            heatmap = np.float32(heatmap) / 255
            heatmap = heatmap[..., ::-1]
            overlayed_img = 0.5 * original_img + 0.3 * heatmap

            plt.imsave(
                os.path.join(
                    specific_folder,
                    "most_activated_prototypes",
                    "prototype_activation_map_by_top-%d_prototype.png" % i,
                ),
                overlayed_img,
            )

            display_images += [p_oimg_with_bbox, p_img, p_img_with_bbox, overlayed_img]
            self.log("------------------------------")

        # Visualize Images
        if show_images:
            display_titles = [
                "Training Image from which \nprototype is taken",
                "Prototype",
                "Test Image + BBox",
                "Test Image + Activation Map",
            ]

            display_titles += [x + "\n(uncompressed)" for x in display_titles]

            visualize_image_grid(images=display_images, titles=display_titles, ncols=8)
            plt.tight_layout()
            plt.show()

        return display_images
    def local_analysis(
        self,
        img_variable,
        test_image_label,
        max_prototypes=10,
        idx=0,
        pid=None,
        verbose=False,
        show_images=True,
        normalize_sim_map=None,
    ):
        """
        Perform local analysis.
        Arguments:
            img_variable (torch.Tensor): imput image to test on.
            test_image_label (int): true label of test image.
            max_prototypes (int): number of most similar prototypes to display (fefault: 10).
            idx (int): image id in the batch (default: 0).
            pid (int): prototype id.
            verbose (bool): whether to print (default: False).
            show_images (bool): show images (default: True).
            normalize_sim_map (Tuple(2)): min and max values resp. to be used for normalizing the activation map.
        """
        img_size = self.ppnet_multi.module.img_size
        images_test = img_variable.cuda()
        labels_test = torch.tensor([test_image_label])

        logits, min_distances = self.ppnet_multi(images_test)
        conv_output, distances = self.ppnet.push_forward(images_test)
        prototype_activations = self.ppnet.distance_2_similarity(min_distances)
        prototype_activation_patterns = self.ppnet.distance_2_similarity(distances)
        if self.ppnet.prototype_activation_function == "linear":
            prototype_activations = prototype_activations + self.max_dist
            prototype_activation_patterns = (
                prototype_activation_patterns + self.max_dist
            )

        tables = []
        for i in range(logits.size(0)):
            tables.append(
                (torch.argmax(logits, dim=1)[i].item(), labels_test[i].item())
            )

        idx = idx
        predicted_cls = tables[idx][0]
        correct_cls = tables[idx][1]
        self.log("Predicted: " + str(predicted_cls))
        self.log("Actual: " + str(correct_cls))
        if predicted_cls == correct_cls:
            self.log("Prediction is correct.")
        else:
            self.log("Prediction is wrong.")

        original_img = save_preprocessed_img(
            os.path.join(self.save_analysis_path, "original_img.png"), images_test, idx
        )

        ##### MOST ACTIVATED (NEAREST) 10 PROTOTYPES OF THIS IMAGE
        makedir(os.path.join(self.save_analysis_path, "most_activated_prototypes"))

        self.log("Most activated 10 prototypes of this image:")
        self.log("--------------------------------------------------------------")

        array_act, sorted_indices_act = torch.sort(prototype_activations[idx])
        for i in range(1, max_prototypes + 1):
            if pid is not None and pid != sorted_indices_act[-i].item():
                continue
            self.log("top {0} activated prototype for this image:".format(i))
            p_img = save_prototype(
                self.load_img_dir,
                os.path.join(
                    self.save_analysis_path,
                    "most_activated_prototypes",
                    "top-%d_activated_prototype.png" % i,
                ),
                self.start_epoch_number,
                sorted_indices_act[-i].item(),
            )
            p_oimg_with_bbox = save_prototype_original_img_with_bbox(
                self.load_img_dir,
                fname=os.path.join(
                    self.save_analysis_path,
                    "most_activated_prototypes",
                    "top-%d_activated_prototype_in_original_pimg.png" % i,
                ),
                epoch=self.start_epoch_number,
                index=sorted_indices_act[-i].item(),
                bbox_height_start=self.prototype_info[sorted_indices_act[-i].item()][1],
                bbox_height_end=self.prototype_info[sorted_indices_act[-i].item()][2],
                bbox_width_start=self.prototype_info[sorted_indices_act[-i].item()][3],
                bbox_width_end=self.prototype_info[sorted_indices_act[-i].item()][4],
                color=(0, 255, 255),
            )
            p_img_with_self_actn = save_prototype_self_activation(
                self.load_img_dir,
                os.path.join(
                    self.save_analysis_path,
                    "most_activated_prototypes",
                    "top-%d_activated_prototype_self_act.png" % i,
                ),
                self.start_epoch_number,
                sorted_indices_act[-i].item(),
            )

            self.log("prototype index: {0}".format(sorted_indices_act[-i].item()))
            self.log(
                "prototype class identity: {0}".format(
                    self.prototype_img_identity[sorted_indices_act[-i].item()]
                )
            )
            if (
                self.prototype_max_connection[sorted_indices_act[-i].item()]
                != self.prototype_img_identity[sorted_indices_act[-i].item()]
            ):
                self.log(
                    "prototype connection identity: {0}".format(
                        self.prototype_max_connection[sorted_indices_act[-i].item()]
                    )
                )
            self.log("activation value (similarity score): {0}".format(array_act[-i]))
            self.log(
                "last layer connection with predicted class: {0}".format(
                    self.ppnet.last_layer.weight[predicted_cls][
                        sorted_indices_act[-i].item()
                    ]
                )
            )

            activation_pattern = (
                prototype_activation_patterns[idx][sorted_indices_act[-i].item()]
                .detach()
                .cpu()
                .numpy()
            )
            upsampled_activation_pattern = cv2.resize(
                activation_pattern,
                dsize=(img_size, img_size),
                interpolation=cv2.INTER_CUBIC,
            )

            # show the most highly activated patch of the image by this prototype
            high_act_patch_indices = find_high_activation_crop(
                upsampled_activation_pattern
            )
            high_act_patch = original_img[
                high_act_patch_indices[0] : high_act_patch_indices[1],
                high_act_patch_indices[2] : high_act_patch_indices[3],
                :,
            ]
            if verbose:
                self.log(
                    "most highly activated patch of the chosen image by this prototype:"
                )
            # plt.axis('off')
            plt.imsave(
                os.path.join(
                    self.save_analysis_path,
                    "most_activated_prototypes",
                    "most_highly_activated_patch_by_top-%d_prototype.png" % i,
                ),
                high_act_patch,
            )
            if verbose:
                log(
                    "most highly activated patch by this prototype shown in the original image:"
                )
            p_img_with_bbox = imsave_with_bbox(
                fname=os.path.join(
                    self.save_analysis_path,
                    "most_activated_prototypes",
                    "most_highly_activated_patch_in_original_img_by_top-%d_prototype.png"
                    % i,
                ),
                img_rgb=original_img,
                bbox_height_start=high_act_patch_indices[0],
                bbox_height_end=high_act_patch_indices[1],
                bbox_width_start=high_act_patch_indices[2],
                bbox_width_end=high_act_patch_indices[3],
                color=(0, 255, 255),
            )

            # show the image overlayed with prototype activation map
            if normalize_sim_map is not None:
                rescaled_activation_pattern = (
                    upsampled_activation_pattern - normalize_sim_map[0]
                )
                rescaled_activation_pattern = (
                    rescaled_activation_pattern / normalize_sim_map[1]
                )
            else:
                rescaled_activation_pattern = upsampled_activation_pattern - np.amin(
                    upsampled_activation_pattern
                )
                rescaled_activation_pattern = rescaled_activation_pattern / np.amax(
                    rescaled_activation_pattern
                )
            heatmap = cv2.applyColorMap(
                np.uint8(255 * rescaled_activation_pattern), cv2.COLORMAP_JET
            )
            heatmap = np.float32(heatmap) / 255
            heatmap = heatmap[..., ::-1]
            overlayed_img = 0.5 * original_img + 0.3 * heatmap
            if verbose:
                self.log("prototype activation map of the chosen image:")
            # plt.axis('off')
            plt.imsave(
                os.path.join(
                    self.save_analysis_path,
                    "most_activated_prototypes",
                    "prototype_activation_map_by_top-%d_prototype.png" % i,
                ),
                overlayed_img,
            )

            if show_images:
                visualize_image_grid(
                    images=[p_oimg_with_bbox, p_img, p_img_with_bbox, overlayed_img],
                    titles=[
                        "Training Image from which \nprototype is taken",
                        "Prototype",
                        "Test Image + BBox",
                        "Test Image + Activation Map",
                    ],
                    ncols=4,
                )
                plt.tight_layout()
                plt.show()
            self.log("--------------------------------------------------------------")

        return (
            sorted_indices_act,
            prototype_activation_patterns,
            [
                np.amin(upsampled_activation_pattern),
                np.amax(rescaled_activation_pattern),
            ],
        )
    def __init__(
        self,
        load_model_dir,
        load_model_name,
        test_image_name,
        image_save_directory=None,
        attack=None,
    ):
        """
        Perform local analysis.
        Arguments:
            load_model_dir (str): path to saved model directory.
            load_model_name (str): saved model name.
            test_image_name (str): test image file name.
            image_save_directory (str): directory to save images.
            attack (int): type of attack (1 or 3 or None).
        """
        model_base_architecture = load_model_dir.split("/")[-3]
        self.model_base_architecture = model_base_architecture
        experiment_run = load_model_dir.split("/")[-2]

        self.save_analysis_path = (
            "/cluster/scratch/{}/PPNet/local_analysis_attack{}/".format(
                username, attack
            )
            + model_base_architecture
            + "-"
            + experiment_run
            + "/"
            + test_image_name[:-4]
        )

        if image_save_directory is not None:
            self.save_analysis_path = image_save_directory + test_image_name

        makedir(self.save_analysis_path)

        self.log, self.logclose = create_logger(
            log_filename=os.path.join(self.save_analysis_path, "local_analysis.log")
        )

        load_model_path = os.path.join(load_model_dir, load_model_name)
        epoch_number_str = re.search(r"\d+", load_model_name).group(0)
        self.start_epoch_number = int(epoch_number_str)

        self.log("load model from " + load_model_path)
        self.log("model base architecture: " + model_base_architecture)
        self.log("experiment run: " + experiment_run)

        self.ppnet = torch.load(load_model_path)
        self.ppnet = self.ppnet.cuda()
        self.ppnet_multi = torch.nn.DataParallel(self.ppnet)

        self.img_size = self.ppnet_multi.module.img_size
        prototype_shape = self.ppnet.prototype_shape
        self.max_dist = prototype_shape[1] * prototype_shape[2] * prototype_shape[3]

        self.class_specific = True
        self.normalize = transforms.Normalize(mean=mean, std=std)

        # confirm prototype class identity
        self.load_img_dir = os.path.join(load_model_dir, "img")

        self.prototype_info = np.load(
            os.path.join(
                self.load_img_dir,
                "epoch-" + epoch_number_str,
                "bb" + epoch_number_str + ".npy",
            )
        )
        self.prototype_img_identity = self.prototype_info[:, -1]

        self.log(
            "Prototypes are chosen from "
            + str(len(set(self.prototype_img_identity)))
            + " number of classes."
        )
        self.log("Their class identities are: " + str(self.prototype_img_identity))

        # confirm prototype connects most strongly to its own class
        prototype_max_connection = torch.argmax(self.ppnet.last_layer.weight, dim=0)
        self.prototype_max_connection = prototype_max_connection.cpu().numpy()
        if (
            np.sum(self.prototype_max_connection == self.prototype_img_identity)
            == self.ppnet.num_prototypes
        ):
            self.log(
                "All prototypes connect most strongly to their respective classes."
            )
        else:
            self.log(
                "WARNING: Not all prototypes connect most strongly to their respective classes."
            )