Beispiel #1
0
    iteration = None
    run_name_prefix = ''
    if args.run_name_prefix:
        run_name_prefix = args.run_name_prefix.lower().replace(' ', '_') + '.'
    time = datetime.datetime.now()
    time = time.replace(microsecond=0)
    experiment_run_name = run_name_prefix + '{}.{}.{}'.format(
        args.dataset, platform.node(), time.isoformat())
    best_accu = 0.
    current_push_best_accu = 0.
    print('Starting new experiment: {}'.format(experiment_run_name))
    print('Saving code snapshot with git-experiments')
    snapshot_code(experiment_run_name)

model_dir = os.path.join(SAVED_MODELS_PATH, experiment_run_name)
makedir(model_dir)
img_dir = os.path.join(model_dir, 'img')
makedir(img_dir)

log_writer = SummaryWriter(os.path.join(LOGS_DIR, experiment_run_name),
                           purge_step=step + 1)

weight_matrix_filename = 'outputL_weights'
prototype_img_filename_prefix = 'prototype-img'
prototype_self_act_filename_prefix = 'prototype-self-act'
proto_bound_boxes_filename_prefix = 'bb'

# noinspection PyTypeChecker
log_writer.add_text(
    'dataset_stats',
    'training set size: {}, push set size: {}, valid set size: {}, test set size: {}'
Beispiel #2
0
def push_prototypes(dataloader, # pytorch dataloader (must be unnormalized in [0,1])
                    prototype_network_parallel, # pytorch network with prototype_vectors
                    class_specific=True,
                    preprocess_input_function=None, # normalize if needed
                    prototype_layer_stride=1,
                    root_dir_for_saving_prototypes=None, # if not None, prototypes will be saved here
                    epoch_number=None, # if not provided, prototypes saved previously will be overwritten
                    prototype_img_filename_prefix=None,
                    prototype_self_act_filename_prefix=None,
                    proto_bound_boxes_filename_prefix=None,
                    save_prototype_class_identity=True, # which class the prototype image comes from
                    log=print,
                    prototype_activation_function_in_numpy=None):

    prototype_network_parallel.eval()
    log('\tpush')

    start = time.time()
    prototype_shape = prototype_network_parallel.module.prototype_shape
    n_prototypes = prototype_network_parallel.module.num_prototypes
    # saves the closest distance seen so far
    global_min_proto_dist = np.full(n_prototypes, np.inf)
    # saves the patch representation that gives the current smallest distance
    global_min_fmap_patches = np.zeros(
        [n_prototypes,
         prototype_shape[1],
         prototype_shape[2],
         prototype_shape[3]])

    '''
    proto_rf_boxes and proto_bound_boxes column:
    0: image index in the entire dataset
    1: height start index
    2: height end index
    3: width start index
    4: width end index
    5: (optional) class identity
    '''
    if save_prototype_class_identity:
        proto_rf_boxes = np.full(shape=[n_prototypes, 6],
                                    fill_value=-1)
        proto_bound_boxes = np.full(shape=[n_prototypes, 6],
                                            fill_value=-1)
    else:
        proto_rf_boxes = np.full(shape=[n_prototypes, 5],
                                    fill_value=-1)
        proto_bound_boxes = np.full(shape=[n_prototypes, 5],
                                            fill_value=-1)

    if root_dir_for_saving_prototypes != None:
        if epoch_number != None:
            proto_epoch_dir = os.path.join(root_dir_for_saving_prototypes,
                                           'epoch-'+str(epoch_number))
            makedir(proto_epoch_dir)
        else:
            proto_epoch_dir = root_dir_for_saving_prototypes
    else:
        proto_epoch_dir = None

    search_batch_size = dataloader.batch_size

    num_classes = prototype_network_parallel.module.num_classes

    for push_iter, (search_batch_input, search_y) in enumerate(dataloader):
        '''
        start_index_of_search keeps track of the index of the image
        assigned to serve as prototype
        '''
        start_index_of_search_batch = push_iter * search_batch_size

        update_prototypes_on_batch(search_batch_input,
                                   start_index_of_search_batch,
                                   prototype_network_parallel,
                                   global_min_proto_dist,
                                   global_min_fmap_patches,
                                   proto_rf_boxes,
                                   proto_bound_boxes,
                                   class_specific=class_specific,
                                   search_y=search_y,
                                   num_classes=num_classes,
                                   preprocess_input_function=preprocess_input_function,
                                   prototype_layer_stride=prototype_layer_stride,
                                   dir_for_saving_prototypes=proto_epoch_dir,
                                   prototype_img_filename_prefix=prototype_img_filename_prefix,
                                   prototype_self_act_filename_prefix=prototype_self_act_filename_prefix,
                                   prototype_activation_function_in_numpy=prototype_activation_function_in_numpy)

    if proto_epoch_dir != None and proto_bound_boxes_filename_prefix != None:
        np.save(os.path.join(proto_epoch_dir, proto_bound_boxes_filename_prefix + '-receptive_field' + str(epoch_number) + '.npy'),
                proto_rf_boxes)
        np.save(os.path.join(proto_epoch_dir, proto_bound_boxes_filename_prefix + str(epoch_number) + '.npy'),
                proto_bound_boxes)

    log('\tExecuting push ...')
    prototype_update = np.reshape(global_min_fmap_patches,
                                  tuple(prototype_shape))
    prototype_network_parallel.module.prototype_vectors.data.copy_(torch.tensor(prototype_update, dtype=torch.float32).cuda())
    # prototype_network_parallel.cuda()
    end = time.time()
    log('\tpush time: \t{0}'.format(end -  start))
Beispiel #3
0
def find_k_nearest_patches_to_prototypes(dataloader,  # pytorch dataloader (must be unnormalized in [0,1])
                                         ppnet,  # pytorch network with prototype_vectors
                                         k=5,
                                         preprocess_input_function=None,  # normalize if needed
                                         full_save=False,  # save all the images
                                         root_dir_for_saving_images='./nearest',
                                         log=print,
                                         prototype_activation_function_in_numpy=None,
                                         only_n_most_activated=None):
    ppnet.eval()
    '''
    full_save=False will only return the class identity of the closest
    patches, but it will not save anything.
    '''
    print('        find nearest patches')
    start = time.time()
    n_prototypes = ppnet.num_prototypes

    prototype_shape = ppnet.prototype_shape
    max_dist = prototype_shape[1] * prototype_shape[2] * prototype_shape[3]

    protoL_rf_info = ppnet.proto_layer_rf_info

    heaps = []
    # allocate an array of n_prototypes number of heaps
    for _ in range(n_prototypes):
        # a heap in python is just a maintained list
        heaps.append([])

    with tqdm(total=len(dataloader.dataset), unit='bag') as pbar:
        for idx, (search_batch_raw, search_batch, search_y) in enumerate(dataloader):
            with torch.no_grad():
                search_batch = search_batch.cuda()
                ppnet.forward(search_batch)
                proto_dist_torch = ppnet.distances
                attention = ppnet.A
            raw_sample = search_batch_raw[0]

            # protoL_input_ = np.copy(protoL_input_torch.detach().cpu().numpy())
            proto_dist_ = np.copy(proto_dist_torch.detach().cpu().numpy())

            if only_n_most_activated:
                most_activates = list(
                    torch.topk(attention, min(only_n_most_activated, attention.shape[1]), dim=-1, largest=True)[1].detach().cpu().numpy()[0])

            for img_idx, distance_map in enumerate(proto_dist_):
                if only_n_most_activated:
                    if img_idx not in most_activates:
                        continue

                for j in range(n_prototypes):
                    # find the closest patches in this batch to prototype j

                    closest_patch_distance_to_prototype_j = np.amin(distance_map[j])

                    if full_save:
                        closest_patch = ImagePatchLazy(
                            label=search_y[0],
                            distance=closest_patch_distance_to_prototype_j,
                            batch_idx=idx,
                            image_idx=img_idx,
                            dataset=dataloader.dataset,
                            protoL_rf_info=protoL_rf_info,
                            distance_map_j=distance_map[j],
                            prototype_activation_function_in_numpy=prototype_activation_function_in_numpy,
                            prototype_activation_function=ppnet.prototype_activation_function,
                            max_dist=max_dist,
                            epsilon=ppnet.epsilon
                        )
                    else:
                        closest_patch = ImagePatchInfo(label=search_y[0],
                                                       distance=closest_patch_distance_to_prototype_j)

                    # add to the j-th heap
                    if len(heaps[j]) < k:
                        heapq.heappush(heaps[j], closest_patch)
                    else:
                        # heappushpop runs more efficiently than heappush
                        # followed by heappop
                        heapq.heappushpop(heaps[j], closest_patch)
            pbar.update(1)

    # after looping through the dataset every heap will
    # have the k closest prototypes
    for j in range(n_prototypes):
        # finally sort the heap; the heap only contains the k closest
        # but they are not ranked yet
        heaps[j].sort()
        heaps[j] = heaps[j][::-1]

        if full_save:

            dir_for_saving_images = os.path.join(root_dir_for_saving_images,
                                                 str(j))
            makedir(dir_for_saving_images)

            labels = []

            for i, patch in enumerate(heaps[j]):
                # save the activation pattern of the original image where the patch comes from
                np.save(os.path.join(dir_for_saving_images,
                                     'nearest-' + str(i + 1) + '_act.npy'),
                        patch.act_pattern)

                # save the original image where the patch comes from
                plt.imsave(fname=os.path.join(dir_for_saving_images,
                                              'nearest-' + str(i + 1) + '_original.png'),
                           arr=patch.original_img,
                           vmin=0.0,
                           vmax=1.0)

                # overlay (upsampled) activation on original image and save the result
                img_size = patch.original_img.shape[0]
                upsampled_act_pattern = cv2.resize(patch.act_pattern,
                                                   dsize=(img_size, img_size),
                                                   interpolation=cv2.INTER_CUBIC)
                rescaled_act_pattern = upsampled_act_pattern - np.amin(upsampled_act_pattern)
                rescaled_act_pattern = rescaled_act_pattern / np.amax(rescaled_act_pattern)
                heatmap = cv2.applyColorMap(np.uint8(255 * rescaled_act_pattern), cv2.COLORMAP_JET)
                heatmap = np.float32(heatmap) / 255
                heatmap = heatmap[..., ::-1]
                overlayed_original_img = 0.5 * patch.original_img + 0.3 * heatmap
                plt.imsave(fname=os.path.join(dir_for_saving_images,
                                              'nearest-' + str(i + 1) + '_original_with_heatmap.png'),
                           arr=overlayed_original_img,
                           vmin=0.0,
                           vmax=1.0)

                # if different from original image, save the patch (i.e. receptive field)
                if patch.patch.shape[0] != img_size or patch.patch.shape[1] != img_size:
                    np.save(os.path.join(dir_for_saving_images,
                                         'nearest-' + str(i + 1) + '_receptive_field_indices.npy'),
                            patch.patch_indices)
                    plt.imsave(fname=os.path.join(dir_for_saving_images,
                                                  'nearest-' + str(i + 1) + '_receptive_field.png'),
                               arr=patch.patch,
                               vmin=0.0,
                               vmax=1.0)
                    # save the receptive field patch with heatmap
                    overlayed_patch = overlayed_original_img[patch.patch_indices[0]:patch.patch_indices[1],
                                      patch.patch_indices[2]:patch.patch_indices[3], :]
                    plt.imsave(fname=os.path.join(dir_for_saving_images,
                                                  'nearest-' + str(i + 1) + '_receptive_field_with_heatmap.png'),
                               arr=overlayed_patch,
                               vmin=0.0,
                               vmax=1.0)

                # save the highly activated patch    
                high_act_patch_indices = find_high_activation_crop(upsampled_act_pattern)
                high_act_patch = patch.original_img[high_act_patch_indices[0]:high_act_patch_indices[1],
                                 high_act_patch_indices[2]:high_act_patch_indices[3], :]
                np.save(os.path.join(dir_for_saving_images,
                                     'nearest-' + str(i + 1) + '_high_act_patch_indices.npy'),
                        high_act_patch_indices)
                plt.imsave(fname=os.path.join(dir_for_saving_images,
                                              'nearest-' + str(i + 1) + '_high_act_patch.png'),
                           arr=high_act_patch,
                           vmin=0.0,
                           vmax=1.0)
                # save the original image with bounding box showing high activation patch
                imsave_with_bbox(fname=os.path.join(dir_for_saving_images,
                                                    'nearest-' + str(i + 1) + '_high_act_patch_in_original_img.png'),
                                 img_rgb=patch.original_img,
                                 bbox_height_start=high_act_patch_indices[0],
                                 bbox_height_end=high_act_patch_indices[1],
                                 bbox_width_start=high_act_patch_indices[2],
                                 bbox_width_end=high_act_patch_indices[3], color=(0, 255, 255))

            labels = np.array([patch.label for patch in heaps[j]])
            np.save(os.path.join(dir_for_saving_images, 'class_id.npy'),
                    labels)

    labels_all_prototype = np.array([[patch.label for patch in heaps[j]] for j in range(n_prototypes)])

    if full_save:
        np.save(os.path.join(root_dir_for_saving_images, 'full_class_id.npy'),
                labels_all_prototype)

    end = time.time()
    log('        find nearest patches time: \t{0}'.format(end - start))

    return labels_all_prototype
Beispiel #4
0
    def push_orig(self, epoch_number):
        self.prototype_network_parallel.eval()
        self.log('\tpush')

        start = time.time()
        prototype_shape = self.prototype_network_parallel.module.prototype_shape
        n_prototypes = self.prototype_network_parallel.module.num_prototypes
        # saves the closest distance seen so far
        global_min_proto_dist = np.full(n_prototypes, np.inf)
        # saves the patch representation that gives the current smallest distance
        global_min_fmap_patches = np.zeros([
            n_prototypes, prototype_shape[1], prototype_shape[2],
            prototype_shape[3]
        ])
        '''
        proto_rf_boxes and proto_bound_boxes column:
        0: image index in the entire dataset
        1: height start index
        2: height end index
        3: width start index
        4: width end index
        5: (optional) class identity
        '''
        if self.save_prototype_class_identity:
            proto_rf_boxes = np.full(shape=[n_prototypes, 6], fill_value=-1)
            proto_bound_boxes = np.full(shape=[n_prototypes, 6], fill_value=-1)
        else:
            proto_rf_boxes = np.full(shape=[n_prototypes, 5], fill_value=-1)
            proto_bound_boxes = np.full(shape=[n_prototypes, 5], fill_value=-1)

        if self.dir_for_saving_prototypes != None:
            if epoch_number != None:
                proto_epoch_dir = os.path.join(self.dir_for_saving_prototypes,
                                               'epoch-' + str(epoch_number))
                makedir(proto_epoch_dir)
            else:
                # XXX I think dir_for_saving_proto and root_dir are actually
                # different variables and it wasnt a misnaming. Oh well
                # I'll come back to this later
                proto_epoch_dir = self.dir_for_saving_prototypes
        else:
            proto_epoch_dir = None

        search_batch_size = self.dataloader.batch_size

        num_classes = self.prototype_network_parallel.module.num_classes

        for push_iter, (search_batch_input,
                        search_y) in enumerate(self.dataloader):
            '''
            start_index_of_search keeps track of the index of the image
            assigned to serve as prototype
            '''
            start_index_of_search_batch = push_iter * search_batch_size

            self.update_prototypes_on_batch(search_batch_input,
                                            start_index_of_search_batch,
                                            global_min_proto_dist,
                                            global_min_fmap_patches,
                                            proto_rf_boxes,
                                            proto_bound_boxes,
                                            search_y=search_y,
                                            num_classes=num_classes)

        if proto_epoch_dir != None and self.proto_bound_boxes_filename_prefix != None:
            np.save(
                os.path.join(
                    proto_epoch_dir, self.proto_bound_boxes_filename_prefix +
                    '-receptive_field' + str(epoch_number) + '.npy'),
                proto_rf_boxes)
            np.save(
                os.path.join(
                    proto_epoch_dir, self.proto_bound_boxes_filename_prefix +
                    str(epoch_number) + '.npy'), proto_bound_boxes)

        # XXX push here is different because we're choosing top K vectors.
        self.log('\tExecuting push ...')
        prototype_update = np.reshape(global_min_fmap_patches,
                                      tuple(prototype_shape))
        self.prototype_network_parallel.module.prototype_vectors.data.copy_(
            torch.tensor(prototype_update, dtype=torch.float32).cuda())
        # prototype_network_parallel.cuda()
        end = time.time()
        self.log('\tpush time: \t{0}'.format(end - start))
valid_loader = torch.utils.data.DataLoader(valid_dataset,
                                           batch_size=batch_size,
                                           shuffle=False)

# test set
test_dataset = datasets.ImageFolder(test_dir, transform_test)
test_loader = torch.utils.data.DataLoader(test_dataset,
                                          batch_size=batch_size,
                                          shuffle=False)

# set up directories
root_dir_for_saving_train_images = os.path.join(args.resume_path,
                                                'nearest_train')
root_dir_for_saving_test_images = os.path.join(args.resume_path,
                                               'nearest_test')
makedir(root_dir_for_saving_train_images)
makedir(root_dir_for_saving_test_images)

root = Node("root")
root.add_children(
    ['animal', 'vehicle', 'everyday_object', 'weapon', 'scuba_diver'])
root.add_children_to('animal', ['non_primate', 'primate'])
root.add_children_to('non_primate',
                     ['African_elephant', 'giant_panda', 'lion'])
root.add_children_to('primate', ['capuchin', 'gibbon', 'orangutan'])
root.add_children_to('vehicle', ['ambulance', 'pickup', 'sports_car'])
root.add_children_to('everyday_object', ['laptop', 'sandal', 'wine_bottle'])
root.add_children_to('weapon', ['assault_rifle', 'rifle'])
root.assign_all_descendents()

flat_root = Node("root")
Beispiel #6
0
need_push = ('nopush' in original_model_name)
if need_push:
    assert (False)  # pruning must happen after push
else:
    epoch = original_model_name.split('push')[0]

if '_' in epoch:
    epoch = int(epoch.split('_')[0])
else:
    epoch = int(epoch)

model_dir = os.path.join(
    original_model_dir,
    'pruned_prototypes_epoch{}_k{}_pt{}'.format(epoch, k, prune_threshold))
makedir(model_dir)
shutil.copy(src=os.path.join(os.getcwd(), __file__), dst=model_dir)

log, logclose = create_logger(
    log_filename=os.path.join(model_dir, 'prune.log'))

ppnet = torch.load(original_model_dir + original_model_name)
ppnet = ppnet.cuda()
ppnet_multi = torch.nn.DataParallel(ppnet)
class_specific = True

# load the data
from settings import train_dir, test_dir, train_push_dir, train_inst_dir

train_batch_size = 80
test_batch_size = 100
def find_k_nearest_patches_to_prototypes(
        dataloader,  # pytorch dataloader (must be unnormalized in [0,1])
        prototype_network_parallel,  # pytorch network with prototype_vectors
        k=5,
        preprocess_input_function=None,  # normalize if needed
        full_save=False,  # save all the images
        root_dir_for_saving_images='./nearest',
        log=print,
        prototype_activation_function_in_numpy=None):
    prototype_network_parallel.eval()
    '''
    full_save=False will only return the class identity of the closest
    patches, but it will not save anything.
    '''
    log('find nearest patches')
    start = time.time()
    n_prototypes = prototype_network_parallel.module.num_prototypes

    prototype_shape = prototype_network_parallel.module.prototype_shape
    max_dist = prototype_shape[1] * prototype_shape[2] * prototype_shape[3]

    protoL_rf_info = prototype_network_parallel.module.proto_layer_rf_info

    heaps = []
    # allocate an array of n_prototypes number of heaps
    for _ in range(n_prototypes):
        # a heap in python is just a maintained list
        heaps.append([])

    for idx, (search_batch_input, search_y) in enumerate(dataloader):
        print('batch {}'.format(idx))
        if preprocess_input_function is not None:
            # print('preprocessing input for pushing ...')
            # search_batch = copy.deepcopy(search_batch_input)
            search_batch = preprocess_input_function(search_batch_input)

        else:
            search_batch = search_batch_input

        with torch.no_grad():
            search_batch = search_batch.cuda()
            protoL_input_torch, proto_dist_torch = \
                prototype_network_parallel.module.push_forward(search_batch)

        #protoL_input_ = np.copy(protoL_input_torch.detach().cpu().numpy())
        proto_dist_ = np.copy(proto_dist_torch.detach().cpu().numpy())

        for img_idx, distance_map in enumerate(proto_dist_):
            for j in range(n_prototypes):
                # find the closest patches in this batch to prototype j

                closest_patch_distance_to_prototype_j = np.amin(
                    distance_map[j])

                if full_save:
                    closest_patch_indices_in_distance_map_j = \
                        list(np.unravel_index(np.argmin(distance_map[j],axis=None),
                                              distance_map[j].shape))
                    closest_patch_indices_in_distance_map_j = [
                        0
                    ] + closest_patch_indices_in_distance_map_j
                    closest_patch_indices_in_img = \
                        compute_rf_prototype(search_batch.size(2),
                                             closest_patch_indices_in_distance_map_j,
                                             protoL_rf_info)
                    closest_patch = \
                        search_batch_input[img_idx, :,
                                           closest_patch_indices_in_img[1]:closest_patch_indices_in_img[2],
                                           closest_patch_indices_in_img[3]:closest_patch_indices_in_img[4]]
                    closest_patch = closest_patch.numpy()
                    closest_patch = np.transpose(closest_patch, (1, 2, 0))

                    original_img = search_batch_input[img_idx].numpy()
                    original_img = np.transpose(original_img, (1, 2, 0))

                    if prototype_network_parallel.module.prototype_activation_function == 'log':
                        act_pattern = np.log(
                            (distance_map[j] + 1) /
                            (distance_map[j] +
                             prototype_network_parallel.module.epsilon))
                    elif prototype_network_parallel.module.prototype_activation_function == 'linear':
                        act_pattern = max_dist - distance_map[j]
                    else:
                        act_pattern = prototype_activation_function_in_numpy(
                            distance_map[j])

                    # 4 numbers: height_start, height_end, width_start, width_end
                    patch_indices = closest_patch_indices_in_img[1:5]

                    # construct the closest patch object
                    closest_patch = ImagePatch(
                        patch=closest_patch,
                        label=search_y[img_idx],
                        distance=closest_patch_distance_to_prototype_j,
                        original_img=original_img,
                        act_pattern=act_pattern,
                        patch_indices=patch_indices)
                else:
                    closest_patch = ImagePatchInfo(
                        label=search_y[img_idx],
                        distance=closest_patch_distance_to_prototype_j)

                # add to the j-th heap
                if len(heaps[j]) < k:
                    heapq.heappush(heaps[j], closest_patch)
                else:
                    # heappushpop runs more efficiently than heappush
                    # followed by heappop
                    heapq.heappushpop(heaps[j], closest_patch)

    # after looping through the dataset every heap will
    # have the k closest prototypes
    for j in range(n_prototypes):
        # finally sort the heap; the heap only contains the k closest
        # but they are not ranked yet
        heaps[j].sort()
        heaps[j] = heaps[j][::-1]

        if full_save:

            dir_for_saving_images = os.path.join(root_dir_for_saving_images,
                                                 str(j))
            makedir(dir_for_saving_images)

            labels = []

            for i, patch in enumerate(heaps[j]):
                # save the activation pattern of the original image where the patch comes from
                np.save(
                    os.path.join(dir_for_saving_images,
                                 'nearest-' + str(i + 1) + '_act.npy'),
                    patch.act_pattern)

                # save the original image where the patch comes from
                plt.imsave(fname=os.path.join(
                    dir_for_saving_images,
                    'nearest-' + str(i + 1) + '_original.png'),
                           arr=patch.original_img,
                           vmin=0.0,
                           vmax=1.0)

                # overlay (upsampled) activation on original image and save the result
                img_size = patch.original_img.shape[0]
                upsampled_act_pattern = cv2.resize(
                    patch.act_pattern,
                    dsize=(img_size, img_size),
                    interpolation=cv2.INTER_CUBIC)
                rescaled_act_pattern = upsampled_act_pattern - np.amin(
                    upsampled_act_pattern)
                rescaled_act_pattern = rescaled_act_pattern / np.amax(
                    rescaled_act_pattern)
                heatmap = cv2.applyColorMap(
                    np.uint8(255 * rescaled_act_pattern), cv2.COLORMAP_JET)
                heatmap = np.float32(heatmap) / 255
                heatmap = heatmap[..., ::-1]
                overlayed_original_img = 0.5 * patch.original_img + 0.3 * heatmap
                plt.imsave(fname=os.path.join(
                    dir_for_saving_images,
                    'nearest-' + str(i + 1) + '_original_with_heatmap.png'),
                           arr=overlayed_original_img,
                           vmin=0.0,
                           vmax=1.0)

                # if different from original image, save the patch (i.e. receptive field)
                if patch.patch.shape[0] != img_size or patch.patch.shape[
                        1] != img_size:
                    np.save(
                        os.path.join(
                            dir_for_saving_images, 'nearest-' + str(i + 1) +
                            '_receptive_field_indices.npy'),
                        patch.patch_indices)
                    plt.imsave(fname=os.path.join(
                        dir_for_saving_images,
                        'nearest-' + str(i + 1) + '_receptive_field.png'),
                               arr=patch.patch,
                               vmin=0.0,
                               vmax=1.0)
                    # save the receptive field patch with heatmap
                    overlayed_patch = overlayed_original_img[
                        patch.patch_indices[0]:patch.patch_indices[1],
                        patch.patch_indices[2]:patch.patch_indices[3], :]
                    plt.imsave(fname=os.path.join(
                        dir_for_saving_images, 'nearest-' + str(i + 1) +
                        '_receptive_field_with_heatmap.png'),
                               arr=overlayed_patch,
                               vmin=0.0,
                               vmax=1.0)

                # save the highly activated patch
                high_act_patch_indices = find_high_activation_crop(
                    upsampled_act_pattern)
                high_act_patch = patch.original_img[
                    high_act_patch_indices[0]:high_act_patch_indices[1],
                    high_act_patch_indices[2]:high_act_patch_indices[3], :]
                np.save(
                    os.path.join(
                        dir_for_saving_images, 'nearest-' + str(i + 1) +
                        '_high_act_patch_indices.npy'), high_act_patch_indices)
                plt.imsave(fname=os.path.join(
                    dir_for_saving_images,
                    'nearest-' + str(i + 1) + '_high_act_patch.png'),
                           arr=high_act_patch,
                           vmin=0.0,
                           vmax=1.0)
                # save the original image with bounding box showing high activation patch
                imsave_with_bbox(fname=os.path.join(
                    dir_for_saving_images, 'nearest-' + str(i + 1) +
                    '_high_act_patch_in_original_img.png'),
                                 img_rgb=patch.original_img,
                                 bbox_height_start=high_act_patch_indices[0],
                                 bbox_height_end=high_act_patch_indices[1],
                                 bbox_width_start=high_act_patch_indices[2],
                                 bbox_width_end=high_act_patch_indices[3],
                                 color=(0, 255, 255))

            labels = np.array([patch.label for patch in heaps[j]])
            np.save(os.path.join(dir_for_saving_images, 'class_id.npy'),
                    labels)

    labels_all_prototype = np.array([[patch.label for patch in heaps[j]]
                                     for j in range(n_prototypes)])

    if full_save:
        np.save(os.path.join(root_dir_for_saving_images, 'full_class_id.npy'),
                labels_all_prototype)

    end = time.time()
    log('\tfind nearest patches time: \t{0}'.format(end - start))

    return labels_all_prototype
load_model_dir = './saved_models/vgg19/003/'  #args.modeldir[0]
load_model_name = '160push0.8167.pth'  #args.model[0] '10_18push0.7822.pth'

#if load_model_dir[-1] == '/':
#    model_base_architecture = load_model_dir.split('/')[-3]
#    experiment_run = load_model_dir.split('/')[-2]
#else:
#    model_base_architecture = load_model_dir.split('/')[-2]
#    experiment_run = load_model_dir.split('/')[-1]

model_base_architecture = load_model_dir.split('/')[2]
experiment_run = '/'.join(load_model_dir.split('/')[3:])

save_analysis_path = os.path.join(test_image_dir, model_base_architecture,
                                  experiment_run, load_model_name)
makedir(save_analysis_path)

log, logclose = create_logger(
    log_filename=os.path.join(save_analysis_path, 'local_analysis.log'))

load_model_path = os.path.join(load_model_dir, load_model_name)
epoch_number_str = re.search(r'\d+', load_model_name).group(0)
start_epoch_number = int(epoch_number_str)

log('load model from ' + load_model_path)
log('model base architecture: ' + model_base_architecture)
log('experiment run: ' + experiment_run)

ppnet = torch.load(load_model_path)
ppnet = ppnet.cuda()
ppnet_multi = torch.nn.DataParallel(ppnet)
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.repeat(3, 1, 1))
    ]))
