示例#1
0
def get_similarity_maps(tree: ProtoTree, project_info: dict, log: Log = None):
    log.log_message("\nCalculating similarity maps (after projection)...")

    sim_maps = dict()
    for j in project_info.keys():
        nearest_x = project_info[j]['nearest_input']
        with torch.no_grad():
            _, distances_batch, _ = tree.forward_partial(nearest_x)
            sim_maps[j] = torch.exp(-distances_batch[0, j, :, :]).cpu().numpy()
        del nearest_x
        del project_info[j]['nearest_input']
    return sim_maps, project_info
示例#2
0
def upsample_local(tree: ProtoTree, sample: torch.Tensor, sample_dir: str,
                   folder_name: str, img_name: str, decision_path: list,
                   args: argparse.Namespace):

    dir = os.path.join(
        os.path.join(os.path.join(args.log_dir, folder_name), img_name),
        args.dir_for_saving_images)
    if not os.path.exists(dir):
        os.makedirs(dir)
    with torch.no_grad():
        _, distances_batch, _ = tree.forward_partial(sample)
        sim_map = torch.exp(-distances_batch[0, :, :, :]).cpu().numpy()
    for i, node in enumerate(decision_path[:-1]):
        decision_node_idx = node.index
        node_id = tree._out_map[node]
        img = Image.open(sample_dir)
        x_np = np.asarray(img)
        x_np = np.float32(x_np) / 255
        if x_np.ndim == 2:  #convert grayscale to RGB
            x_np = np.stack((x_np, ) * 3, axis=-1)

        img_size = x_np.shape[:2]
        similarity_map = sim_map[node_id]

        rescaled_sim_map = similarity_map - np.amin(similarity_map)
        rescaled_sim_map = rescaled_sim_map / np.amax(rescaled_sim_map)
        similarity_heatmap = cv2.applyColorMap(
            np.uint8(255 * rescaled_sim_map), cv2.COLORMAP_JET)
        similarity_heatmap = np.float32(similarity_heatmap) / 255
        similarity_heatmap = similarity_heatmap[..., ::-1]
        plt.imsave(fname=os.path.join(
            dir,
            '%s_heatmap_latent_similaritymap.png' % str(decision_node_idx)),
                   arr=similarity_heatmap,
                   vmin=0.0,
                   vmax=1.0)

        upsampled_act_pattern = cv2.resize(similarity_map,
                                           dsize=(img_size[1], img_size[0]),
                                           interpolation=cv2.INTER_CUBIC)
        rescaled_act_pattern = upsampled_act_pattern - np.amin(
            upsampled_act_pattern)
        rescaled_act_pattern = rescaled_act_pattern / np.amax(
            rescaled_act_pattern)
        heatmap = cv2.applyColorMap(np.uint8(255 * rescaled_act_pattern),
                                    cv2.COLORMAP_JET)
        heatmap = np.float32(heatmap) / 255
        heatmap = heatmap[..., ::-1]
        overlayed_original_img = 0.5 * x_np + 0.2 * heatmap
        plt.imsave(fname=os.path.join(
            dir, '%s_heatmap_original_image.png' % str(decision_node_idx)),
                   arr=overlayed_original_img,
                   vmin=0.0,
                   vmax=1.0)

        # save the highly activated patch
        masked_similarity_map = np.ones(similarity_map.shape)
        masked_similarity_map[similarity_map < np.max(
            similarity_map
        )] = 0  #mask similarity map such that only the nearest patch z* is visualized

        upsampled_prototype_pattern = cv2.resize(masked_similarity_map,
                                                 dsize=(img_size[1],
                                                        img_size[0]),
                                                 interpolation=cv2.INTER_CUBIC)
        plt.imsave(fname=os.path.join(
            dir, '%s_masked_upsampled_heatmap.png' % str(decision_node_idx)),
                   arr=upsampled_prototype_pattern,
                   vmin=0.0,
                   vmax=1.0)

        high_act_patch_indices = find_high_activation_crop(
            upsampled_prototype_pattern, args.upsample_threshold)
        high_act_patch = x_np[
            high_act_patch_indices[0]:high_act_patch_indices[1],
            high_act_patch_indices[2]:high_act_patch_indices[3], :]
        plt.imsave(fname=os.path.join(
            dir, '%s_nearest_patch_of_image.png' % str(decision_node_idx)),
                   arr=high_act_patch,
                   vmin=0.0,
                   vmax=1.0)

        # save the original image with bounding box showing high activation patch
        imsave_with_bbox(fname=os.path.join(
            dir, '%s_bounding_box_nearest_patch_of_image.png' %
            str(decision_node_idx)),
                         img_rgb=x_np,
                         bbox_height_start=high_act_patch_indices[0],
                         bbox_height_end=high_act_patch_indices[1],
                         bbox_width_start=high_act_patch_indices[2],
                         bbox_width_end=high_act_patch_indices[3],
                         color=(0, 255, 255))
