def get_similarity_maps(tree: ProtoTree, project_info: dict, log: Log = None): log.log_message("\nCalculating similarity maps (after projection)...") sim_maps = dict() for j in project_info.keys(): nearest_x = project_info[j]['nearest_input'] with torch.no_grad(): _, distances_batch, _ = tree.forward_partial(nearest_x) sim_maps[j] = torch.exp(-distances_batch[0, j, :, :]).cpu().numpy() del nearest_x del project_info[j]['nearest_input'] return sim_maps, project_info
def upsample_local(tree: ProtoTree, sample: torch.Tensor, sample_dir: str, folder_name: str, img_name: str, decision_path: list, args: argparse.Namespace): dir = os.path.join( os.path.join(os.path.join(args.log_dir, folder_name), img_name), args.dir_for_saving_images) if not os.path.exists(dir): os.makedirs(dir) with torch.no_grad(): _, distances_batch, _ = tree.forward_partial(sample) sim_map = torch.exp(-distances_batch[0, :, :, :]).cpu().numpy() for i, node in enumerate(decision_path[:-1]): decision_node_idx = node.index node_id = tree._out_map[node] img = Image.open(sample_dir) x_np = np.asarray(img) x_np = np.float32(x_np) / 255 if x_np.ndim == 2: #convert grayscale to RGB x_np = np.stack((x_np, ) * 3, axis=-1) img_size = x_np.shape[:2] similarity_map = sim_map[node_id] rescaled_sim_map = similarity_map - np.amin(similarity_map) rescaled_sim_map = rescaled_sim_map / np.amax(rescaled_sim_map) similarity_heatmap = cv2.applyColorMap( np.uint8(255 * rescaled_sim_map), cv2.COLORMAP_JET) similarity_heatmap = np.float32(similarity_heatmap) / 255 similarity_heatmap = similarity_heatmap[..., ::-1] plt.imsave(fname=os.path.join( dir, '%s_heatmap_latent_similaritymap.png' % str(decision_node_idx)), arr=similarity_heatmap, vmin=0.0, vmax=1.0) upsampled_act_pattern = cv2.resize(similarity_map, dsize=(img_size[1], img_size[0]), interpolation=cv2.INTER_CUBIC) rescaled_act_pattern = upsampled_act_pattern - np.amin( upsampled_act_pattern) rescaled_act_pattern = rescaled_act_pattern / np.amax( rescaled_act_pattern) heatmap = cv2.applyColorMap(np.uint8(255 * rescaled_act_pattern), cv2.COLORMAP_JET) heatmap = np.float32(heatmap) / 255 heatmap = heatmap[..., ::-1] overlayed_original_img = 0.5 * x_np + 0.2 * heatmap plt.imsave(fname=os.path.join( dir, '%s_heatmap_original_image.png' % str(decision_node_idx)), arr=overlayed_original_img, vmin=0.0, vmax=1.0) # save the highly activated patch masked_similarity_map = np.ones(similarity_map.shape) masked_similarity_map[similarity_map < np.max( similarity_map )] = 0 #mask similarity map such that only the nearest patch z* is visualized upsampled_prototype_pattern = cv2.resize(masked_similarity_map, dsize=(img_size[1], img_size[0]), interpolation=cv2.INTER_CUBIC) plt.imsave(fname=os.path.join( dir, '%s_masked_upsampled_heatmap.png' % str(decision_node_idx)), arr=upsampled_prototype_pattern, vmin=0.0, vmax=1.0) high_act_patch_indices = find_high_activation_crop( upsampled_prototype_pattern, args.upsample_threshold) high_act_patch = x_np[ high_act_patch_indices[0]:high_act_patch_indices[1], high_act_patch_indices[2]:high_act_patch_indices[3], :] plt.imsave(fname=os.path.join( dir, '%s_nearest_patch_of_image.png' % str(decision_node_idx)), arr=high_act_patch, vmin=0.0, vmax=1.0) # save the original image with bounding box showing high activation patch imsave_with_bbox(fname=os.path.join( dir, '%s_bounding_box_nearest_patch_of_image.png' % str(decision_node_idx)), img_rgb=x_np, bbox_height_start=high_act_patch_indices[0], bbox_height_end=high_act_patch_indices[1], bbox_width_start=high_act_patch_indices[2], bbox_width_end=high_act_patch_indices[3], color=(0, 255, 255))
def project_with_class_constraints( tree: ProtoTree, project_loader: DataLoader, device, args: argparse.Namespace, log: Log, log_prefix: str = 'log_projection_with_constraints', # TODO progress_prefix: str = 'Projection') -> dict: log.log_message( "\nProjecting prototypes to nearest training patch (with class restrictions)..." ) # Set the model to evaluation mode tree.eval() torch.cuda.empty_cache() # The goal is to find the latent patch that minimizes the L2 distance to each prototype # To do this we iterate through the train dataset and store for each prototype the closest latent patch seen so far # Also store info about the image that was used for projection global_min_proto_dist = {j: np.inf for j in range(tree.num_prototypes)} global_min_patches = {j: None for j in range(tree.num_prototypes)} global_min_info = {j: None for j in range(tree.num_prototypes)} # Get the shape of the prototypes W1, H1, D = tree.prototype_shape # Build a progress bar for showing the status projection_iter = tqdm(enumerate(project_loader), total=len(project_loader), desc=progress_prefix, ncols=0) with torch.no_grad(): # Get a batch of data xs, ys = next(iter(project_loader)) batch_size = xs.shape[0] # For each internal node, collect the leaf labels in the subtree with this node as root. # Only images from these classes can be used for projection. leaf_labels_subtree = dict() for branch, j in tree._out_map.items(): leaf_labels_subtree[branch.index] = set() for leaf in branch.leaves: leaf_labels_subtree[branch.index].add( torch.argmax(leaf.distribution()).item()) for i, (xs, ys) in projection_iter: xs, ys = xs.to(device), ys.to(device) # Get the features and distances # - features_batch: features tensor (shared by all prototypes) # shape: (batch_size, D, W, H) # - distances_batch: distances tensor (for all prototypes) # shape: (batch_size, num_prototypes, W, H) # - out_map: a dict mapping decision nodes to distances (indices) features_batch, distances_batch, out_map = tree.forward_partial(xs) # Get the features dimensions bs, D, W, H = features_batch.shape # Get a tensor containing the individual latent patches # Create the patches by unfolding over both the W and H dimensions # TODO -- support for strides in the prototype layer? (corresponds to step size here) patches_batch = features_batch.unfold(2, W1, 1).unfold( 3, H1, 1) # Shape: (batch_size, D, W, H, W1, H1) # Iterate over all decision nodes/prototypes for node, j in out_map.items(): leaf_labels = leaf_labels_subtree[node.index] # Iterate over all items in the batch # Select the features/distances that are relevant to this prototype # - distances: distances of the prototype to the latent patches # shape: (W, H) # - patches: latent patches # shape: (D, W, H, W1, H1) for batch_i, (distances, patches) in enumerate( zip(distances_batch[:, j, :, :], patches_batch)): #Check if label of this image is in one of the leaves of the subtree if ys[batch_i].item() in leaf_labels: # Find the index of the latent patch that is closest to the prototype min_distance = distances.min() min_distance_ix = distances.argmin() # Use the index to get the closest latent patch closest_patch = patches.view(D, W * H, W1, H1)[:, min_distance_ix, :, :] # Check if the latent patch is closest for all data samples seen so far if min_distance < global_min_proto_dist[j]: global_min_proto_dist[j] = min_distance global_min_patches[j] = closest_patch global_min_info[j] = { 'input_image_ix': i * batch_size + batch_i, 'patch_ix': min_distance_ix.item( ), # Index in a flattened array of the feature map 'W': W, 'H': H, 'W1': W1, 'H1': H1, 'distance': min_distance.item(), 'nearest_input': torch.unsqueeze(xs[batch_i], 0), 'node_ix': node.index, } # Update the progress bar if required projection_iter.set_postfix_str( f'Batch: {i + 1}/{len(project_loader)}') del features_batch del distances_batch del out_map # Copy the patches to the prototype layer weights projection = torch.cat(tuple(global_min_patches[j].unsqueeze(0) for j in range(tree.num_prototypes)), dim=0, out=tree.prototype_layer.prototype_vectors) del projection return global_min_info, tree