Beispiel #1
0
 def closest_patch_indices_in_img(self):
     closest_patch_indices_in_distance_map_j = \
         list(np.unravel_index(np.argmin(self.distance_map_j, axis=None),
                               self.distance_map_j.shape))
     closest_patch_indices_in_distance_map_j = [0] + closest_patch_indices_in_distance_map_j
     closest_patch_indices_in_img = \
         compute_rf_prototype(self.original_raw.size(1),
                              closest_patch_indices_in_distance_map_j,
                              self.protoL_rf_info)
     return closest_patch_indices_in_img
Beispiel #2
0
def update_prototypes_on_batch(search_batch_input,
                               start_index_of_search_batch,
                               prototype_network_parallel,
                               global_min_proto_dist, # this will be updated
                               global_min_fmap_patches, # this will be updated
                               proto_rf_boxes, # this will be updated
                               proto_bound_boxes, # this will be updated
                               class_specific=True,
                               search_y=None, # required if class_specific == True
                               num_classes=None, # required if class_specific == True
                               preprocess_input_function=None,
                               prototype_layer_stride=1,
                               dir_for_saving_prototypes=None,
                               prototype_img_filename_prefix=None,
                               prototype_self_act_filename_prefix=None,
                               prototype_activation_function_in_numpy=None):

    prototype_network_parallel.eval()

    if preprocess_input_function is not None:
        # print('preprocessing input for pushing ...')
        # search_batch = copy.deepcopy(search_batch_input)
        search_batch = preprocess_input_function(search_batch_input)

    else:
        search_batch = search_batch_input

    with torch.no_grad():
        search_batch = search_batch.cuda()
        # this computation currently is not parallelized
        protoL_input_torch, proto_dist_torch = prototype_network_parallel.module.push_forward(search_batch)

    protoL_input_ = np.copy(protoL_input_torch.detach().cpu().numpy())
    proto_dist_ = np.copy(proto_dist_torch.detach().cpu().numpy())

    del protoL_input_torch, proto_dist_torch

    if class_specific:
        class_to_img_index_dict = {key: [] for key in range(num_classes)}
        # img_y is the image's integer label
        for img_index, img_y in enumerate(search_y):
            img_label = img_y.item()
            class_to_img_index_dict[img_label].append(img_index)

    prototype_shape = prototype_network_parallel.module.prototype_shape
    n_prototypes = prototype_shape[0]
    proto_h = prototype_shape[2]
    proto_w = prototype_shape[3]
    max_dist = prototype_shape[1] * prototype_shape[2] * prototype_shape[3]

    for j in range(n_prototypes):
        #if n_prototypes_per_class != None:
        if class_specific:
            # target_class is the class of the class_specific prototype
            target_class = torch.argmax(prototype_network_parallel.module.prototype_class_identity[j]).item()
            # if there is not images of the target_class from this batch
            # we go on to the next prototype
            if len(class_to_img_index_dict[target_class]) == 0:
                continue
            proto_dist_j = proto_dist_[class_to_img_index_dict[target_class]][:,j,:,:]
        else:
            # if it is not class specific, then we will search through
            # every example
            proto_dist_j = proto_dist_[:,j,:,:]

        batch_min_proto_dist_j = np.amin(proto_dist_j)
        if batch_min_proto_dist_j < global_min_proto_dist[j]:
            batch_argmin_proto_dist_j = \
                list(np.unravel_index(np.argmin(proto_dist_j, axis=None),
                                      proto_dist_j.shape))
            if class_specific:
                '''
                change the argmin index from the index among
                images of the target class to the index in the entire search
                batch
                '''
                batch_argmin_proto_dist_j[0] = class_to_img_index_dict[target_class][batch_argmin_proto_dist_j[0]]

            # retrieve the corresponding feature map patch
            img_index_in_batch = batch_argmin_proto_dist_j[0]
            fmap_height_start_index = batch_argmin_proto_dist_j[1] * prototype_layer_stride
            fmap_height_end_index = fmap_height_start_index + proto_h
            fmap_width_start_index = batch_argmin_proto_dist_j[2] * prototype_layer_stride
            fmap_width_end_index = fmap_width_start_index + proto_w

            batch_min_fmap_patch_j = protoL_input_[img_index_in_batch,
                                                   :,
                                                   fmap_height_start_index:fmap_height_end_index,
                                                   fmap_width_start_index:fmap_width_end_index]

            global_min_proto_dist[j] = batch_min_proto_dist_j
            global_min_fmap_patches[j] = batch_min_fmap_patch_j
            
            # get the receptive field boundary of the image patch
            # that generates the representation
            protoL_rf_info = prototype_network_parallel.module.proto_layer_rf_info
            rf_prototype_j = compute_rf_prototype(search_batch.size(2), batch_argmin_proto_dist_j, protoL_rf_info)
            
            # get the whole image
            original_img_j = search_batch_input[rf_prototype_j[0]]
            original_img_j = original_img_j.numpy()
            original_img_j = np.transpose(original_img_j, (1, 2, 0))
            original_img_size = original_img_j.shape[0]
            
            # crop out the receptive field
            rf_img_j = original_img_j[rf_prototype_j[1]:rf_prototype_j[2],
                                      rf_prototype_j[3]:rf_prototype_j[4], :]
            
            # save the prototype receptive field information
            proto_rf_boxes[j, 0] = rf_prototype_j[0] + start_index_of_search_batch
            proto_rf_boxes[j, 1] = rf_prototype_j[1]
            proto_rf_boxes[j, 2] = rf_prototype_j[2]
            proto_rf_boxes[j, 3] = rf_prototype_j[3]
            proto_rf_boxes[j, 4] = rf_prototype_j[4]
            if proto_rf_boxes.shape[1] == 6 and search_y is not None:
                proto_rf_boxes[j, 5] = search_y[rf_prototype_j[0]].item()

            # find the highly activated region of the original image
            proto_dist_img_j = proto_dist_[img_index_in_batch, j, :, :]
            if prototype_network_parallel.module.prototype_activation_function == 'log':
                proto_act_img_j = np.log((proto_dist_img_j + 1) / (proto_dist_img_j + prototype_network_parallel.module.epsilon))
            elif prototype_network_parallel.module.prototype_activation_function == 'linear':
                proto_act_img_j = max_dist - proto_dist_img_j
            else:
                proto_act_img_j = prototype_activation_function_in_numpy(proto_dist_img_j)
            upsampled_act_img_j = cv2.resize(proto_act_img_j, dsize=(original_img_size, original_img_size),
                                             interpolation=cv2.INTER_CUBIC)
            proto_bound_j = find_high_activation_crop(upsampled_act_img_j)
            # crop out the image patch with high activation as prototype image
            proto_img_j = original_img_j[proto_bound_j[0]:proto_bound_j[1],
                                         proto_bound_j[2]:proto_bound_j[3], :]

            # save the prototype boundary (rectangular boundary of highly activated region)
            proto_bound_boxes[j, 0] = proto_rf_boxes[j, 0]
            proto_bound_boxes[j, 1] = proto_bound_j[0]
            proto_bound_boxes[j, 2] = proto_bound_j[1]
            proto_bound_boxes[j, 3] = proto_bound_j[2]
            proto_bound_boxes[j, 4] = proto_bound_j[3]
            if proto_bound_boxes.shape[1] == 6 and search_y is not None:
                proto_bound_boxes[j, 5] = search_y[rf_prototype_j[0]].item()

            if dir_for_saving_prototypes is not None:
                if prototype_self_act_filename_prefix is not None:
                    # save the numpy array of the prototype self activation
                    np.save(os.path.join(dir_for_saving_prototypes,
                                         prototype_self_act_filename_prefix + str(j) + '.npy'),
                            proto_act_img_j)
                if prototype_img_filename_prefix is not None:
                    # save the whole image containing the prototype as png
                    plt.imsave(os.path.join(dir_for_saving_prototypes,
                                            prototype_img_filename_prefix + '-original' + str(j) + '.png'),
                               original_img_j,
                               vmin=0.0,
                               vmax=1.0)
                    # overlay (upsampled) self activation on original image and save the result
                    rescaled_act_img_j = upsampled_act_img_j - np.amin(upsampled_act_img_j)
                    rescaled_act_img_j = rescaled_act_img_j / np.amax(rescaled_act_img_j)
                    heatmap = cv2.applyColorMap(np.uint8(255*rescaled_act_img_j), cv2.COLORMAP_JET)
                    heatmap = np.float32(heatmap) / 255
                    heatmap = heatmap[...,::-1]
                    overlayed_original_img_j = 0.5 * original_img_j + 0.3 * heatmap
                    plt.imsave(os.path.join(dir_for_saving_prototypes,
                                            prototype_img_filename_prefix + '-original_with_self_act' + str(j) + '.png'),
                               overlayed_original_img_j,
                               vmin=0.0,
                               vmax=1.0)
                    
                    # if different from the original (whole) image, save the prototype receptive field as png
                    if rf_img_j.shape[0] != original_img_size or rf_img_j.shape[1] != original_img_size:
                        plt.imsave(os.path.join(dir_for_saving_prototypes,
                                                prototype_img_filename_prefix + '-receptive_field' + str(j) + '.png'),
                                   rf_img_j,
                                   vmin=0.0,
                                   vmax=1.0)
                        overlayed_rf_img_j = overlayed_original_img_j[rf_prototype_j[1]:rf_prototype_j[2],
                                                                      rf_prototype_j[3]:rf_prototype_j[4]]
                        plt.imsave(os.path.join(dir_for_saving_prototypes,
                                                prototype_img_filename_prefix + '-receptive_field_with_self_act' + str(j) + '.png'),
                                   overlayed_rf_img_j,
                                   vmin=0.0,
                                   vmax=1.0)
                    
                    # save the prototype image (highly activated region of the whole image)
                    plt.imsave(os.path.join(dir_for_saving_prototypes,
                                            prototype_img_filename_prefix + str(j) + '.png'),
                               proto_img_j,
                               vmin=0.0,
                               vmax=1.0)
                
    if class_specific:
        del class_to_img_index_dict