test_loader = torch.utils.data.DataLoader(test_dataset,
                                          batch_size=test_batch_size,
                                          shuffle=False,
                                          num_workers=4,
                                          pin_memory=False)

root_dir_for_saving_train_images = os.path.join(
    load_model_dir,
    load_model_name.split('.pth')[0] + '_nearest_train')
root_dir_for_saving_test_images = os.path.join(
    load_model_dir,
    load_model_name.split('.pth')[0] + '_nearest_test')
makedir(root_dir_for_saving_train_images)
makedir(root_dir_for_saving_test_images)

# save prototypes in original images
load_img_dir = os.path.join(load_model_dir, 'img')
prototype_info = np.load(
    os.path.join(load_img_dir, 'epoch-' + str(start_epoch_number),
                 'bb' + str(start_epoch_number) + '.npy'))


def save_prototype_original_img_with_bbox(fname,
                                          epoch,
                                          index,
                                          bbox_height_start,
                                          bbox_height_end,
                                          bbox_width_start,
Beispiel #10
0
def prune_prototypes(dataloader,
                     ppnet,
                     k,
                     prune_threshold,
                     preprocess_input_function,
                     original_model_dir,
                     epoch_number,
                     log=print,
                     copy_prototype_imgs=True,
                     find_threshold_prune_n_patches=None,
                     only_n_most_activated=None):
    ### run global analysis
    nearest_train_patch_class_ids = \
        find_nearest.find_k_nearest_patches_to_prototypes(dataloader=dataloader,
                                                          ppnet=ppnet,
                                                          k=k,
                                                          preprocess_input_function=preprocess_input_function,
                                                          full_save=False,
                                                          log=log, only_n_most_activated=only_n_most_activated)

    ### find prototypes to prune
    original_num_prototypes = ppnet.num_prototypes

    if find_threshold_prune_n_patches is None:
        prototypes_to_prune = find_prototypes_to_prune(
            ppnet, nearest_train_patch_class_ids, prune_threshold)
    else:
        low = 0.0001
        high = 10
        prototypes_to_prune = []
        for _ in range(30):
            m = (low + high) / 2
            current_prototypes_to_prune = find_prototypes_to_prune(
                ppnet, nearest_train_patch_class_ids, m)
            print(m, current_prototypes_to_prune)
            if len(current_prototypes_to_prune
                   ) > find_threshold_prune_n_patches:
                high = m
            else:
                low = m
            if len(current_prototypes_to_prune
                   ) <= find_threshold_prune_n_patches:
                prototypes_to_prune = current_prototypes_to_prune
                prune_threshold = m

    log('k = {}, prune_threshold = {}'.format(k, prune_threshold))
    log('{} prototypes will be pruned'.format(len(prototypes_to_prune)))

    ### bookkeeping of prototypes to be pruned
    class_of_prototypes_to_prune = \
        torch.argmax(
            ppnet.prototype_class_identity[prototypes_to_prune],
            dim=1).numpy().reshape(-1, 1)
    prototypes_to_prune_np = np.array(prototypes_to_prune).reshape(-1, 1)
    prune_info = np.hstack(
        (prototypes_to_prune_np, class_of_prototypes_to_prune))
    makedir(
        os.path.join(
            original_model_dir, 'pruned_prototypes_epoch{}_k{}_pt{}'.format(
                epoch_number, k, prune_threshold)))
    np.save(
        os.path.join(
            original_model_dir, 'pruned_prototypes_epoch{}_k{}_pt{}'.format(
                epoch_number, k, prune_threshold), 'prune_info.npy'),
        prune_info)

    ### prune prototypes
    print('Prototypes to prune', prototypes_to_prune)
    ppnet.prune_prototypes(prototypes_to_prune)
    # torch.save(obj=ppnet,
    #           f=os.path.join(original_model_dir, 'pruned_prototypes_epoch{}_k{}_pt{}'.format(epoch_number,
    #                                              k,
    #                                              prune_threshold),
    #                          model_name + '-pruned.pth'))
    if copy_prototype_imgs:
        original_img_dir = os.path.join(original_model_dir, 'img',
                                        'epoch-%d' % epoch_number)
        dst_img_dir = os.path.join(
            original_model_dir, 'pruned_prototypes_epoch{}_k{}_pt{}'.format(
                epoch_number, k, prune_threshold), 'img',
            'epoch-%d' % epoch_number)
        makedir(dst_img_dir)
        prototypes_to_keep = list(
            set(range(original_num_prototypes)) - set(prototypes_to_prune))

        for idx in range(len(prototypes_to_keep)):
            shutil.copyfile(src=os.path.join(
                original_img_dir,
                'prototype-img%d.png' % prototypes_to_keep[idx]),
                            dst=os.path.join(dst_img_dir,
                                             'prototype-img%d.png' % idx))

            shutil.copyfile(src=os.path.join(
                original_img_dir,
                'prototype-img-original%d.png' % prototypes_to_keep[idx]),
                            dst=os.path.join(
                                dst_img_dir,
                                'prototype-img-original%d.png' % idx))

            shutil.copyfile(
                src=os.path.join(
                    original_img_dir,
                    'prototype-img-original_with_self_act%d.png' %
                    prototypes_to_keep[idx]),
                dst=os.path.join(
                    dst_img_dir,
                    'prototype-img-original_with_self_act%d.png' % idx))

            shutil.copyfile(src=os.path.join(
                original_img_dir,
                'prototype-self-act%d.npy' % prototypes_to_keep[idx]),
                            dst=os.path.join(dst_img_dir,
                                             'prototype-self-act%d.npy' % idx))

            bb = np.load(
                os.path.join(original_img_dir, 'bb%d.npy' % epoch_number))
            bb = bb[prototypes_to_keep]
            np.save(os.path.join(dst_img_dir, 'bb%d.npy' % epoch_number), bb)

            bb_rf = np.load(
                os.path.join(original_img_dir,
                             'bb-receptive_field%d.npy' % epoch_number))
            bb_rf = bb_rf[prototypes_to_keep]
            np.save(
                os.path.join(dst_img_dir,
                             'bb-receptive_field%d.npy' % epoch_number), bb_rf)

    return prune_info
Beispiel #11
0
def prune_prototypes(
        dataloader,
        prototype_network_parallel,
        k,
        prune_threshold,
        preprocess_input_function,
        original_model_dir,
        epoch_number,
        #model_name=None,
        log=print,
        copy_prototype_imgs=True):
    ### run global analysis
    nearest_train_patch_class_ids = \
        find_nearest.find_k_nearest_patches_to_prototypes(dataloader=dataloader,
                                                          prototype_network_parallel=prototype_network_parallel,
                                                          k=k,
                                                          preprocess_input_function=preprocess_input_function,
                                                          full_save=False,
                                                          log=log)

    ### find prototypes to prune
    original_num_prototypes = prototype_network_parallel.module.num_prototypes

    prototypes_to_prune = []
    for j in range(prototype_network_parallel.module.num_prototypes):
        class_j = torch.argmax(prototype_network_parallel.module.
                               prototype_class_identity[j]).item()
        nearest_train_patch_class_counts_j = Counter(
            nearest_train_patch_class_ids[j])
        # if no such element is in Counter, it will return 0
        if nearest_train_patch_class_counts_j[class_j] < prune_threshold:
            prototypes_to_prune.append(j)

    log('k = {}, prune_threshold = {}'.format(k, prune_threshold))
    log('{} prototypes will be pruned'.format(len(prototypes_to_prune)))

    ### bookkeeping of prototypes to be pruned
    class_of_prototypes_to_prune = \
        torch.argmax(
            prototype_network_parallel.module.prototype_class_identity[prototypes_to_prune],
            dim=1).numpy().reshape(-1, 1)
    prototypes_to_prune_np = np.array(prototypes_to_prune).reshape(-1, 1)
    prune_info = np.hstack(
        (prototypes_to_prune_np, class_of_prototypes_to_prune))
    makedir(
        os.path.join(
            original_model_dir, 'pruned_prototypes_epoch{}_k{}_pt{}'.format(
                epoch_number, k, prune_threshold)))
    np.save(
        os.path.join(
            original_model_dir, 'pruned_prototypes_epoch{}_k{}_pt{}'.format(
                epoch_number, k, prune_threshold), 'prune_info.npy'),
        prune_info)

    ### prune prototypes
    prototype_network_parallel.module.prune_prototypes(prototypes_to_prune)
    #torch.save(obj=prototype_network_parallel.module,
    #           f=os.path.join(original_model_dir, 'pruned_prototypes_epoch{}_k{}_pt{}'.format(epoch_number,
    #                                              k,
    #                                              prune_threshold),
    #                          model_name + '-pruned.pth'))
    if copy_prototype_imgs:
        original_img_dir = os.path.join(original_model_dir, 'img',
                                        'epoch-%d' % epoch_number)
        dst_img_dir = os.path.join(
            original_model_dir, 'pruned_prototypes_epoch{}_k{}_pt{}'.format(
                epoch_number, k, prune_threshold), 'img',
            'epoch-%d' % epoch_number)
        makedir(dst_img_dir)
        prototypes_to_keep = list(
            set(range(original_num_prototypes)) - set(prototypes_to_prune))

        for idx in range(len(prototypes_to_keep)):
            shutil.copyfile(src=os.path.join(
                original_img_dir,
                'prototype-img%d.png' % prototypes_to_keep[idx]),
                            dst=os.path.join(dst_img_dir,
                                             'prototype-img%d.png' % idx))

            shutil.copyfile(src=os.path.join(
                original_img_dir,
                'prototype-img-original%d.png' % prototypes_to_keep[idx]),
                            dst=os.path.join(
                                dst_img_dir,
                                'prototype-img-original%d.png' % idx))

            shutil.copyfile(
                src=os.path.join(
                    original_img_dir,
                    'prototype-img-original_with_self_act%d.png' %
                    prototypes_to_keep[idx]),
                dst=os.path.join(
                    dst_img_dir,
                    'prototype-img-original_with_self_act%d.png' % idx))

            shutil.copyfile(src=os.path.join(
                original_img_dir,
                'prototype-self-act%d.npy' % prototypes_to_keep[idx]),
                            dst=os.path.join(dst_img_dir,
                                             'prototype-self-act%d.npy' % idx))

            bb = np.load(
                os.path.join(original_img_dir, 'bb%d.npy' % epoch_number))
            bb = bb[prototypes_to_keep]
            np.save(os.path.join(dst_img_dir, 'bb%d.npy' % epoch_number), bb)

            bb_rf = np.load(
                os.path.join(original_img_dir,
                             'bb-receptive_field%d.npy' % epoch_number))
            bb_rf = bb_rf[prototypes_to_keep]
            np.save(
                os.path.join(dst_img_dir,
                             'bb-receptive_field%d.npy' % epoch_number), bb_rf)

    return prune_info
Beispiel #12
0
def generate_prototype_activation_matrix(ppnet, test_push_dataloader, push_dataloader, epoch,
                                         model_dir, device, bag_class=0, N=10, do_nearest=True):
    print('    analysis for class', bag_class)
    epoch_number_str = str(epoch)
    load_img_dir = os.path.join(model_dir, 'img')

    prototype_info = np.load(os.path.join(load_img_dir, 'epoch-' + epoch_number_str, 'bb' + epoch_number_str + '.npy'))
    prototype_img_identity = prototype_info[:, -1]
    prototype_shape = ppnet.prototype_shape
    max_dist = prototype_shape[1] * prototype_shape[2] * prototype_shape[3]

    # print('Prototypes are chosen from ' + str(len(set(prototype_img_identity))) + ' number of classes.')
    # print('Their class identities are: ' + str(prototype_img_identity))

    raw, bag, label = next(
        ((r, b, l) for r, b, l in iter(test_push_dataloader) if l.max().unsqueeze(0) == bag_class and len(b) >= N))

    count_positive_patches = sum(label)
    if len(label) > 1:
        label = label.max().unsqueeze(0)

    bag = bag.squeeze(0)
    img_size = raw[0].shape[1]

    with torch.no_grad():
        ppnet.eval()

        images_test = bag.to(device)
        labels_test = label.to(device)

        logits, min_distances, attention, vector_scores = ppnet.forward_(
            images_test)  # function forward in model.py should return logits, min_distances, A, prototype_activations

        conv_output, distances = ppnet.push_forward(images_test)
        prototype_activation_patterns = ppnet.distance_2_similarity(distances)
        if ppnet.prototype_activation_function == 'linear':
            prototype_activation_patterns = prototype_activation_patterns + max_dist

        tables = []
        for i in range(logits.size(0)):
            tables.append((torch.argmax(logits, dim=1)[i].item(), labels_test[i].item()))
            # print(str(i) + ' ' + str(tables[-1]))

        idx = 0
        predicted_cls = tables[idx][0]
        correct_cls = tables[idx][1]
        # print('Predicted: ' + str(predicted_cls))
        # print('Actual: ' + str(correct_cls))

    # Take the N patches with the most attention
    at = attention.squeeze(0).detach().cpu().numpy()
    top_patches = at.argsort()[-N:][::-1]
    # print(f'        patch indexes: {top_patches}')

    imgs = [raw[i].permute(1, 2, 0) for i in top_patches]

    # Take the most highly activated area of the image by prototype
    imgs_with_self_activation_by_prototype = []
    for img, idx in zip(imgs, top_patches):
        original_img = img

        self_activation_for_img = []
        # for every prototype
        for i in range(len(prototype_img_identity)):
            activation_pattern = prototype_activation_patterns[idx][i].detach().cpu().numpy()
            upsampled_activation_pattern = cv2.resize(activation_pattern,
                                                      dsize=(original_img.shape[0], original_img.shape[1]),
                                                      interpolation=cv2.INTER_CUBIC)

            rescaled_activation_pattern = upsampled_activation_pattern - np.amin(upsampled_activation_pattern)
            rescaled_activation_pattern = rescaled_activation_pattern / np.amax(rescaled_activation_pattern)
            heatmap = cv2.applyColorMap(np.uint8(255 * rescaled_activation_pattern), cv2.COLORMAP_JET)
            heatmap = np.float32(heatmap) / 255
            heatmap = heatmap[..., ::-1]
            overlayed_img = 0.5 * original_img + 0.3 * heatmap

            self_activation_for_img.append(np.asarray(overlayed_img))

        imgs_with_self_activation_by_prototype.append(np.asarray(self_activation_for_img))

    ### Take prototypes

    prototype_dir = os.path.join(load_img_dir, 'epoch-' + epoch_number_str)

    prototypes = []
    for i in range(len(prototype_img_identity)):
        prototypes.append(plt.imread(f'{prototype_dir}/prototype-img{i}.png'))

    # Take prototypes img original with self activation
    prototypes_img_with_act = []
    for i in range(len(prototype_img_identity)):
        prototypes_img_with_act.append(plt.imread(f'{prototype_dir}/prototype-img-original_with_self_act{i}.png'))

    ### Find the k-nearest patches in the dataset to each prototype

    k = 5

    root_dir_for_saving_train_images = os.path.join(model_dir)
    makedir(root_dir_for_saving_train_images)

    if do_nearest:
        find_nearest.find_k_nearest_patches_to_prototypes(
            dataloader=push_dataloader,  # pytorch dataloader (must be unnormalized in [0,1])
            ppnet=ppnet,  # pytorch network with prototype_vectors
            k=k,
            full_save=True,
            root_dir_for_saving_images=root_dir_for_saving_train_images,
            log=print)

    k_nearest_patches = []

    for i in range(len(prototype_img_identity)):
        tmp = []
        for j in range(1, k + 1):
            tmp.append(plt.imread(f'{root_dir_for_saving_train_images}/{i}/nearest-{j}_original_with_heatmap.png'))
        k_nearest_patches.append(np.asarray(tmp))

    ###  Vector score for top patches
    arr = vector_scores.detach().cpu().numpy()
    arr = np.array([arr[i] for i in top_patches])

    def get_colors(inp, colormap, vmin=None, vmax=None):
        norm = plt.Normalize(vmin, vmax)
        return colormap(norm(inp))

    grid_score = np.around(arr.T, 2)
    colors = get_colors(grid_score, plt.cm.magma)
    len_proto = len(prototype_img_identity)

    # Set up the axes with gridspec
    fig = plt.figure(figsize=(2 * N + 2 + k, len_proto + 2))
    fig.suptitle(f'patches in bag: {len(bag)}, positive patches: {count_positive_patches}, class label: {label.item()}',
                 fontsize=40)

    grid = plt.GridSpec(len_proto + 2, 2 * N + 2 + k, hspace=0.04, wspace=0.04)

    # build a rectangle in axes coords
    left, width = .25, .5
    bottom, height = .25, .5
    right = left + width
    top = bottom + height

    # histogram
    l = 0
    for j in range(3 + k, 2 * N + 3 + k, 2):
        for i in range(2, len_proto + 2):
            main_ax = fig.add_subplot(grid[i, j])
            main_ax.set_facecolor(colors[i - 2][l])
            main_ax.text(0.5 * (left + right), 0.5 * (bottom + top), grid_score[i - 2][l],
                         horizontalalignment='center',
                         verticalalignment='center',
                         fontsize=15,
                         color='white' if grid_score[i - 2][l] < grid_score.max() * 0.9 else 'black',
                         transform=main_ax.transAxes)

            main_ax.get_xaxis().set_visible(False)
            main_ax.get_yaxis().set_visible(False)
        l = l + 1

    # images with self activation by prototype
    l = 0
    for j in range(2 + k, 2 * N + 2 + k, 2):
        for i in range(2, len_proto + 2):
            main_ax = fig.add_subplot(grid[i, j])
            main_ax.set_xlim([0, img_size])
            main_ax.set_ylim([0, img_size])
            main_ax.invert_yaxis()
            main_ax.imshow(imgs_with_self_activation_by_prototype[l][i - 2], aspect='auto')

            main_ax.get_xaxis().set_visible(False)
            main_ax.get_yaxis().set_visible(False)
        l = l + 1

    # prototypes
    for i in range(2, len_proto + 2):
        main_ax = fig.add_subplot(grid[i, 1 + k])
        main_ax.invert_yaxis()
        main_ax.imshow(prototypes[i - 2])

        main_ax.get_xaxis().set_visible(False)
        main_ax.get_yaxis().set_visible(False)

    # prototypes images with activation
    for i in range(2, len_proto + 2):
        main_ax = fig.add_subplot(grid[i, 0 + k])
        main_ax.set_xlim([0, img_size])
        main_ax.set_ylim([0, img_size])
        main_ax.invert_yaxis()
        main_ax.imshow(prototypes_img_with_act[i - 2])

        if i - 2 < len_proto // 2:
            main_ax.patch.set_edgecolor('red')
        else:
            main_ax.patch.set_edgecolor('green')
        main_ax.patch.set_linewidth('5')

        main_ax.get_xaxis().set_visible(False)
        main_ax.get_yaxis().set_visible(False)

    # k nearest patches
    for j in range(k):
        for i in range(2, len_proto + 2):
            main_ax = fig.add_subplot(grid[i, j])
            main_ax.set_xlim([0, img_size])
            main_ax.set_ylim([0, img_size])
            main_ax.invert_yaxis()
            main_ax.imshow(k_nearest_patches[i - 2][j])

            main_ax.get_xaxis().set_visible(False)
            main_ax.get_yaxis().set_visible(False)

    # patches
    l = 0
    for i in range(2 + k, 2 * N + 2 + k, 2):
        main_ax = fig.add_subplot(grid[0:2, i:i + 2])
        main_ax.set_xlim([0, img_size])
        main_ax.set_ylim([0, img_size])
        main_ax.invert_yaxis()
        main_ax.set_title(f'{at[top_patches[l]]:.3f}', fontsize=25)
        main_ax.imshow(imgs[l], aspect='auto')

        main_ax.get_xaxis().set_visible(False)
        main_ax.get_yaxis().set_visible(False)
        l = l + 1

    plt.axis('off')
    return fig
def find_k_nearest_patches_to_prototypes(
        dataloader,  # pytorch dataloader (must be unnormalized in [0,1])
        prototype_network_parallel,  # pytorch network with prototype_vectors
        k=5,
        preprocess_input_function=None,  # normalize if needed)
        prototype_layer_stride=1,
        root_dir_for_saving_images='./nearest_',
        save_image_class_identity=True,
        log=print):
    log('find nearest patches')
    start = time.time()

    protoL_rf_info = prototype_network_parallel.module.proto_layer_rf_info
    parent_names = [
        node.name for node in
        prototype_network_parallel.module.root.nodes_with_children()
    ]

    cmap = "jet"

    node2heaps = {name: [] for name in parent_names}
    # for each parent node, organize a heap for every prototype
    for node in prototype_network_parallel.module.root.nodes_with_children():
        name = node.name
        num_prototypes = getattr(prototype_network_parallel.module,
                                 "num_" + node.name + "_prototypes")
        for j in range(num_prototypes):
            node2heaps[name].append([])

    for (search_batch_input, search_y) in dataloader:
        if preprocess_input_function != None:
            #print('preprocessing input for pushing ...')
            search_batch = copy.deepcopy(search_batch_input)
            search_batch = preprocess_input_function(search_batch)
        else:
            search_batch = search_batch_input

        with torch.no_grad():
            search_batch = search_batch.cuda()
            conv_features = prototype_network_parallel.module.conv_features(
                search_batch)

        for node in prototype_network_parallel.module.root.nodes_with_children(
        ):

            num_prototypes = getattr(prototype_network_parallel.module,
                                     "num_" + node.name + "_prototypes")

            with torch.no_grad():
                proto_dist_torch = prototype_network_parallel.module.prototype_distances(
                    conv_features, node.name)
                proto_dist_ = np.copy(proto_dist_torch.detach().cpu().numpy())

            heaps = node2heaps[node.name]

            for img_idx, distance_map in enumerate(proto_dist_):
                for j in range(num_prototypes):
                    # find the closest patch to prototype j
                    closest_patch_distance_to_prototype_j = np.amin(
                        distance_map[j])
                    closest_patch_indices_in_distance_map_j = \
                        list(np.unravel_index(np.argmin(distance_map[j],axis=None),
                                              distance_map[j].shape))
                    closest_patch_indices_in_distance_map_j = [
                        0
                    ] + closest_patch_indices_in_distance_map_j
                    closest_patch_indices_in_img = \
                        compute_rf_prototype(search_batch.size(2),
                                             closest_patch_indices_in_distance_map_j,
                                             protoL_rf_info)
                    closest_patch = \
                        search_batch_input[img_idx, :,
                                           closest_patch_indices_in_img[1]:closest_patch_indices_in_img[2],
                                           closest_patch_indices_in_img[3]:closest_patch_indices_in_img[4]]
                    closest_patch = closest_patch.numpy()
                    closest_patch = np.transpose(closest_patch, (1, 2, 0))

                    original_img = search_batch_input[img_idx].numpy()
                    original_img = np.transpose(original_img, (1, 2, 0))

                    act_pattern = np.log(1 + (1 / (distance_map[j] + 1e-4)))

                    patch_indices = closest_patch_indices_in_img[1:5]

                    # construct the closest patch object
                    closest_patch = ImagePatch(
                        patch=closest_patch,
                        label=search_y[img_idx],
                        distance=closest_patch_distance_to_prototype_j,
                        original_img=original_img,
                        act_pattern=act_pattern,
                        patch_indices=patch_indices)
                    # add to the j-th heap
                    if len(heaps[j]) < k:
                        heapq.heappush(heaps[j], closest_patch)
                    else:
                        heapq.heappushpop(heaps[j], closest_patch)

        del conv_features, search_batch

    for node in prototype_network_parallel.module.root.nodes_with_children():
        heaps = node2heaps[node.name]
        num_prototypes = getattr(prototype_network_parallel.module,
                                 "num_" + node.name + "_prototypes")

        for j in range(num_prototypes):
            heaps[j].sort()
            heaps[j] = heaps[j][::-1]  # reverses

            dir_for_saving_images = os.path.join(root_dir_for_saving_images,
                                                 node.name, str(j))
            makedir(dir_for_saving_images)

            labels = []

            for i, patch in enumerate(heaps[j]):
                # save the patch itself
                plt.imsave(fname=os.path.join(dir_for_saving_images,
                                              'nearest-' + str(i) + '.png'),
                           arr=patch.patch,
                           vmin=0.0,
                           vmax=1.0)
                # save the original image where the patch comes from
                plt.imsave(fname=os.path.join(
                    dir_for_saving_images,
                    'nearest-' + str(i) + '_original.png'),
                           arr=patch.original_img,
                           vmin=0.0,
                           vmax=1.0)
                # save the activation pattern and the patch indices
                np.save(
                    os.path.join(dir_for_saving_images,
                                 'nearest-' + str(i) + '_act.npy'),
                    patch.act_pattern)
                np.save(
                    os.path.join(dir_for_saving_images,
                                 'nearest-' + str(i) + '_indices.npy'),
                    patch.patch_indices)
                # upsample the activation pattern
                img_size = patch.original_img.shape[0]
                original_img_gray = cv2.cvtColor(patch.original_img,
                                                 cv2.COLOR_BGR2GRAY)
                upsampled_act_pattern = cv2.resize(
                    patch.act_pattern,
                    dsize=(img_size, img_size),
                    interpolation=cv2.INTER_CUBIC)
                # overlay heatmap on the original image and save the result
                overlayed_original_img = 0.7 * original_img_gray + 0.3 * upsampled_act_pattern
                plt.imsave(fname=os.path.join(
                    dir_for_saving_images,
                    'nearest-' + str(i) + '_original_with_heatmap.png'),
                           arr=overlayed_original_img,
                           cmap=cmap,
                           vmin=0.0,
                           vmax=1.0)
                # save the patch with heatmap
                overlayed_patch = overlayed_original_img[
                    patch.patch_indices[0]:patch.patch_indices[1],
                    patch.patch_indices[2]:patch.patch_indices[3]]
                plt.imsave(fname=os.path.join(
                    dir_for_saving_images,
                    'nearest-' + str(i) + '_patch_with_heatmap.png'),
                           arr=overlayed_patch,
                           cmap=cmap,
                           vmin=0.0,
                           vmax=1.0)

            if save_image_class_identity:
                labels = np.array([patch.label for patch in heaps[j]])
                np.save(os.path.join(dir_for_saving_images, 'class_id.npy'),
                        labels)

    end = time.time()
    log('\tfind nearest patches time: \t{0}'.format(end - start))