示例#3
0
def project_with_class_constraints(
        tree: ProtoTree,
        project_loader: DataLoader,
        device,
        args: argparse.Namespace,
        log: Log,
        log_prefix: str = 'log_projection_with_constraints',  # TODO
        progress_prefix: str = 'Projection') -> dict:

    log.log_message(
        "\nProjecting prototypes to nearest training patch (with class restrictions)..."
    )
    # Set the model to evaluation mode
    tree.eval()
    torch.cuda.empty_cache()
    # The goal is to find the latent patch that minimizes the L2 distance to each prototype
    # To do this we iterate through the train dataset and store for each prototype the closest latent patch seen so far
    # Also store info about the image that was used for projection
    global_min_proto_dist = {j: np.inf for j in range(tree.num_prototypes)}
    global_min_patches = {j: None for j in range(tree.num_prototypes)}
    global_min_info = {j: None for j in range(tree.num_prototypes)}

    # Get the shape of the prototypes
    W1, H1, D = tree.prototype_shape

    # Build a progress bar for showing the status
    projection_iter = tqdm(enumerate(project_loader),
                           total=len(project_loader),
                           desc=progress_prefix,
                           ncols=0)

    with torch.no_grad():
        # Get a batch of data
        xs, ys = next(iter(project_loader))
        batch_size = xs.shape[0]
        # For each internal node, collect the leaf labels in the subtree with this node as root.
        # Only images from these classes can be used for projection.
        leaf_labels_subtree = dict()

        for branch, j in tree._out_map.items():
            leaf_labels_subtree[branch.index] = set()
            for leaf in branch.leaves:
                leaf_labels_subtree[branch.index].add(
                    torch.argmax(leaf.distribution()).item())

        for i, (xs, ys) in projection_iter:
            xs, ys = xs.to(device), ys.to(device)
            # Get the features and distances
            # - features_batch: features tensor (shared by all prototypes)
            #   shape: (batch_size, D, W, H)
            # - distances_batch: distances tensor (for all prototypes)
            #   shape: (batch_size, num_prototypes, W, H)
            # - out_map: a dict mapping decision nodes to distances (indices)
            features_batch, distances_batch, out_map = tree.forward_partial(xs)

            # Get the features dimensions
            bs, D, W, H = features_batch.shape

            # Get a tensor containing the individual latent patches
            # Create the patches by unfolding over both the W and H dimensions
            # TODO -- support for strides in the prototype layer? (corresponds to step size here)
            patches_batch = features_batch.unfold(2, W1, 1).unfold(
                3, H1, 1)  # Shape: (batch_size, D, W, H, W1, H1)

            # Iterate over all decision nodes/prototypes
            for node, j in out_map.items():
                leaf_labels = leaf_labels_subtree[node.index]
                # Iterate over all items in the batch
                # Select the features/distances that are relevant to this prototype
                # - distances: distances of the prototype to the latent patches
                #   shape: (W, H)
                # - patches: latent patches
                #   shape: (D, W, H, W1, H1)
                for batch_i, (distances, patches) in enumerate(
                        zip(distances_batch[:, j, :, :], patches_batch)):
                    #Check if label of this image is in one of the leaves of the subtree
                    if ys[batch_i].item() in leaf_labels:
                        # Find the index of the latent patch that is closest to the prototype
                        min_distance = distances.min()
                        min_distance_ix = distances.argmin()
                        # Use the index to get the closest latent patch
                        closest_patch = patches.view(D, W * H, W1,
                                                     H1)[:,
                                                         min_distance_ix, :, :]

                        # Check if the latent patch is closest for all data samples seen so far
                        if min_distance < global_min_proto_dist[j]:
                            global_min_proto_dist[j] = min_distance
                            global_min_patches[j] = closest_patch
                            global_min_info[j] = {
                                'input_image_ix': i * batch_size + batch_i,
                                'patch_ix': min_distance_ix.item(
                                ),  # Index in a flattened array of the feature map
                                'W': W,
                                'H': H,
                                'W1': W1,
                                'H1': H1,
                                'distance': min_distance.item(),
                                'nearest_input':
                                torch.unsqueeze(xs[batch_i], 0),
                                'node_ix': node.index,
                            }

            # Update the progress bar if required
            projection_iter.set_postfix_str(
                f'Batch: {i + 1}/{len(project_loader)}')

            del features_batch
            del distances_batch
            del out_map

        # Copy the patches to the prototype layer weights
        projection = torch.cat(tuple(global_min_patches[j].unsqueeze(0)
                                     for j in range(tree.num_prototypes)),
                               dim=0,
                               out=tree.prototype_layer.prototype_vectors)
        del projection

    return global_min_info, tree