def find_k_nearest_patches_to_prototypes(
        dataloader,  # pytorch dataloader (must be unnormalized in [0,1])
        prototype_network_parallel,  # pytorch network with prototype_vectors
        k=5,
        preprocess_input_function=None,  # normalize if needed
        full_save=False,  # save all the images
        root_dir_for_saving_images='./nearest',
        log=print,
        prototype_activation_function_in_numpy=None):
    prototype_network_parallel.eval()
    '''
    full_save=False will only return the class identity of the closest
    patches, but it will not save anything.
    '''
    log('find nearest patches')
    start = time.time()
    n_prototypes = prototype_network_parallel.module.num_prototypes

    prototype_shape = prototype_network_parallel.module.prototype_shape
    max_dist = prototype_shape[1] * prototype_shape[2] * prototype_shape[3]

    protoL_rf_info = prototype_network_parallel.module.proto_layer_rf_info

    heaps = []
    # allocate an array of n_prototypes number of heaps
    for _ in range(n_prototypes):
        # a heap in python is just a maintained list
        heaps.append([])

    for idx, (search_batch_input, search_y) in enumerate(dataloader):
        print('batch {}'.format(idx))
        if preprocess_input_function is not None:
            # print('preprocessing input for pushing ...')
            # search_batch = copy.deepcopy(search_batch_input)
            search_batch = preprocess_input_function(search_batch_input)

        else:
            search_batch = search_batch_input

        with torch.no_grad():
            search_batch = search_batch.cuda()
            protoL_input_torch, proto_dist_torch = \
                prototype_network_parallel.module.push_forward(search_batch)

        #protoL_input_ = np.copy(protoL_input_torch.detach().cpu().numpy())
        proto_dist_ = np.copy(proto_dist_torch.detach().cpu().numpy())

        for img_idx, distance_map in enumerate(proto_dist_):
            for j in range(n_prototypes):
                # find the closest patches in this batch to prototype j

                closest_patch_distance_to_prototype_j = np.amin(
                    distance_map[j])

                if full_save:
                    closest_patch_indices_in_distance_map_j = \
                        list(np.unravel_index(np.argmin(distance_map[j],axis=None),
                                              distance_map[j].shape))
                    closest_patch_indices_in_distance_map_j = [
                        0
                    ] + closest_patch_indices_in_distance_map_j
                    closest_patch_indices_in_img = \
                        compute_rf_prototype(search_batch.size(2),
                                             closest_patch_indices_in_distance_map_j,
                                             protoL_rf_info)
                    closest_patch = \
                        search_batch_input[img_idx, :,
                                           closest_patch_indices_in_img[1]:closest_patch_indices_in_img[2],
                                           closest_patch_indices_in_img[3]:closest_patch_indices_in_img[4]]
                    closest_patch = closest_patch.numpy()
                    closest_patch = np.transpose(closest_patch, (1, 2, 0))

                    original_img = search_batch_input[img_idx].numpy()
                    original_img = np.transpose(original_img, (1, 2, 0))

                    if prototype_network_parallel.module.prototype_activation_function == 'log':
                        act_pattern = np.log(
                            (distance_map[j] + 1) /
                            (distance_map[j] +
                             prototype_network_parallel.module.epsilon))
                    elif prototype_network_parallel.module.prototype_activation_function == 'linear':
                        act_pattern = max_dist - distance_map[j]
                    else:
                        act_pattern = prototype_activation_function_in_numpy(
                            distance_map[j])

                    # 4 numbers: height_start, height_end, width_start, width_end
                    patch_indices = closest_patch_indices_in_img[1:5]

                    # construct the closest patch object
                    closest_patch = ImagePatch(
                        patch=closest_patch,
                        label=search_y[img_idx],
                        distance=closest_patch_distance_to_prototype_j,
                        original_img=original_img,
                        act_pattern=act_pattern,
                        patch_indices=patch_indices)
                else:
                    closest_patch = ImagePatchInfo(
                        label=search_y[img_idx],
                        distance=closest_patch_distance_to_prototype_j)

                # add to the j-th heap
                if len(heaps[j]) < k:
                    heapq.heappush(heaps[j], closest_patch)
                else:
                    # heappushpop runs more efficiently than heappush
                    # followed by heappop
                    heapq.heappushpop(heaps[j], closest_patch)

    # after looping through the dataset every heap will
    # have the k closest prototypes
    for j in range(n_prototypes):
        # finally sort the heap; the heap only contains the k closest
        # but they are not ranked yet
        heaps[j].sort()
        heaps[j] = heaps[j][::-1]

        if full_save:

            dir_for_saving_images = os.path.join(root_dir_for_saving_images,
                                                 str(j))
            makedir(dir_for_saving_images)

            labels = []

            for i, patch in enumerate(heaps[j]):
                # save the activation pattern of the original image where the patch comes from
                np.save(
                    os.path.join(dir_for_saving_images,
                                 'nearest-' + str(i + 1) + '_act.npy'),
                    patch.act_pattern)

                # save the original image where the patch comes from
                plt.imsave(fname=os.path.join(
                    dir_for_saving_images,
                    'nearest-' + str(i + 1) + '_original.png'),
                           arr=patch.original_img,
                           vmin=0.0,
                           vmax=1.0)

                # overlay (upsampled) activation on original image and save the result
                img_size = patch.original_img.shape[0]
                upsampled_act_pattern = cv2.resize(
                    patch.act_pattern,
                    dsize=(img_size, img_size),
                    interpolation=cv2.INTER_CUBIC)
                rescaled_act_pattern = upsampled_act_pattern - np.amin(
                    upsampled_act_pattern)
                rescaled_act_pattern = rescaled_act_pattern / np.amax(
                    rescaled_act_pattern)
                heatmap = cv2.applyColorMap(
                    np.uint8(255 * rescaled_act_pattern), cv2.COLORMAP_JET)
                heatmap = np.float32(heatmap) / 255
                heatmap = heatmap[..., ::-1]
                overlayed_original_img = 0.5 * patch.original_img + 0.3 * heatmap
                plt.imsave(fname=os.path.join(
                    dir_for_saving_images,
                    'nearest-' + str(i + 1) + '_original_with_heatmap.png'),
                           arr=overlayed_original_img,
                           vmin=0.0,
                           vmax=1.0)

                # if different from original image, save the patch (i.e. receptive field)
                if patch.patch.shape[0] != img_size or patch.patch.shape[
                        1] != img_size:
                    np.save(
                        os.path.join(
                            dir_for_saving_images, 'nearest-' + str(i + 1) +
                            '_receptive_field_indices.npy'),
                        patch.patch_indices)
                    plt.imsave(fname=os.path.join(
                        dir_for_saving_images,
                        'nearest-' + str(i + 1) + '_receptive_field.png'),
                               arr=patch.patch,
                               vmin=0.0,
                               vmax=1.0)
                    # save the receptive field patch with heatmap
                    overlayed_patch = overlayed_original_img[
                        patch.patch_indices[0]:patch.patch_indices[1],
                        patch.patch_indices[2]:patch.patch_indices[3], :]
                    plt.imsave(fname=os.path.join(
                        dir_for_saving_images, 'nearest-' + str(i + 1) +
                        '_receptive_field_with_heatmap.png'),
                               arr=overlayed_patch,
                               vmin=0.0,
                               vmax=1.0)

                # save the highly activated patch
                high_act_patch_indices = find_high_activation_crop(
                    upsampled_act_pattern)
                high_act_patch = patch.original_img[
                    high_act_patch_indices[0]:high_act_patch_indices[1],
                    high_act_patch_indices[2]:high_act_patch_indices[3], :]
                np.save(
                    os.path.join(
                        dir_for_saving_images, 'nearest-' + str(i + 1) +
                        '_high_act_patch_indices.npy'), high_act_patch_indices)
                plt.imsave(fname=os.path.join(
                    dir_for_saving_images,
                    'nearest-' + str(i + 1) + '_high_act_patch.png'),
                           arr=high_act_patch,
                           vmin=0.0,
                           vmax=1.0)
                # save the original image with bounding box showing high activation patch
                imsave_with_bbox(fname=os.path.join(
                    dir_for_saving_images, 'nearest-' + str(i + 1) +
                    '_high_act_patch_in_original_img.png'),
                                 img_rgb=patch.original_img,
                                 bbox_height_start=high_act_patch_indices[0],
                                 bbox_height_end=high_act_patch_indices[1],
                                 bbox_width_start=high_act_patch_indices[2],
                                 bbox_width_end=high_act_patch_indices[3],
                                 color=(0, 255, 255))

            labels = np.array([patch.label for patch in heaps[j]])
            np.save(os.path.join(dir_for_saving_images, 'class_id.npy'),
                    labels)

    labels_all_prototype = np.array([[patch.label for patch in heaps[j]]
                                     for j in range(n_prototypes)])

    if full_save:
        np.save(os.path.join(root_dir_for_saving_images, 'full_class_id.npy'),
                labels_all_prototype)

    end = time.time()
    log('\tfind nearest patches time: \t{0}'.format(end - start))

    return labels_all_prototype
