def prune_prototypes( dataloader, prototype_network_parallel, k, prune_threshold, preprocess_input_function, original_model_dir, epoch_number, # model_name=None, log=print, copy_prototype_imgs=True, ): ### run global analysis nearest_train_patch_class_ids = find_nearest.find_k_nearest_patches_to_prototypes( dataloader=dataloader, prototype_network_parallel=prototype_network_parallel, k=k, preprocess_input_function=preprocess_input_function, full_save=False, log=log, ) ### find prototypes to prune original_num_prototypes = prototype_network_parallel.module.num_prototypes prototypes_to_prune = [] for j in range(prototype_network_parallel.module.num_prototypes): class_j = torch.argmax(prototype_network_parallel.module. prototype_class_identity[j]).item() nearest_train_patch_class_counts_j = Counter( nearest_train_patch_class_ids[j]) # if no such element is in Counter, it will return 0 if nearest_train_patch_class_counts_j[class_j] < prune_threshold: prototypes_to_prune.append(j) log("k = {}, prune_threshold = {}".format(k, prune_threshold)) log("{} prototypes will be pruned".format(len(prototypes_to_prune))) ### bookkeeping of prototypes to be pruned class_of_prototypes_to_prune = (torch.argmax( prototype_network_parallel.module. prototype_class_identity[prototypes_to_prune], dim=1, ).numpy().reshape(-1, 1)) prototypes_to_prune_np = np.array(prototypes_to_prune).reshape(-1, 1) prune_info = np.hstack( (prototypes_to_prune_np, class_of_prototypes_to_prune)) makedir( os.path.join( original_model_dir, "pruned_prototypes_epoch{}_k{}_pt{}".format( epoch_number, k, prune_threshold), )) np.save( os.path.join( original_model_dir, "pruned_prototypes_epoch{}_k{}_pt{}".format( epoch_number, k, prune_threshold), "prune_info.npy", ), prune_info, ) ### prune prototypes prototype_network_parallel.module.prune_prototypes(prototypes_to_prune) # torch.save(obj=prototype_network_parallel.module, # f=os.path.join(original_model_dir, 'pruned_prototypes_epoch{}_k{}_pt{}'.format(epoch_number, # k, # prune_threshold), # model_name + '-pruned.pth')) if copy_prototype_imgs: original_img_dir = os.path.join(original_model_dir, "img", "epoch-%d" % epoch_number) dst_img_dir = os.path.join( original_model_dir, "pruned_prototypes_epoch{}_k{}_pt{}".format( epoch_number, k, prune_threshold), "img", "epoch-%d" % epoch_number, ) makedir(dst_img_dir) prototypes_to_keep = list( set(range(original_num_prototypes)) - set(prototypes_to_prune)) for idx in range(len(prototypes_to_keep)): shutil.copyfile( src=os.path.join( original_img_dir, "prototype-img%d.png" % prototypes_to_keep[idx]), dst=os.path.join(dst_img_dir, "prototype-img%d.png" % idx), ) shutil.copyfile( src=os.path.join( original_img_dir, "prototype-img-original%d.png" % prototypes_to_keep[idx], ), dst=os.path.join(dst_img_dir, "prototype-img-original%d.png" % idx), ) shutil.copyfile( src=os.path.join( original_img_dir, "prototype-img-original_with_self_act%d.png" % prototypes_to_keep[idx], ), dst=os.path.join( dst_img_dir, "prototype-img-original_with_self_act%d.png" % idx), ) shutil.copyfile( src=os.path.join( original_img_dir, "prototype-self-act%d.npy" % prototypes_to_keep[idx], ), dst=os.path.join(dst_img_dir, "prototype-self-act%d.npy" % idx), ) bb = np.load( os.path.join(original_img_dir, "bb%d.npy" % epoch_number)) bb = bb[prototypes_to_keep] np.save(os.path.join(dst_img_dir, "bb%d.npy" % epoch_number), bb) bb_rf = np.load( os.path.join(original_img_dir, "bb-receptive_field%d.npy" % epoch_number)) bb_rf = bb_rf[prototypes_to_keep] np.save( os.path.join(dst_img_dir, "bb-receptive_field%d.npy" % epoch_number), bb_rf, ) return prune_info
def push_prototypes( dataloader, # pytorch dataloader (must be unnormalized in [0,1]) prototype_network_parallel, # pytorch network with prototype_vectors class_specific=True, preprocess_input_function=None, # normalize if needed prototype_layer_stride=1, root_dir_for_saving_prototypes=None, # if not None, prototypes will be saved here epoch_number=None, # if not provided, prototypes saved previously will be overwritten prototype_img_filename_prefix=None, prototype_self_act_filename_prefix=None, proto_bound_boxes_filename_prefix=None, save_prototype_class_identity=True, # which class the prototype image comes from log=print, prototype_activation_function_in_numpy=None, ): prototype_network_parallel.eval() log("\tpush") start = time.time() prototype_shape = prototype_network_parallel.module.prototype_shape n_prototypes = prototype_network_parallel.module.num_prototypes # saves the closest distance seen so far global_min_proto_dist = np.full(n_prototypes, np.inf) # saves the patch representation that gives the current smallest distance global_min_fmap_patches = np.zeros([ n_prototypes, prototype_shape[1], prototype_shape[2], prototype_shape[3] ]) """ proto_rf_boxes and proto_bound_boxes column: 0: image index in the entire dataset 1: height start index 2: height end index 3: width start index 4: width end index 5: (optional) class identity """ if save_prototype_class_identity: proto_rf_boxes = np.full(shape=[n_prototypes, 6], fill_value=-1) proto_bound_boxes = np.full(shape=[n_prototypes, 6], fill_value=-1) else: proto_rf_boxes = np.full(shape=[n_prototypes, 5], fill_value=-1) proto_bound_boxes = np.full(shape=[n_prototypes, 5], fill_value=-1) if root_dir_for_saving_prototypes != None: if epoch_number != None: proto_epoch_dir = os.path.join(root_dir_for_saving_prototypes, "epoch-" + str(epoch_number)) makedir(proto_epoch_dir) else: proto_epoch_dir = root_dir_for_saving_prototypes else: proto_epoch_dir = None search_batch_size = dataloader.batch_size num_classes = prototype_network_parallel.module.num_classes for push_iter, (search_batch_input, search_y) in enumerate(dataloader): """ start_index_of_search keeps track of the index of the image assigned to serve as prototype """ start_index_of_search_batch = push_iter * search_batch_size update_prototypes_on_batch( search_batch_input, start_index_of_search_batch, prototype_network_parallel, global_min_proto_dist, global_min_fmap_patches, proto_rf_boxes, proto_bound_boxes, class_specific=class_specific, search_y=search_y, num_classes=num_classes, preprocess_input_function=preprocess_input_function, prototype_layer_stride=prototype_layer_stride, dir_for_saving_prototypes=proto_epoch_dir, prototype_img_filename_prefix=prototype_img_filename_prefix, prototype_self_act_filename_prefix= prototype_self_act_filename_prefix, prototype_activation_function_in_numpy= prototype_activation_function_in_numpy, ) if proto_epoch_dir != None and proto_bound_boxes_filename_prefix != None: np.save( os.path.join( proto_epoch_dir, proto_bound_boxes_filename_prefix + "-receptive_field" + str(epoch_number) + ".npy", ), proto_rf_boxes, ) np.save( os.path.join( proto_epoch_dir, proto_bound_boxes_filename_prefix + str(epoch_number) + ".npy", ), proto_bound_boxes, ) log("\tExecuting push ...") prototype_update = np.reshape(global_min_fmap_patches, tuple(prototype_shape)) prototype_network_parallel.module.prototype_vectors.data.copy_( torch.tensor(prototype_update, dtype=torch.float32).cuda()) # prototype_network_parallel.cuda() end = time.time() log("\tpush time: \t{0}".format(end - start))
def find_k_nearest_patches_to_prototypes( dataloader, # pytorch dataloader (must be unnormalized in [0,1]) prototype_network_parallel, # pytorch network with prototype_vectors k=5, preprocess_input_function=None, # normalize if needed full_save=False, # save all the images root_dir_for_saving_images="./nearest", log=print, prototype_activation_function_in_numpy=None, ): prototype_network_parallel.eval() """ full_save=False will only return the class identity of the closest patches, but it will not save anything. """ log("find nearest patches") start = time.time() n_prototypes = prototype_network_parallel.module.num_prototypes prototype_shape = prototype_network_parallel.module.prototype_shape max_dist = prototype_shape[1] * prototype_shape[2] * prototype_shape[3] protoL_rf_info = prototype_network_parallel.module.proto_layer_rf_info heaps = [] # allocate an array of n_prototypes number of heaps for _ in range(n_prototypes): # a heap in python is just a maintained list heaps.append([]) for idx, (search_batch_input, search_y) in enumerate(dataloader): print("batch {}".format(idx)) if preprocess_input_function is not None: # print('preprocessing input for pushing ...') # search_batch = copy.deepcopy(search_batch_input) search_batch = preprocess_input_function(search_batch_input) else: search_batch = search_batch_input with torch.no_grad(): search_batch = search_batch.cuda() ( protoL_input_torch, proto_dist_torch, ) = prototype_network_parallel.module.push_forward(search_batch) # protoL_input_ = np.copy(protoL_input_torch.detach().cpu().numpy()) proto_dist_ = np.copy(proto_dist_torch.detach().cpu().numpy()) for img_idx, distance_map in enumerate(proto_dist_): for j in range(n_prototypes): # find the closest patches in this batch to prototype j closest_patch_distance_to_prototype_j = np.amin(distance_map[j]) if full_save: closest_patch_indices_in_distance_map_j = list( np.unravel_index( np.argmin(distance_map[j], axis=None), distance_map[j].shape ) ) closest_patch_indices_in_distance_map_j = [ 0 ] + closest_patch_indices_in_distance_map_j closest_patch_indices_in_img = compute_rf_prototype( search_batch.size(2), closest_patch_indices_in_distance_map_j, protoL_rf_info, ) closest_patch = search_batch_input[ img_idx, :, closest_patch_indices_in_img[1] : closest_patch_indices_in_img[ 2 ], closest_patch_indices_in_img[3] : closest_patch_indices_in_img[ 4 ], ] closest_patch = closest_patch.numpy() closest_patch = np.transpose(closest_patch, (1, 2, 0)) original_img = search_batch_input[img_idx].numpy() original_img = np.transpose(original_img, (1, 2, 0)) if ( prototype_network_parallel.module.prototype_activation_function == "log" ): act_pattern = np.log( (distance_map[j] + 1) / ( distance_map[j] + prototype_network_parallel.module.epsilon ) ) elif ( prototype_network_parallel.module.prototype_activation_function == "linear" ): act_pattern = max_dist - distance_map[j] else: act_pattern = prototype_activation_function_in_numpy( distance_map[j] ) # 4 numbers: height_start, height_end, width_start, width_end patch_indices = closest_patch_indices_in_img[1:5] # construct the closest patch object closest_patch = ImagePatch( patch=closest_patch, label=search_y[img_idx], distance=closest_patch_distance_to_prototype_j, original_img=original_img, act_pattern=act_pattern, patch_indices=patch_indices, ) else: closest_patch = ImagePatchInfo( label=search_y[img_idx], distance=closest_patch_distance_to_prototype_j, ) # add to the j-th heap if len(heaps[j]) < k: heapq.heappush(heaps[j], closest_patch) else: # heappushpop runs more efficiently than heappush # followed by heappop heapq.heappushpop(heaps[j], closest_patch) # after looping through the dataset every heap will # have the k closest prototypes for j in range(n_prototypes): # finally sort the heap; the heap only contains the k closest # but they are not ranked yet heaps[j].sort() heaps[j] = heaps[j][::-1] if full_save: dir_for_saving_images = os.path.join(root_dir_for_saving_images, str(j)) makedir(dir_for_saving_images) labels = [] for i, patch in enumerate(heaps[j]): # save the activation pattern of the original image where the patch comes from np.save( os.path.join( dir_for_saving_images, "nearest-" + str(i + 1) + "_act.npy" ), patch.act_pattern, ) # save the original image where the patch comes from plt.imsave( fname=os.path.join( dir_for_saving_images, "nearest-" + str(i + 1) + "_original.png" ), arr=patch.original_img, vmin=0.0, vmax=1.0, ) # overlay (upsampled) activation on original image and save the result img_size = patch.original_img.shape[0] upsampled_act_pattern = cv2.resize( patch.act_pattern, dsize=(img_size, img_size), interpolation=cv2.INTER_CUBIC, ) rescaled_act_pattern = upsampled_act_pattern - np.amin( upsampled_act_pattern ) rescaled_act_pattern = rescaled_act_pattern / np.amax( rescaled_act_pattern ) heatmap = cv2.applyColorMap( np.uint8(255 * rescaled_act_pattern), cv2.COLORMAP_JET ) heatmap = np.float32(heatmap) / 255 heatmap = heatmap[..., ::-1] overlayed_original_img = 0.5 * patch.original_img + 0.3 * heatmap plt.imsave( fname=os.path.join( dir_for_saving_images, "nearest-" + str(i + 1) + "_original_with_heatmap.png", ), arr=overlayed_original_img, vmin=0.0, vmax=1.0, ) # if different from original image, save the patch (i.e. receptive field) if patch.patch.shape[0] != img_size or patch.patch.shape[1] != img_size: np.save( os.path.join( dir_for_saving_images, "nearest-" + str(i + 1) + "_receptive_field_indices.npy", ), patch.patch_indices, ) plt.imsave( fname=os.path.join( dir_for_saving_images, "nearest-" + str(i + 1) + "_receptive_field.png", ), arr=patch.patch, vmin=0.0, vmax=1.0, ) # save the receptive field patch with heatmap overlayed_patch = overlayed_original_img[ patch.patch_indices[0] : patch.patch_indices[1], patch.patch_indices[2] : patch.patch_indices[3], :, ] plt.imsave( fname=os.path.join( dir_for_saving_images, "nearest-" + str(i + 1) + "_receptive_field_with_heatmap.png", ), arr=overlayed_patch, vmin=0.0, vmax=1.0, ) # save the highly activated patch high_act_patch_indices = find_high_activation_crop( upsampled_act_pattern ) high_act_patch = patch.original_img[ high_act_patch_indices[0] : high_act_patch_indices[1], high_act_patch_indices[2] : high_act_patch_indices[3], :, ] np.save( os.path.join( dir_for_saving_images, "nearest-" + str(i + 1) + "_high_act_patch_indices.npy", ), high_act_patch_indices, ) plt.imsave( fname=os.path.join( dir_for_saving_images, "nearest-" + str(i + 1) + "_high_act_patch.png", ), arr=high_act_patch, vmin=0.0, vmax=1.0, ) # save the original image with bounding box showing high activation patch imsave_with_bbox( fname=os.path.join( dir_for_saving_images, "nearest-" + str(i + 1) + "_high_act_patch_in_original_img.png", ), img_rgb=patch.original_img, bbox_height_start=high_act_patch_indices[0], bbox_height_end=high_act_patch_indices[1], bbox_width_start=high_act_patch_indices[2], bbox_width_end=high_act_patch_indices[3], color=(0, 255, 255), ) labels = np.array([patch.label for patch in heaps[j]]) np.save(os.path.join(dir_for_saving_images, "class_id.npy"), labels) labels_all_prototype = np.array( [[patch.label for patch in heaps[j]] for j in range(n_prototypes)] ) if full_save: np.save( os.path.join(root_dir_for_saving_images, "full_class_id.npy"), labels_all_prototype, ) end = time.time() log("\tfind nearest patches time: \t{0}".format(end - start)) return labels_all_prototype
if colab: model_dir = ( "/content/PPNet/saved_models/" + base_architecture + "/" + experiment_run + "/" ) else: model_dir = ( "/cluster/scratch/" + username + "/PPNet/saved_models/" + base_architecture + "/" + experiment_run + "/" ) makedir(model_dir) shutil.copy(src=os.path.join(os.getcwd(), __file__), dst=model_dir) shutil.copy(src=os.path.join(os.getcwd(), "settings.py"), dst=model_dir) shutil.copy( src=os.path.join( os.getcwd(), "src/models/", base_architecture_type + "_features.py" ), dst=model_dir, ) shutil.copy(src=os.path.join(os.getcwd(), "src/models/", "model.py"), dst=model_dir) shutil.copy( src=os.path.join(os.getcwd(), "src/training/", "train_and_test.py"), dst=model_dir ) log, logclose = create_logger(log_filename=os.path.join(model_dir, "train.log")) img_dir = os.path.join(model_dir, "img")
def jpeg_visualization2( self, pil_img, img_name, test_image_label, preprocess_clean, preprocess_compressed, show_images=False, idx=0, max_prototypes=5, top_n=None, save_histogram=False, save_name="histogram", ): """ Perform detailed local analysis comparing compressed and uncompressed images. Args: pil_img: is the PIL test image which is to be inspected img_name: name of the image to save it test_image_label: label of the test image preprocess_clean: first torch vision transform pipeline (here, without compression) preprocess_compressed: second torch vision transform pipeline (here, with compression) show_images (default = False): Boolean value to show images idx (default = 0): Index Value, when retrieving results of the PPNet max_prototypes (int): number of most similar prototypes to display (default: 5) top_n (default = None): Visualize the top_n_th most activating prototype Returns: A list containing 8 images in the following order: 1. Full picture of most activated prototype with bounding box 2. Most activated prototype of compressed image 3. Compressed image passed through, with activated patch in bounding box 4. Corresponding activation map of the compressed image 5. Full picture of most activated prototype with bounding box 6. Most activated prototype of compressed image 7. Uncompressed image passed through, with activated patch in bounding box 8. Corresponding activation map of the uncompressed image """ # How to save the images specific_folder = self.save_analysis_path + "/" + img_name makedir(specific_folder) # Preprocess Clean image img_tensor_clean = preprocess_clean(pil_img) img_variable_clean = Variable(img_tensor_clean.unsqueeze(0)) # Preprocess compressed image img_tensor_compressed = preprocess_compressed(pil_img) img_variable_compressed = Variable(img_tensor_compressed.unsqueeze(0)) img_variables = [img_variable_compressed, img_variable_clean] # Save activations dict_prototype_activations, dict_tables = {}, {} dict_prototype_activation_patterns = {} dict_array_act, dict_sorted_indices_act = {}, {} dict_original_img = {} for k, img_variable in enumerate(img_variables): # Forward the image variable through the network images_test = img_variable.cuda() labels_test = torch.tensor([test_image_label]) logits, min_distances = self.ppnet_multi(images_test) conv_output, distances = self.ppnet.push_forward(images_test) prototype_activations = self.ppnet.distance_2_similarity(min_distances) prototype_activation_patterns = self.ppnet.distance_2_similarity(distances) if self.ppnet.prototype_activation_function == "linear": prototype_activations = prototype_activations + max_dist prototype_activation_patterns = prototype_activation_patterns + max_dist tables = [] for i in range(logits.size(0)): tables.append( (torch.argmax(logits, dim=1)[i].item(), labels_test[i].item()) ) idx = idx predicted_cls = tables[idx][0] correct_cls = tables[idx][1] self.log("Uncompressed image" if k == 1 else "JPEG compressed image") if predicted_cls == correct_cls: pred_text = "Prediction is correct." else: pred_text = "Prediction is wrong." self.log( "Predicted: " + str(predicted_cls) + "\t Actual: " + str(correct_cls) + "\t " + pred_text ) self.log("------------------------------") dict_original_img[k] = save_preprocessed_img( os.path.join(specific_folder, "original_img.png"), images_test, idx ) ##### MOST ACTIVATED (NEAREST) PROTOTYPES OF THIS IMAGE makedir(os.path.join(specific_folder, "most_activated_prototypes")) dict_prototype_activations[k], dict_tables[k] = ( prototype_activations, tables, ) dict_prototype_activation_patterns[k] = prototype_activation_patterns dict_array_act[k], dict_sorted_indices_act[k] = torch.sort( prototype_activations[idx] ) # Initialize as none, and will be filled after examining the first image inspected_index = None inspected_min = None inspected_max = None logs = {0: [], 1: []} display_images = [] for i in range(1, max_prototypes + 1): if top_n is not None and top_n != i: continue for k in [0, 1]: if top_n is not None: if inspected_index is None: inspected_index = dict_sorted_indices_act[k][-i].item() else: i = np.where( dict_sorted_indices_act[k].cpu().numpy() == inspected_index )[0][0] else: inspected_index, inspected_min, inspected_max = None, None, None inspected_index = dict_sorted_indices_act[k][-i].item() p_img = save_prototype( self.load_img_dir, os.path.join( self.save_analysis_path, "most_activated_prototypes", "top-%d_activated_prototype.png" % i, ), self.start_epoch_number, inspected_index, ) p_oimg_with_bbox = save_prototype_original_img_with_bbox( self.load_img_dir, fname=os.path.join( self.save_analysis_path, "most_activated_prototypes", "top-%d_activated_prototype_in_original_pimg.png" % i, ), epoch=self.start_epoch_number, index=inspected_index, bbox_height_start=self.prototype_info[inspected_index][1], bbox_height_end=self.prototype_info[inspected_index][2], bbox_width_start=self.prototype_info[inspected_index][3], bbox_width_end=self.prototype_info[inspected_index][4], color=(0, 255, 255), ) p_img_with_self_actn = save_prototype_self_activation( self.load_img_dir, os.path.join( self.save_analysis_path, "most_activated_prototypes", "top-%d_activated_prototype_self_act.png" % i, ), self.start_epoch_number, inspected_index, ) activation_pattern = ( dict_prototype_activation_patterns[k][idx][inspected_index] .detach() .cpu() .numpy() ) upsampled_activation_pattern = cv2.resize( activation_pattern, dsize=(self.img_size, self.img_size), interpolation=cv2.INTER_CUBIC, ) # show the most highly activated patch of the image by this prototype high_act_patch_indices = find_high_activation_crop( upsampled_activation_pattern ) high_act_patch = dict_original_img[k][ high_act_patch_indices[0] : high_act_patch_indices[1], high_act_patch_indices[2] : high_act_patch_indices[3], :, ] plt.imsave( os.path.join( specific_folder, "most_activated_prototypes", "most_highly_activated_patch_by_top-%d_prototype.png" % i, ), high_act_patch, ) p_img_with_bbox = imsave_with_bbox( fname=os.path.join( specific_folder, "most_activated_prototypes", "most_highly_activated_patch_in_original_img_by_top-%d_prototype.png" % i, ), img_rgb=dict_original_img[k], bbox_height_start=high_act_patch_indices[0], bbox_height_end=high_act_patch_indices[1], bbox_width_start=high_act_patch_indices[2], bbox_width_end=high_act_patch_indices[3], color=(0, 255, 255), ) # show the image overlayed with prototype activation map and use normalization values of first run if inspected_min is None: inspected_min = np.amin(upsampled_activation_pattern) rescaled_activation_pattern = ( upsampled_activation_pattern - inspected_min ) if inspected_max is None: inspected_max = np.amax(rescaled_activation_pattern) rescaled_activation_pattern = ( rescaled_activation_pattern / inspected_max ) heatmap = cv2.applyColorMap( np.uint8(255 * rescaled_activation_pattern), cv2.COLORMAP_JET ) heatmap = np.float32(heatmap) / 255 heatmap = heatmap[..., ::-1] overlayed_img = 0.5 * dict_original_img[k] + 0.3 * heatmap plt.imsave( os.path.join( specific_folder, "most_activated_prototypes", "prototype_activation_map_by_top-%d_prototype.png" % i, ), overlayed_img, ) display_images += [ p_oimg_with_bbox, p_img, p_img_with_bbox, overlayed_img, ] logs[k].append( { "Prototype Id": inspected_index, "Rank": i, "Prototype Class": self.prototype_img_identity[inspected_index], "Similarity Score": dict_prototype_activations[k][idx][ inspected_index ].item(), } ) # Visualize Logs and Images if show_images: logs_df = pd.concat( [pd.DataFrame(logs[0]), pd.DataFrame(logs[1])], axis=1, keys=["Compressed Image", "Clean Image"], ) display(logs_df) display_titles = [ "Training Image from which \nprototype is taken", "Prototype", "Test Image + BBox", "Test Image + Activation Map", ] display_titles += [x + "\n(uncompressed)" for x in display_titles] visualize_image_grid(images=display_images, titles=display_titles, ncols=8) plt.tight_layout() plt.show() # Display Histograms f = plt.figure(figsize=(12, 3.5)) for k in range(0, 2): plt.subplot(1, 2, k + 1) _, pids = torch.topk(dict_prototype_activations[k][idx], 75) pids = pids.cpu().numpy() plt.bar( np.arange(0, len(pids)) - 0.2, dict_prototype_activations[1][idx][pids].cpu().numpy(), width=0.4, label="Clean", ) plt.bar( np.arange(0, len(pids)) + 0.2, dict_prototype_activations[0][idx][pids].cpu().numpy(), width=0.4, label="Compressed", ) plt.title( "Top 75 for compressed image" if k == 0 else "Top 75 for clean image" ) plt.ylabel("Similarities") plt.xlabel("Prototype") plt.legend() plt.tight_layout() plt.show() if save_histogram: f.savefig("{}.pdf".format(save_name), bbox_inches="tight") return display_images
def jpeg_visualization( self, pil_img, img_name, test_image_label, preprocess_clean, preprocess_compressed, show_images=False, idx=0, top_n=1, ): """ Perform local analysis comparing compressed and uncompressed images. Args: pil_img: is the PIL test image which is to be inspected img_name: name of the image to save it test_image_label: label of the test image preprocess_clean: first torch vision transform pipeline (here, without compression) preprocess_compressed: second torch vision transform pipeline (here, with compression) show_images (default = False): Boolean value to show images idx (default = 0): Index Value, when retrieving results of the PPNet top_n (default = 1): Visualize the n_th most activating prototype Returns: A list containing 8 images in the following order: 1. Full picture of most activated prototype with bounding box 2. Most activated prototype of compressed image 3. Compressed image passed through, with activated patch in bounding box 4. Corresponding activation map of the compressed image 5. Full picture of most activated prototype with bounding box 6. Most activated prototype of compressed image 7. Uncompressed image passed through, with activated patch in bounding box 8. Corresponding activation map of the uncompressed image """ # How to save the images specific_folder = self.save_analysis_path + "/" + img_name makedir(specific_folder) # Preprocess Clean image img_tensor_clean = preprocess_clean(pil_img) img_variable_clean = Variable(img_tensor_clean.unsqueeze(0)) # Preprocess compressed image img_tensor_compressed = preprocess_compressed(pil_img) img_variable_compressed = Variable(img_tensor_compressed.unsqueeze(0)) # Initialize as none, and will be filled after examining the first image inspected_index = None inspected_min = None inspected_max = None img_variables = [img_variable_compressed, img_variable_clean] display_images = [] for img_variable in img_variables: # Forward the image variable through the network images_test = img_variable.cuda() labels_test = torch.tensor([test_image_label]) logits, min_distances = self.ppnet_multi(images_test) conv_output, distances = self.ppnet.push_forward(images_test) prototype_activations = self.ppnet.distance_2_similarity(min_distances) prototype_activation_patterns = self.ppnet.distance_2_similarity(distances) if self.ppnet.prototype_activation_function == "linear": prototype_activations = prototype_activations + max_dist prototype_activation_patterns = prototype_activation_patterns + max_dist tables = [] for i in range(logits.size(0)): tables.append( (torch.argmax(logits, dim=1)[i].item(), labels_test[i].item()) ) idx = idx predicted_cls = tables[idx][0] correct_cls = tables[idx][1] if predicted_cls == correct_cls: pred_text = "Prediction is correct." else: pred_text = "Prediction is wrong." self.log( "Predicted: " + str(predicted_cls) + "\t Actual: " + str(correct_cls) + "\t " + pred_text ) original_img = save_preprocessed_img( os.path.join(specific_folder, "original_img.png"), images_test, idx ) ##### MOST ACTIVATED (NEAREST) 10 PROTOTYPES OF THIS IMAGE makedir(os.path.join(specific_folder, "most_activated_prototypes")) array_act, sorted_indices_act = torch.sort(prototype_activations[idx]) i = top_n if inspected_index is None: inspected_index = sorted_indices_act[-i].item() self.log("protoype index: " + str(inspected_index)) p_img = save_prototype( self.load_img_dir, os.path.join( self.save_analysis_path, "most_activated_prototypes", "top-%d_activated_prototype.png" % i, ), self.start_epoch_number, inspected_index, ) p_oimg_with_bbox = save_prototype_original_img_with_bbox( self.load_img_dir, fname=os.path.join( self.save_analysis_path, "most_activated_prototypes", "top-%d_activated_prototype_in_original_pimg.png" % i, ), epoch=self.start_epoch_number, index=inspected_index, bbox_height_start=self.prototype_info[inspected_index][1], bbox_height_end=self.prototype_info[inspected_index][2], bbox_width_start=self.prototype_info[inspected_index][3], bbox_width_end=self.prototype_info[inspected_index][4], color=(0, 255, 255), ) p_img_with_self_actn = save_prototype_self_activation( self.load_img_dir, os.path.join( self.save_analysis_path, "most_activated_prototypes", "top-%d_activated_prototype_self_act.png" % i, ), self.start_epoch_number, inspected_index, ) self.log( "prototype class identity: {0}".format( self.prototype_img_identity[inspected_index] ) ) if ( self.prototype_max_connection[inspected_index] != self.prototype_img_identity[sorted_indices_act[-i].item()] ): self.log( "prototype connection identity: {0}".format( self.prototype_max_connection[inspected_index] ) ) self.log( "activation value (similarity score): {0}".format( prototype_activations[idx][inspected_index] ) ) activation_pattern = ( prototype_activation_patterns[idx][inspected_index] .detach() .cpu() .numpy() ) upsampled_activation_pattern = cv2.resize( activation_pattern, dsize=(self.img_size, self.img_size), interpolation=cv2.INTER_CUBIC, ) # show the most highly activated patch of the image by this prototype high_act_patch_indices = find_high_activation_crop( upsampled_activation_pattern ) high_act_patch = original_img[ high_act_patch_indices[0] : high_act_patch_indices[1], high_act_patch_indices[2] : high_act_patch_indices[3], :, ] plt.imsave( os.path.join( specific_folder, "most_activated_prototypes", "most_highly_activated_patch_by_top-%d_prototype.png" % i, ), high_act_patch, ) p_img_with_bbox = imsave_with_bbox( fname=os.path.join( specific_folder, "most_activated_prototypes", "most_highly_activated_patch_in_original_img_by_top-%d_prototype.png" % i, ), img_rgb=original_img, bbox_height_start=high_act_patch_indices[0], bbox_height_end=high_act_patch_indices[1], bbox_width_start=high_act_patch_indices[2], bbox_width_end=high_act_patch_indices[3], color=(0, 255, 255), ) # show the image overlayed with prototype activation map and use normalization values of first run if inspected_min is None: inspected_min = np.amin(upsampled_activation_pattern) rescaled_activation_pattern = upsampled_activation_pattern - inspected_min if inspected_max is None: inspected_max = np.amax(rescaled_activation_pattern) rescaled_activation_pattern = rescaled_activation_pattern / inspected_max heatmap = cv2.applyColorMap( np.uint8(255 * rescaled_activation_pattern), cv2.COLORMAP_JET ) heatmap = np.float32(heatmap) / 255 heatmap = heatmap[..., ::-1] overlayed_img = 0.5 * original_img + 0.3 * heatmap plt.imsave( os.path.join( specific_folder, "most_activated_prototypes", "prototype_activation_map_by_top-%d_prototype.png" % i, ), overlayed_img, ) display_images += [p_oimg_with_bbox, p_img, p_img_with_bbox, overlayed_img] self.log("------------------------------") # Visualize Images if show_images: display_titles = [ "Training Image from which \nprototype is taken", "Prototype", "Test Image + BBox", "Test Image + Activation Map", ] display_titles += [x + "\n(uncompressed)" for x in display_titles] visualize_image_grid(images=display_images, titles=display_titles, ncols=8) plt.tight_layout() plt.show() return display_images
def local_analysis( self, img_variable, test_image_label, max_prototypes=10, idx=0, pid=None, verbose=False, show_images=True, normalize_sim_map=None, ): """ Perform local analysis. Arguments: img_variable (torch.Tensor): imput image to test on. test_image_label (int): true label of test image. max_prototypes (int): number of most similar prototypes to display (fefault: 10). idx (int): image id in the batch (default: 0). pid (int): prototype id. verbose (bool): whether to print (default: False). show_images (bool): show images (default: True). normalize_sim_map (Tuple(2)): min and max values resp. to be used for normalizing the activation map. """ img_size = self.ppnet_multi.module.img_size images_test = img_variable.cuda() labels_test = torch.tensor([test_image_label]) logits, min_distances = self.ppnet_multi(images_test) conv_output, distances = self.ppnet.push_forward(images_test) prototype_activations = self.ppnet.distance_2_similarity(min_distances) prototype_activation_patterns = self.ppnet.distance_2_similarity(distances) if self.ppnet.prototype_activation_function == "linear": prototype_activations = prototype_activations + self.max_dist prototype_activation_patterns = ( prototype_activation_patterns + self.max_dist ) tables = [] for i in range(logits.size(0)): tables.append( (torch.argmax(logits, dim=1)[i].item(), labels_test[i].item()) ) idx = idx predicted_cls = tables[idx][0] correct_cls = tables[idx][1] self.log("Predicted: " + str(predicted_cls)) self.log("Actual: " + str(correct_cls)) if predicted_cls == correct_cls: self.log("Prediction is correct.") else: self.log("Prediction is wrong.") original_img = save_preprocessed_img( os.path.join(self.save_analysis_path, "original_img.png"), images_test, idx ) ##### MOST ACTIVATED (NEAREST) 10 PROTOTYPES OF THIS IMAGE makedir(os.path.join(self.save_analysis_path, "most_activated_prototypes")) self.log("Most activated 10 prototypes of this image:") self.log("--------------------------------------------------------------") array_act, sorted_indices_act = torch.sort(prototype_activations[idx]) for i in range(1, max_prototypes + 1): if pid is not None and pid != sorted_indices_act[-i].item(): continue self.log("top {0} activated prototype for this image:".format(i)) p_img = save_prototype( self.load_img_dir, os.path.join( self.save_analysis_path, "most_activated_prototypes", "top-%d_activated_prototype.png" % i, ), self.start_epoch_number, sorted_indices_act[-i].item(), ) p_oimg_with_bbox = save_prototype_original_img_with_bbox( self.load_img_dir, fname=os.path.join( self.save_analysis_path, "most_activated_prototypes", "top-%d_activated_prototype_in_original_pimg.png" % i, ), epoch=self.start_epoch_number, index=sorted_indices_act[-i].item(), bbox_height_start=self.prototype_info[sorted_indices_act[-i].item()][1], bbox_height_end=self.prototype_info[sorted_indices_act[-i].item()][2], bbox_width_start=self.prototype_info[sorted_indices_act[-i].item()][3], bbox_width_end=self.prototype_info[sorted_indices_act[-i].item()][4], color=(0, 255, 255), ) p_img_with_self_actn = save_prototype_self_activation( self.load_img_dir, os.path.join( self.save_analysis_path, "most_activated_prototypes", "top-%d_activated_prototype_self_act.png" % i, ), self.start_epoch_number, sorted_indices_act[-i].item(), ) self.log("prototype index: {0}".format(sorted_indices_act[-i].item())) self.log( "prototype class identity: {0}".format( self.prototype_img_identity[sorted_indices_act[-i].item()] ) ) if ( self.prototype_max_connection[sorted_indices_act[-i].item()] != self.prototype_img_identity[sorted_indices_act[-i].item()] ): self.log( "prototype connection identity: {0}".format( self.prototype_max_connection[sorted_indices_act[-i].item()] ) ) self.log("activation value (similarity score): {0}".format(array_act[-i])) self.log( "last layer connection with predicted class: {0}".format( self.ppnet.last_layer.weight[predicted_cls][ sorted_indices_act[-i].item() ] ) ) activation_pattern = ( prototype_activation_patterns[idx][sorted_indices_act[-i].item()] .detach() .cpu() .numpy() ) upsampled_activation_pattern = cv2.resize( activation_pattern, dsize=(img_size, img_size), interpolation=cv2.INTER_CUBIC, ) # show the most highly activated patch of the image by this prototype high_act_patch_indices = find_high_activation_crop( upsampled_activation_pattern ) high_act_patch = original_img[ high_act_patch_indices[0] : high_act_patch_indices[1], high_act_patch_indices[2] : high_act_patch_indices[3], :, ] if verbose: self.log( "most highly activated patch of the chosen image by this prototype:" ) # plt.axis('off') plt.imsave( os.path.join( self.save_analysis_path, "most_activated_prototypes", "most_highly_activated_patch_by_top-%d_prototype.png" % i, ), high_act_patch, ) if verbose: log( "most highly activated patch by this prototype shown in the original image:" ) p_img_with_bbox = imsave_with_bbox( fname=os.path.join( self.save_analysis_path, "most_activated_prototypes", "most_highly_activated_patch_in_original_img_by_top-%d_prototype.png" % i, ), img_rgb=original_img, bbox_height_start=high_act_patch_indices[0], bbox_height_end=high_act_patch_indices[1], bbox_width_start=high_act_patch_indices[2], bbox_width_end=high_act_patch_indices[3], color=(0, 255, 255), ) # show the image overlayed with prototype activation map if normalize_sim_map is not None: rescaled_activation_pattern = ( upsampled_activation_pattern - normalize_sim_map[0] ) rescaled_activation_pattern = ( rescaled_activation_pattern / normalize_sim_map[1] ) else: rescaled_activation_pattern = upsampled_activation_pattern - np.amin( upsampled_activation_pattern ) rescaled_activation_pattern = rescaled_activation_pattern / np.amax( rescaled_activation_pattern ) heatmap = cv2.applyColorMap( np.uint8(255 * rescaled_activation_pattern), cv2.COLORMAP_JET ) heatmap = np.float32(heatmap) / 255 heatmap = heatmap[..., ::-1] overlayed_img = 0.5 * original_img + 0.3 * heatmap if verbose: self.log("prototype activation map of the chosen image:") # plt.axis('off') plt.imsave( os.path.join( self.save_analysis_path, "most_activated_prototypes", "prototype_activation_map_by_top-%d_prototype.png" % i, ), overlayed_img, ) if show_images: visualize_image_grid( images=[p_oimg_with_bbox, p_img, p_img_with_bbox, overlayed_img], titles=[ "Training Image from which \nprototype is taken", "Prototype", "Test Image + BBox", "Test Image + Activation Map", ], ncols=4, ) plt.tight_layout() plt.show() self.log("--------------------------------------------------------------") return ( sorted_indices_act, prototype_activation_patterns, [ np.amin(upsampled_activation_pattern), np.amax(rescaled_activation_pattern), ], )
def __init__( self, load_model_dir, load_model_name, test_image_name, image_save_directory=None, attack=None, ): """ Perform local analysis. Arguments: load_model_dir (str): path to saved model directory. load_model_name (str): saved model name. test_image_name (str): test image file name. image_save_directory (str): directory to save images. attack (int): type of attack (1 or 3 or None). """ model_base_architecture = load_model_dir.split("/")[-3] self.model_base_architecture = model_base_architecture experiment_run = load_model_dir.split("/")[-2] self.save_analysis_path = ( "/cluster/scratch/{}/PPNet/local_analysis_attack{}/".format( username, attack ) + model_base_architecture + "-" + experiment_run + "/" + test_image_name[:-4] ) if image_save_directory is not None: self.save_analysis_path = image_save_directory + test_image_name makedir(self.save_analysis_path) self.log, self.logclose = create_logger( log_filename=os.path.join(self.save_analysis_path, "local_analysis.log") ) load_model_path = os.path.join(load_model_dir, load_model_name) epoch_number_str = re.search(r"\d+", load_model_name).group(0) self.start_epoch_number = int(epoch_number_str) self.log("load model from " + load_model_path) self.log("model base architecture: " + model_base_architecture) self.log("experiment run: " + experiment_run) self.ppnet = torch.load(load_model_path) self.ppnet = self.ppnet.cuda() self.ppnet_multi = torch.nn.DataParallel(self.ppnet) self.img_size = self.ppnet_multi.module.img_size prototype_shape = self.ppnet.prototype_shape self.max_dist = prototype_shape[1] * prototype_shape[2] * prototype_shape[3] self.class_specific = True self.normalize = transforms.Normalize(mean=mean, std=std) # confirm prototype class identity self.load_img_dir = os.path.join(load_model_dir, "img") self.prototype_info = np.load( os.path.join( self.load_img_dir, "epoch-" + epoch_number_str, "bb" + epoch_number_str + ".npy", ) ) self.prototype_img_identity = self.prototype_info[:, -1] self.log( "Prototypes are chosen from " + str(len(set(self.prototype_img_identity))) + " number of classes." ) self.log("Their class identities are: " + str(self.prototype_img_identity)) # confirm prototype connects most strongly to its own class prototype_max_connection = torch.argmax(self.ppnet.last_layer.weight, dim=0) self.prototype_max_connection = prototype_max_connection.cpu().numpy() if ( np.sum(self.prototype_max_connection == self.prototype_img_identity) == self.ppnet.num_prototypes ): self.log( "All prototypes connect most strongly to their respective classes." ) else: self.log( "WARNING: Not all prototypes connect most strongly to their respective classes." )