def closest_patch_indices_in_img(self): closest_patch_indices_in_distance_map_j = \ list(np.unravel_index(np.argmin(self.distance_map_j, axis=None), self.distance_map_j.shape)) closest_patch_indices_in_distance_map_j = [0] + closest_patch_indices_in_distance_map_j closest_patch_indices_in_img = \ compute_rf_prototype(self.original_raw.size(1), closest_patch_indices_in_distance_map_j, self.protoL_rf_info) return closest_patch_indices_in_img
def update_prototypes_on_batch(search_batch_input, start_index_of_search_batch, prototype_network_parallel, global_min_proto_dist, # this will be updated global_min_fmap_patches, # this will be updated proto_rf_boxes, # this will be updated proto_bound_boxes, # this will be updated class_specific=True, search_y=None, # required if class_specific == True num_classes=None, # required if class_specific == True preprocess_input_function=None, prototype_layer_stride=1, dir_for_saving_prototypes=None, prototype_img_filename_prefix=None, prototype_self_act_filename_prefix=None, prototype_activation_function_in_numpy=None): prototype_network_parallel.eval() if preprocess_input_function is not None: # print('preprocessing input for pushing ...') # search_batch = copy.deepcopy(search_batch_input) search_batch = preprocess_input_function(search_batch_input) else: search_batch = search_batch_input with torch.no_grad(): search_batch = search_batch.cuda() # this computation currently is not parallelized protoL_input_torch, proto_dist_torch = prototype_network_parallel.module.push_forward(search_batch) protoL_input_ = np.copy(protoL_input_torch.detach().cpu().numpy()) proto_dist_ = np.copy(proto_dist_torch.detach().cpu().numpy()) del protoL_input_torch, proto_dist_torch if class_specific: class_to_img_index_dict = {key: [] for key in range(num_classes)} # img_y is the image's integer label for img_index, img_y in enumerate(search_y): img_label = img_y.item() class_to_img_index_dict[img_label].append(img_index) prototype_shape = prototype_network_parallel.module.prototype_shape n_prototypes = prototype_shape[0] proto_h = prototype_shape[2] proto_w = prototype_shape[3] max_dist = prototype_shape[1] * prototype_shape[2] * prototype_shape[3] for j in range(n_prototypes): #if n_prototypes_per_class != None: if class_specific: # target_class is the class of the class_specific prototype target_class = torch.argmax(prototype_network_parallel.module.prototype_class_identity[j]).item() # if there is not images of the target_class from this batch # we go on to the next prototype if len(class_to_img_index_dict[target_class]) == 0: continue proto_dist_j = proto_dist_[class_to_img_index_dict[target_class]][:,j,:,:] else: # if it is not class specific, then we will search through # every example proto_dist_j = proto_dist_[:,j,:,:] batch_min_proto_dist_j = np.amin(proto_dist_j) if batch_min_proto_dist_j < global_min_proto_dist[j]: batch_argmin_proto_dist_j = \ list(np.unravel_index(np.argmin(proto_dist_j, axis=None), proto_dist_j.shape)) if class_specific: ''' change the argmin index from the index among images of the target class to the index in the entire search batch ''' batch_argmin_proto_dist_j[0] = class_to_img_index_dict[target_class][batch_argmin_proto_dist_j[0]] # retrieve the corresponding feature map patch img_index_in_batch = batch_argmin_proto_dist_j[0] fmap_height_start_index = batch_argmin_proto_dist_j[1] * prototype_layer_stride fmap_height_end_index = fmap_height_start_index + proto_h fmap_width_start_index = batch_argmin_proto_dist_j[2] * prototype_layer_stride fmap_width_end_index = fmap_width_start_index + proto_w batch_min_fmap_patch_j = protoL_input_[img_index_in_batch, :, fmap_height_start_index:fmap_height_end_index, fmap_width_start_index:fmap_width_end_index] global_min_proto_dist[j] = batch_min_proto_dist_j global_min_fmap_patches[j] = batch_min_fmap_patch_j # get the receptive field boundary of the image patch # that generates the representation protoL_rf_info = prototype_network_parallel.module.proto_layer_rf_info rf_prototype_j = compute_rf_prototype(search_batch.size(2), batch_argmin_proto_dist_j, protoL_rf_info) # get the whole image original_img_j = search_batch_input[rf_prototype_j[0]] original_img_j = original_img_j.numpy() original_img_j = np.transpose(original_img_j, (1, 2, 0)) original_img_size = original_img_j.shape[0] # crop out the receptive field rf_img_j = original_img_j[rf_prototype_j[1]:rf_prototype_j[2], rf_prototype_j[3]:rf_prototype_j[4], :] # save the prototype receptive field information proto_rf_boxes[j, 0] = rf_prototype_j[0] + start_index_of_search_batch proto_rf_boxes[j, 1] = rf_prototype_j[1] proto_rf_boxes[j, 2] = rf_prototype_j[2] proto_rf_boxes[j, 3] = rf_prototype_j[3] proto_rf_boxes[j, 4] = rf_prototype_j[4] if proto_rf_boxes.shape[1] == 6 and search_y is not None: proto_rf_boxes[j, 5] = search_y[rf_prototype_j[0]].item() # find the highly activated region of the original image proto_dist_img_j = proto_dist_[img_index_in_batch, j, :, :] if prototype_network_parallel.module.prototype_activation_function == 'log': proto_act_img_j = np.log((proto_dist_img_j + 1) / (proto_dist_img_j + prototype_network_parallel.module.epsilon)) elif prototype_network_parallel.module.prototype_activation_function == 'linear': proto_act_img_j = max_dist - proto_dist_img_j else: proto_act_img_j = prototype_activation_function_in_numpy(proto_dist_img_j) upsampled_act_img_j = cv2.resize(proto_act_img_j, dsize=(original_img_size, original_img_size), interpolation=cv2.INTER_CUBIC) proto_bound_j = find_high_activation_crop(upsampled_act_img_j) # crop out the image patch with high activation as prototype image proto_img_j = original_img_j[proto_bound_j[0]:proto_bound_j[1], proto_bound_j[2]:proto_bound_j[3], :] # save the prototype boundary (rectangular boundary of highly activated region) proto_bound_boxes[j, 0] = proto_rf_boxes[j, 0] proto_bound_boxes[j, 1] = proto_bound_j[0] proto_bound_boxes[j, 2] = proto_bound_j[1] proto_bound_boxes[j, 3] = proto_bound_j[2] proto_bound_boxes[j, 4] = proto_bound_j[3] if proto_bound_boxes.shape[1] == 6 and search_y is not None: proto_bound_boxes[j, 5] = search_y[rf_prototype_j[0]].item() if dir_for_saving_prototypes is not None: if prototype_self_act_filename_prefix is not None: # save the numpy array of the prototype self activation np.save(os.path.join(dir_for_saving_prototypes, prototype_self_act_filename_prefix + str(j) + '.npy'), proto_act_img_j) if prototype_img_filename_prefix is not None: # save the whole image containing the prototype as png plt.imsave(os.path.join(dir_for_saving_prototypes, prototype_img_filename_prefix + '-original' + str(j) + '.png'), original_img_j, vmin=0.0, vmax=1.0) # overlay (upsampled) self activation on original image and save the result rescaled_act_img_j = upsampled_act_img_j - np.amin(upsampled_act_img_j) rescaled_act_img_j = rescaled_act_img_j / np.amax(rescaled_act_img_j) heatmap = cv2.applyColorMap(np.uint8(255*rescaled_act_img_j), cv2.COLORMAP_JET) heatmap = np.float32(heatmap) / 255 heatmap = heatmap[...,::-1] overlayed_original_img_j = 0.5 * original_img_j + 0.3 * heatmap plt.imsave(os.path.join(dir_for_saving_prototypes, prototype_img_filename_prefix + '-original_with_self_act' + str(j) + '.png'), overlayed_original_img_j, vmin=0.0, vmax=1.0) # if different from the original (whole) image, save the prototype receptive field as png if rf_img_j.shape[0] != original_img_size or rf_img_j.shape[1] != original_img_size: plt.imsave(os.path.join(dir_for_saving_prototypes, prototype_img_filename_prefix + '-receptive_field' + str(j) + '.png'), rf_img_j, vmin=0.0, vmax=1.0) overlayed_rf_img_j = overlayed_original_img_j[rf_prototype_j[1]:rf_prototype_j[2], rf_prototype_j[3]:rf_prototype_j[4]] plt.imsave(os.path.join(dir_for_saving_prototypes, prototype_img_filename_prefix + '-receptive_field_with_self_act' + str(j) + '.png'), overlayed_rf_img_j, vmin=0.0, vmax=1.0) # save the prototype image (highly activated region of the whole image) plt.imsave(os.path.join(dir_for_saving_prototypes, prototype_img_filename_prefix + str(j) + '.png'), proto_img_j, vmin=0.0, vmax=1.0) if class_specific: del class_to_img_index_dict
def find_k_nearest_patches_to_prototypes( dataloader, # pytorch dataloader (must be unnormalized in [0,1]) prototype_network_parallel, # pytorch network with prototype_vectors k=5, preprocess_input_function=None, # normalize if needed full_save=False, # save all the images root_dir_for_saving_images='./nearest', log=print, prototype_activation_function_in_numpy=None): prototype_network_parallel.eval() ''' full_save=False will only return the class identity of the closest patches, but it will not save anything. ''' log('find nearest patches') start = time.time() n_prototypes = prototype_network_parallel.module.num_prototypes prototype_shape = prototype_network_parallel.module.prototype_shape max_dist = prototype_shape[1] * prototype_shape[2] * prototype_shape[3] protoL_rf_info = prototype_network_parallel.module.proto_layer_rf_info heaps = [] # allocate an array of n_prototypes number of heaps for _ in range(n_prototypes): # a heap in python is just a maintained list heaps.append([]) for idx, (search_batch_input, search_y) in enumerate(dataloader): print('batch {}'.format(idx)) if preprocess_input_function is not None: # print('preprocessing input for pushing ...') # search_batch = copy.deepcopy(search_batch_input) search_batch = preprocess_input_function(search_batch_input) else: search_batch = search_batch_input with torch.no_grad(): search_batch = search_batch.cuda() protoL_input_torch, proto_dist_torch = \ prototype_network_parallel.module.push_forward(search_batch) #protoL_input_ = np.copy(protoL_input_torch.detach().cpu().numpy()) proto_dist_ = np.copy(proto_dist_torch.detach().cpu().numpy()) for img_idx, distance_map in enumerate(proto_dist_): for j in range(n_prototypes): # find the closest patches in this batch to prototype j closest_patch_distance_to_prototype_j = np.amin( distance_map[j]) if full_save: closest_patch_indices_in_distance_map_j = \ list(np.unravel_index(np.argmin(distance_map[j],axis=None), distance_map[j].shape)) closest_patch_indices_in_distance_map_j = [ 0 ] + closest_patch_indices_in_distance_map_j closest_patch_indices_in_img = \ compute_rf_prototype(search_batch.size(2), closest_patch_indices_in_distance_map_j, protoL_rf_info) closest_patch = \ search_batch_input[img_idx, :, closest_patch_indices_in_img[1]:closest_patch_indices_in_img[2], closest_patch_indices_in_img[3]:closest_patch_indices_in_img[4]] closest_patch = closest_patch.numpy() closest_patch = np.transpose(closest_patch, (1, 2, 0)) original_img = search_batch_input[img_idx].numpy() original_img = np.transpose(original_img, (1, 2, 0)) if prototype_network_parallel.module.prototype_activation_function == 'log': act_pattern = np.log( (distance_map[j] + 1) / (distance_map[j] + prototype_network_parallel.module.epsilon)) elif prototype_network_parallel.module.prototype_activation_function == 'linear': act_pattern = max_dist - distance_map[j] else: act_pattern = prototype_activation_function_in_numpy( distance_map[j]) # 4 numbers: height_start, height_end, width_start, width_end patch_indices = closest_patch_indices_in_img[1:5] # construct the closest patch object closest_patch = ImagePatch( patch=closest_patch, label=search_y[img_idx], distance=closest_patch_distance_to_prototype_j, original_img=original_img, act_pattern=act_pattern, patch_indices=patch_indices) else: closest_patch = ImagePatchInfo( label=search_y[img_idx], distance=closest_patch_distance_to_prototype_j) # add to the j-th heap if len(heaps[j]) < k: heapq.heappush(heaps[j], closest_patch) else: # heappushpop runs more efficiently than heappush # followed by heappop heapq.heappushpop(heaps[j], closest_patch) # after looping through the dataset every heap will # have the k closest prototypes for j in range(n_prototypes): # finally sort the heap; the heap only contains the k closest # but they are not ranked yet heaps[j].sort() heaps[j] = heaps[j][::-1] if full_save: dir_for_saving_images = os.path.join(root_dir_for_saving_images, str(j)) makedir(dir_for_saving_images) labels = [] for i, patch in enumerate(heaps[j]): # save the activation pattern of the original image where the patch comes from np.save( os.path.join(dir_for_saving_images, 'nearest-' + str(i + 1) + '_act.npy'), patch.act_pattern) # save the original image where the patch comes from plt.imsave(fname=os.path.join( dir_for_saving_images, 'nearest-' + str(i + 1) + '_original.png'), arr=patch.original_img, vmin=0.0, vmax=1.0) # overlay (upsampled) activation on original image and save the result img_size = patch.original_img.shape[0] upsampled_act_pattern = cv2.resize( patch.act_pattern, dsize=(img_size, img_size), interpolation=cv2.INTER_CUBIC) rescaled_act_pattern = upsampled_act_pattern - np.amin( upsampled_act_pattern) rescaled_act_pattern = rescaled_act_pattern / np.amax( rescaled_act_pattern) heatmap = cv2.applyColorMap( np.uint8(255 * rescaled_act_pattern), cv2.COLORMAP_JET) heatmap = np.float32(heatmap) / 255 heatmap = heatmap[..., ::-1] overlayed_original_img = 0.5 * patch.original_img + 0.3 * heatmap plt.imsave(fname=os.path.join( dir_for_saving_images, 'nearest-' + str(i + 1) + '_original_with_heatmap.png'), arr=overlayed_original_img, vmin=0.0, vmax=1.0) # if different from original image, save the patch (i.e. receptive field) if patch.patch.shape[0] != img_size or patch.patch.shape[ 1] != img_size: np.save( os.path.join( dir_for_saving_images, 'nearest-' + str(i + 1) + '_receptive_field_indices.npy'), patch.patch_indices) plt.imsave(fname=os.path.join( dir_for_saving_images, 'nearest-' + str(i + 1) + '_receptive_field.png'), arr=patch.patch, vmin=0.0, vmax=1.0) # save the receptive field patch with heatmap overlayed_patch = overlayed_original_img[ patch.patch_indices[0]:patch.patch_indices[1], patch.patch_indices[2]:patch.patch_indices[3], :] plt.imsave(fname=os.path.join( dir_for_saving_images, 'nearest-' + str(i + 1) + '_receptive_field_with_heatmap.png'), arr=overlayed_patch, vmin=0.0, vmax=1.0) # save the highly activated patch high_act_patch_indices = find_high_activation_crop( upsampled_act_pattern) high_act_patch = patch.original_img[ high_act_patch_indices[0]:high_act_patch_indices[1], high_act_patch_indices[2]:high_act_patch_indices[3], :] np.save( os.path.join( dir_for_saving_images, 'nearest-' + str(i + 1) + '_high_act_patch_indices.npy'), high_act_patch_indices) plt.imsave(fname=os.path.join( dir_for_saving_images, 'nearest-' + str(i + 1) + '_high_act_patch.png'), arr=high_act_patch, vmin=0.0, vmax=1.0) # save the original image with bounding box showing high activation patch imsave_with_bbox(fname=os.path.join( dir_for_saving_images, 'nearest-' + str(i + 1) + '_high_act_patch_in_original_img.png'), img_rgb=patch.original_img, bbox_height_start=high_act_patch_indices[0], bbox_height_end=high_act_patch_indices[1], bbox_width_start=high_act_patch_indices[2], bbox_width_end=high_act_patch_indices[3], color=(0, 255, 255)) labels = np.array([patch.label for patch in heaps[j]]) np.save(os.path.join(dir_for_saving_images, 'class_id.npy'), labels) labels_all_prototype = np.array([[patch.label for patch in heaps[j]] for j in range(n_prototypes)]) if full_save: np.save(os.path.join(root_dir_for_saving_images, 'full_class_id.npy'), labels_all_prototype) end = time.time() log('\tfind nearest patches time: \t{0}'.format(end - start)) return labels_all_prototype
def update_prototypes_on_batch( search_batch_input, start_index_of_search_batch, prototype_network_parallel, conv_features, global_min_proto_dist, # this will be updated global_min_fmap_patches, # this will be updated proto_bound_boxes, # this will be updated name, prototype_shape, n_prototypes_per_class=None, # default: no restriction on number of prototypes per class search_y=None, # required if n_prototypes_per_class != None num_classes=None, # required if n_prototypes_per_class != None preprocess_input_function=None, prototype_layer_stride=1, dir_for_saving_prototypes=None, prototype_img_filename_prefix=None, prototype_original_img_filename_prefix=None): # print("inside update") # for x in range(5): # img_from_input = np.transpose(search_batch_input[x],(1,2,0)) # plt.imshow(img_from_input) # plt.show() protoL_input_torch = copy.deepcopy(conv_features) with torch.no_grad(): proto_dist_torch = prototype_network_parallel.module.prototype_distances( conv_features, name) protoL_input_ = np.copy(protoL_input_torch.detach().cpu().numpy()) proto_dist_ = np.copy(proto_dist_torch.detach().cpu().numpy()) del protoL_input_torch, proto_dist_torch if n_prototypes_per_class != None: class_to_img_index_dict = {key: [] for key in range(num_classes)} for img_index, img_y in enumerate(search_y): img_label = img_y.item() class_to_img_index_dict[img_label].append(img_index) n_prototypes = prototype_shape[0] proto_h = prototype_shape[2] proto_w = prototype_shape[3] for j in range(n_prototypes): # print("STARTING PROTO", j) if n_prototypes_per_class != None: target_class = j // n_prototypes_per_class if len(class_to_img_index_dict[target_class]) == 0: continue # print(n_prototypes_per_class) # print(j) # print(target_class) proto_dist_j = proto_dist_[ class_to_img_index_dict[target_class]][:, j, :, :] else: proto_dist_j = proto_dist_[:, j, :, :] batch_min_proto_dist_j = np.amin(proto_dist_j) if batch_min_proto_dist_j < global_min_proto_dist[j]: batch_argmin_proto_dist_j = \ list(np.unravel_index(np.argmin(proto_dist_j, axis=None), proto_dist_j.shape)) # print("batch min index", batch_argmin_proto_dist_j) # retrive the corresponding feature map patch if n_prototypes_per_class != None: img_index_in_batch = class_to_img_index_dict[target_class][ batch_argmin_proto_dist_j[0]] else: img_index_in_batch = batch_argmin_proto_dist_j[0] fmap_height_start_index = batch_argmin_proto_dist_j[ 1] * prototype_layer_stride fmap_height_end_index = fmap_height_start_index + proto_h fmap_width_start_index = batch_argmin_proto_dist_j[ 2] * prototype_layer_stride fmap_width_end_index = fmap_width_start_index + proto_w batch_min_fmap_patch_j = protoL_input_[ img_index_in_batch, :, fmap_height_start_index:fmap_height_end_index, fmap_width_start_index:fmap_width_end_index] global_min_proto_dist[j] = batch_min_proto_dist_j global_min_fmap_patches[j] = np.reshape(batch_min_fmap_patch_j, -1) if n_prototypes_per_class != None: batch_argmin_proto_dist_j[0] = class_to_img_index_dict[ target_class][batch_argmin_proto_dist_j[0]] protoL_rf_info = prototype_network_parallel.module.proto_layer_rf_info rf_prototype_j = compute_rf_prototype(search_batch_input.size(2), batch_argmin_proto_dist_j, protoL_rf_info) img_j = search_batch_input[rf_prototype_j[0], :, rf_prototype_j[1]:rf_prototype_j[2], rf_prototype_j[3]:rf_prototype_j[4]] img_j = img_j.numpy() proto_bound_boxes[ j, 0] = rf_prototype_j[0] + start_index_of_search_batch proto_bound_boxes[j, 1] = rf_prototype_j[1] proto_bound_boxes[j, 2] = rf_prototype_j[2] proto_bound_boxes[j, 3] = rf_prototype_j[3] proto_bound_boxes[j, 4] = rf_prototype_j[4] if proto_bound_boxes.shape[1] == 6 and not (search_y is None): proto_bound_boxes[j, 5] = search_y[rf_prototype_j[0]].item() if dir_for_saving_prototypes != None: if prototype_img_filename_prefix != None: np.save( os.path.join( dir_for_saving_prototypes, prototype_img_filename_prefix + str(j) + '.npy'), img_j) img_j = np.transpose(img_j, (1, 2, 0)) matplotlib.image.imsave(os.path.join( dir_for_saving_prototypes, prototype_img_filename_prefix + str(j) + '.png'), img_j, vmin=0.0, vmax=1.0) if prototype_original_img_filename_prefix != None: original_img_j = search_batch_input[rf_prototype_j[0]] original_img_j = original_img_j.numpy() np.save( os.path.join( dir_for_saving_prototypes, prototype_original_img_filename_prefix + str(j) + '.npy'), original_img_j) original_img_j = np.transpose(original_img_j, (1, 2, 0)) # scipy.misc.imsave(os.path.join(dir_for_saving_prototypes, # prototype_original_img_filename_prefix + str(j) + '.png'), # original_img_j) matplotlib.image.imsave(os.path.join( dir_for_saving_prototypes, prototype_original_img_filename_prefix + str(j) + '.png'), original_img_j, vmin=0.0, vmax=1.0) if n_prototypes_per_class != None: del class_to_img_index_dict
def find_k_nearest_patches_to_prototypes( dataloader, # pytorch dataloader (must be unnormalized in [0,1]) prototype_network_parallel, # pytorch network with prototype_vectors k=5, preprocess_input_function=None, # normalize if needed) prototype_layer_stride=1, root_dir_for_saving_images='./nearest_', save_image_class_identity=True, log=print): log('find nearest patches') start = time.time() protoL_rf_info = prototype_network_parallel.module.proto_layer_rf_info parent_names = [ node.name for node in prototype_network_parallel.module.root.nodes_with_children() ] cmap = "jet" node2heaps = {name: [] for name in parent_names} # for each parent node, organize a heap for every prototype for node in prototype_network_parallel.module.root.nodes_with_children(): name = node.name num_prototypes = getattr(prototype_network_parallel.module, "num_" + node.name + "_prototypes") for j in range(num_prototypes): node2heaps[name].append([]) for (search_batch_input, search_y) in dataloader: if preprocess_input_function != None: #print('preprocessing input for pushing ...') search_batch = copy.deepcopy(search_batch_input) search_batch = preprocess_input_function(search_batch) else: search_batch = search_batch_input with torch.no_grad(): search_batch = search_batch.cuda() conv_features = prototype_network_parallel.module.conv_features( search_batch) for node in prototype_network_parallel.module.root.nodes_with_children( ): num_prototypes = getattr(prototype_network_parallel.module, "num_" + node.name + "_prototypes") with torch.no_grad(): proto_dist_torch = prototype_network_parallel.module.prototype_distances( conv_features, node.name) proto_dist_ = np.copy(proto_dist_torch.detach().cpu().numpy()) heaps = node2heaps[node.name] for img_idx, distance_map in enumerate(proto_dist_): for j in range(num_prototypes): # find the closest patch to prototype j closest_patch_distance_to_prototype_j = np.amin( distance_map[j]) closest_patch_indices_in_distance_map_j = \ list(np.unravel_index(np.argmin(distance_map[j],axis=None), distance_map[j].shape)) closest_patch_indices_in_distance_map_j = [ 0 ] + closest_patch_indices_in_distance_map_j closest_patch_indices_in_img = \ compute_rf_prototype(search_batch.size(2), closest_patch_indices_in_distance_map_j, protoL_rf_info) closest_patch = \ search_batch_input[img_idx, :, closest_patch_indices_in_img[1]:closest_patch_indices_in_img[2], closest_patch_indices_in_img[3]:closest_patch_indices_in_img[4]] closest_patch = closest_patch.numpy() closest_patch = np.transpose(closest_patch, (1, 2, 0)) original_img = search_batch_input[img_idx].numpy() original_img = np.transpose(original_img, (1, 2, 0)) act_pattern = np.log(1 + (1 / (distance_map[j] + 1e-4))) patch_indices = closest_patch_indices_in_img[1:5] # construct the closest patch object closest_patch = ImagePatch( patch=closest_patch, label=search_y[img_idx], distance=closest_patch_distance_to_prototype_j, original_img=original_img, act_pattern=act_pattern, patch_indices=patch_indices) # add to the j-th heap if len(heaps[j]) < k: heapq.heappush(heaps[j], closest_patch) else: heapq.heappushpop(heaps[j], closest_patch) del conv_features, search_batch for node in prototype_network_parallel.module.root.nodes_with_children(): heaps = node2heaps[node.name] num_prototypes = getattr(prototype_network_parallel.module, "num_" + node.name + "_prototypes") for j in range(num_prototypes): heaps[j].sort() heaps[j] = heaps[j][::-1] # reverses dir_for_saving_images = os.path.join(root_dir_for_saving_images, node.name, str(j)) makedir(dir_for_saving_images) labels = [] for i, patch in enumerate(heaps[j]): # save the patch itself plt.imsave(fname=os.path.join(dir_for_saving_images, 'nearest-' + str(i) + '.png'), arr=patch.patch, vmin=0.0, vmax=1.0) # save the original image where the patch comes from plt.imsave(fname=os.path.join( dir_for_saving_images, 'nearest-' + str(i) + '_original.png'), arr=patch.original_img, vmin=0.0, vmax=1.0) # save the activation pattern and the patch indices np.save( os.path.join(dir_for_saving_images, 'nearest-' + str(i) + '_act.npy'), patch.act_pattern) np.save( os.path.join(dir_for_saving_images, 'nearest-' + str(i) + '_indices.npy'), patch.patch_indices) # upsample the activation pattern img_size = patch.original_img.shape[0] original_img_gray = cv2.cvtColor(patch.original_img, cv2.COLOR_BGR2GRAY) upsampled_act_pattern = cv2.resize( patch.act_pattern, dsize=(img_size, img_size), interpolation=cv2.INTER_CUBIC) # overlay heatmap on the original image and save the result overlayed_original_img = 0.7 * original_img_gray + 0.3 * upsampled_act_pattern plt.imsave(fname=os.path.join( dir_for_saving_images, 'nearest-' + str(i) + '_original_with_heatmap.png'), arr=overlayed_original_img, cmap=cmap, vmin=0.0, vmax=1.0) # save the patch with heatmap overlayed_patch = overlayed_original_img[ patch.patch_indices[0]:patch.patch_indices[1], patch.patch_indices[2]:patch.patch_indices[3]] plt.imsave(fname=os.path.join( dir_for_saving_images, 'nearest-' + str(i) + '_patch_with_heatmap.png'), arr=overlayed_patch, cmap=cmap, vmin=0.0, vmax=1.0) if save_image_class_identity: labels = np.array([patch.label for patch in heaps[j]]) np.save(os.path.join(dir_for_saving_images, 'class_id.npy'), labels) end = time.time() log('\tfind nearest patches time: \t{0}'.format(end - start))