Beispiel #4
0
def update_prototypes_on_batch(
        search_batch_input,
        start_index_of_search_batch,
        prototype_network_parallel,
        conv_features,
        global_min_proto_dist,  # this will be updated
        global_min_fmap_patches,  # this will be updated
        proto_bound_boxes,  # this will be updated
        name,
        prototype_shape,
        n_prototypes_per_class=None,  # default: no restriction on number of prototypes per class
        search_y=None,  # required if n_prototypes_per_class != None
        num_classes=None,  # required if n_prototypes_per_class != None
        preprocess_input_function=None,
        prototype_layer_stride=1,
        dir_for_saving_prototypes=None,
        prototype_img_filename_prefix=None,
        prototype_original_img_filename_prefix=None):

    # print("inside update")
    # for x in range(5):
    #     img_from_input = np.transpose(search_batch_input[x],(1,2,0))
    #     plt.imshow(img_from_input)
    #     plt.show()

    protoL_input_torch = copy.deepcopy(conv_features)

    with torch.no_grad():

        proto_dist_torch = prototype_network_parallel.module.prototype_distances(
            conv_features, name)

    protoL_input_ = np.copy(protoL_input_torch.detach().cpu().numpy())
    proto_dist_ = np.copy(proto_dist_torch.detach().cpu().numpy())

    del protoL_input_torch, proto_dist_torch

    if n_prototypes_per_class != None:
        class_to_img_index_dict = {key: [] for key in range(num_classes)}
        for img_index, img_y in enumerate(search_y):
            img_label = img_y.item()
            class_to_img_index_dict[img_label].append(img_index)

    n_prototypes = prototype_shape[0]
    proto_h = prototype_shape[2]
    proto_w = prototype_shape[3]

    for j in range(n_prototypes):
        # print("STARTING PROTO", j)
        if n_prototypes_per_class != None:
            target_class = j // n_prototypes_per_class
            if len(class_to_img_index_dict[target_class]) == 0:
                continue
            # print(n_prototypes_per_class)
            # print(j)
            # print(target_class)
            proto_dist_j = proto_dist_[
                class_to_img_index_dict[target_class]][:, j, :, :]
        else:
            proto_dist_j = proto_dist_[:, j, :, :]
        batch_min_proto_dist_j = np.amin(proto_dist_j)

        if batch_min_proto_dist_j < global_min_proto_dist[j]:
            batch_argmin_proto_dist_j = \
                list(np.unravel_index(np.argmin(proto_dist_j, axis=None),
                                      proto_dist_j.shape))
            # print("batch min index", batch_argmin_proto_dist_j)
            # retrive the corresponding feature map patch
            if n_prototypes_per_class != None:
                img_index_in_batch = class_to_img_index_dict[target_class][
                    batch_argmin_proto_dist_j[0]]
            else:
                img_index_in_batch = batch_argmin_proto_dist_j[0]
            fmap_height_start_index = batch_argmin_proto_dist_j[
                1] * prototype_layer_stride
            fmap_height_end_index = fmap_height_start_index + proto_h
            fmap_width_start_index = batch_argmin_proto_dist_j[
                2] * prototype_layer_stride
            fmap_width_end_index = fmap_width_start_index + proto_w
            batch_min_fmap_patch_j = protoL_input_[
                img_index_in_batch, :,
                fmap_height_start_index:fmap_height_end_index,
                fmap_width_start_index:fmap_width_end_index]
            global_min_proto_dist[j] = batch_min_proto_dist_j
            global_min_fmap_patches[j] = np.reshape(batch_min_fmap_patch_j, -1)

            if n_prototypes_per_class != None:
                batch_argmin_proto_dist_j[0] = class_to_img_index_dict[
                    target_class][batch_argmin_proto_dist_j[0]]
            protoL_rf_info = prototype_network_parallel.module.proto_layer_rf_info
            rf_prototype_j = compute_rf_prototype(search_batch_input.size(2),
                                                  batch_argmin_proto_dist_j,
                                                  protoL_rf_info)
            img_j = search_batch_input[rf_prototype_j[0], :,
                                       rf_prototype_j[1]:rf_prototype_j[2],
                                       rf_prototype_j[3]:rf_prototype_j[4]]
            img_j = img_j.numpy()
            proto_bound_boxes[
                j, 0] = rf_prototype_j[0] + start_index_of_search_batch
            proto_bound_boxes[j, 1] = rf_prototype_j[1]
            proto_bound_boxes[j, 2] = rf_prototype_j[2]
            proto_bound_boxes[j, 3] = rf_prototype_j[3]
            proto_bound_boxes[j, 4] = rf_prototype_j[4]
            if proto_bound_boxes.shape[1] == 6 and not (search_y is None):
                proto_bound_boxes[j, 5] = search_y[rf_prototype_j[0]].item()

            if dir_for_saving_prototypes != None:
                if prototype_img_filename_prefix != None:
                    np.save(
                        os.path.join(
                            dir_for_saving_prototypes,
                            prototype_img_filename_prefix + str(j) + '.npy'),
                        img_j)
                    img_j = np.transpose(img_j, (1, 2, 0))
                    matplotlib.image.imsave(os.path.join(
                        dir_for_saving_prototypes,
                        prototype_img_filename_prefix + str(j) + '.png'),
                                            img_j,
                                            vmin=0.0,
                                            vmax=1.0)
                if prototype_original_img_filename_prefix != None:
                    original_img_j = search_batch_input[rf_prototype_j[0]]
                    original_img_j = original_img_j.numpy()
                    np.save(
                        os.path.join(
                            dir_for_saving_prototypes,
                            prototype_original_img_filename_prefix + str(j) +
                            '.npy'), original_img_j)
                    original_img_j = np.transpose(original_img_j, (1, 2, 0))
                    # scipy.misc.imsave(os.path.join(dir_for_saving_prototypes,
                    #                                prototype_original_img_filename_prefix + str(j) + '.png'),
                    #                   original_img_j)
                    matplotlib.image.imsave(os.path.join(
                        dir_for_saving_prototypes,
                        prototype_original_img_filename_prefix + str(j) +
                        '.png'),
                                            original_img_j,
                                            vmin=0.0,
                                            vmax=1.0)

    if n_prototypes_per_class != None:
        del class_to_img_index_dict
def find_k_nearest_patches_to_prototypes(
        dataloader,  # pytorch dataloader (must be unnormalized in [0,1])
        prototype_network_parallel,  # pytorch network with prototype_vectors
        k=5,
        preprocess_input_function=None,  # normalize if needed)
        prototype_layer_stride=1,
        root_dir_for_saving_images='./nearest_',
        save_image_class_identity=True,
        log=print):
    log('find nearest patches')
    start = time.time()

    protoL_rf_info = prototype_network_parallel.module.proto_layer_rf_info
    parent_names = [
        node.name for node in
        prototype_network_parallel.module.root.nodes_with_children()
    ]

    cmap = "jet"

    node2heaps = {name: [] for name in parent_names}
    # for each parent node, organize a heap for every prototype
    for node in prototype_network_parallel.module.root.nodes_with_children():
        name = node.name
        num_prototypes = getattr(prototype_network_parallel.module,
                                 "num_" + node.name + "_prototypes")
        for j in range(num_prototypes):
            node2heaps[name].append([])

    for (search_batch_input, search_y) in dataloader:
        if preprocess_input_function != None:
            #print('preprocessing input for pushing ...')
            search_batch = copy.deepcopy(search_batch_input)
            search_batch = preprocess_input_function(search_batch)
        else:
            search_batch = search_batch_input

        with torch.no_grad():
            search_batch = search_batch.cuda()
            conv_features = prototype_network_parallel.module.conv_features(
                search_batch)

        for node in prototype_network_parallel.module.root.nodes_with_children(
        ):

            num_prototypes = getattr(prototype_network_parallel.module,
                                     "num_" + node.name + "_prototypes")

            with torch.no_grad():
                proto_dist_torch = prototype_network_parallel.module.prototype_distances(
                    conv_features, node.name)
                proto_dist_ = np.copy(proto_dist_torch.detach().cpu().numpy())

            heaps = node2heaps[node.name]

            for img_idx, distance_map in enumerate(proto_dist_):
                for j in range(num_prototypes):
                    # find the closest patch to prototype j
                    closest_patch_distance_to_prototype_j = np.amin(
                        distance_map[j])
                    closest_patch_indices_in_distance_map_j = \
                        list(np.unravel_index(np.argmin(distance_map[j],axis=None),
                                              distance_map[j].shape))
                    closest_patch_indices_in_distance_map_j = [
                        0
                    ] + closest_patch_indices_in_distance_map_j
                    closest_patch_indices_in_img = \
                        compute_rf_prototype(search_batch.size(2),
                                             closest_patch_indices_in_distance_map_j,
                                             protoL_rf_info)
                    closest_patch = \
                        search_batch_input[img_idx, :,
                                           closest_patch_indices_in_img[1]:closest_patch_indices_in_img[2],
                                           closest_patch_indices_in_img[3]:closest_patch_indices_in_img[4]]
                    closest_patch = closest_patch.numpy()
                    closest_patch = np.transpose(closest_patch, (1, 2, 0))

                    original_img = search_batch_input[img_idx].numpy()
                    original_img = np.transpose(original_img, (1, 2, 0))

                    act_pattern = np.log(1 + (1 / (distance_map[j] + 1e-4)))

                    patch_indices = closest_patch_indices_in_img[1:5]

                    # construct the closest patch object
                    closest_patch = ImagePatch(
                        patch=closest_patch,
                        label=search_y[img_idx],
                        distance=closest_patch_distance_to_prototype_j,
                        original_img=original_img,
                        act_pattern=act_pattern,
                        patch_indices=patch_indices)
                    # add to the j-th heap
                    if len(heaps[j]) < k:
                        heapq.heappush(heaps[j], closest_patch)
                    else:
                        heapq.heappushpop(heaps[j], closest_patch)

        del conv_features, search_batch

    for node in prototype_network_parallel.module.root.nodes_with_children():
        heaps = node2heaps[node.name]
        num_prototypes = getattr(prototype_network_parallel.module,
                                 "num_" + node.name + "_prototypes")

        for j in range(num_prototypes):
            heaps[j].sort()
            heaps[j] = heaps[j][::-1]  # reverses

            dir_for_saving_images = os.path.join(root_dir_for_saving_images,
                                                 node.name, str(j))
            makedir(dir_for_saving_images)

            labels = []

            for i, patch in enumerate(heaps[j]):
                # save the patch itself
                plt.imsave(fname=os.path.join(dir_for_saving_images,
                                              'nearest-' + str(i) + '.png'),
                           arr=patch.patch,
                           vmin=0.0,
                           vmax=1.0)
                # save the original image where the patch comes from
                plt.imsave(fname=os.path.join(
                    dir_for_saving_images,
                    'nearest-' + str(i) + '_original.png'),
                           arr=patch.original_img,
                           vmin=0.0,
                           vmax=1.0)
                # save the activation pattern and the patch indices
                np.save(
                    os.path.join(dir_for_saving_images,
                                 'nearest-' + str(i) + '_act.npy'),
                    patch.act_pattern)
                np.save(
                    os.path.join(dir_for_saving_images,
                                 'nearest-' + str(i) + '_indices.npy'),
                    patch.patch_indices)
                # upsample the activation pattern
                img_size = patch.original_img.shape[0]
                original_img_gray = cv2.cvtColor(patch.original_img,
                                                 cv2.COLOR_BGR2GRAY)
                upsampled_act_pattern = cv2.resize(
                    patch.act_pattern,
                    dsize=(img_size, img_size),
                    interpolation=cv2.INTER_CUBIC)
                # overlay heatmap on the original image and save the result
                overlayed_original_img = 0.7 * original_img_gray + 0.3 * upsampled_act_pattern
                plt.imsave(fname=os.path.join(
                    dir_for_saving_images,
                    'nearest-' + str(i) + '_original_with_heatmap.png'),
                           arr=overlayed_original_img,
                           cmap=cmap,
                           vmin=0.0,
                           vmax=1.0)
                # save the patch with heatmap
                overlayed_patch = overlayed_original_img[
                    patch.patch_indices[0]:patch.patch_indices[1],
                    patch.patch_indices[2]:patch.patch_indices[3]]
                plt.imsave(fname=os.path.join(
                    dir_for_saving_images,
                    'nearest-' + str(i) + '_patch_with_heatmap.png'),
                           arr=overlayed_patch,
                           cmap=cmap,
                           vmin=0.0,
                           vmax=1.0)

            if save_image_class_identity:
                labels = np.array([patch.label for patch in heaps[j]])
                np.save(os.path.join(dir_for_saving_images, 'class_id.npy'),
                        labels)

    end = time.time()
    log('\tfind nearest patches time: \t{0}'.format(end - start))