Ejemplo n.º 1
0
def plot_saliency(model, model_input, baseline_input=None, decode_dictionary=None, color_map="inferno", smooth=7):
    """Displays or saves a saliency mask interpretation of the given input

    Args:
        model: A model to evaluate. Should be a classifier which takes the 0th axis as the batch axis
        model_input: Input tensor, shaped for the model ex. (1, 299, 299, 3)
        baseline_input: An example of what a blank model input would be.
                        Should be a tensor with the same shape as model_input
        decode_dictionary: A dictionary of "class_idx" -> "class_name" associations
        color_map: The color map to use to visualize the saliency maps.
                        Consider "Greys_r", "plasma", or "magma" as alternatives
        smooth: The number of samples to use when generating a smoothed image
    """
    predictions = np.asarray(model(model_input))
    decoded = decode_predictions(predictions, top=3, dictionary=decode_dictionary)

    grad_sal = GradientSaliency(model)
    grad_int = IntegratedGradients(model)

    vanilla_masks = grad_sal.get_mask(model_input)
    vanilla_ims = convert_for_visualization(vanilla_masks)
    smooth_masks = grad_sal.get_smoothed_mask(model_input, nsamples=smooth)
    smooth_ims = convert_for_visualization(smooth_masks)
    smooth_integrated_masks = grad_int.get_smoothed_mask(model_input, nsamples=smooth, input_baseline=baseline_input)
    smooth_integrated_ims = convert_for_visualization(smooth_integrated_masks)

    filtered_inputs = (model_input + 1) * np.asarray(smooth_integrated_ims)[:, :, :, None] - 1

    num_rows = model_input.shape[0]
    num_cols = 6
    dpi = 96.0

    box_width = max(220, model_input.shape[2])
    box_height = max(220, model_input.shape[1])
    fig, axs = plt.subplots(num_rows, num_cols, figsize=(num_cols * (box_width / dpi), num_rows * (box_height / dpi)),
                            dpi=dpi)
    if num_rows == 1:
        axs = [axs]  # axs object not wrapped if there's only one row
    for i in range(num_rows):
        show_text(np.ones_like(model_input[i]), decoded[i], axis=axs[i][0], title="Predictions" if i == 0 else None)
        show_image(model_input[i], axis=axs[i][1], title="Raw" if i == 0 else None)
        show_image(filtered_inputs[i], axis=axs[i][2], title="Filtered" if i == 0 else None)
        show_gray_image(vanilla_ims[i], axis=axs[i][3], color_map=color_map, title="Vanilla" if i == 0 else None)
        show_gray_image(smooth_ims[i], axis=axs[i][4], color_map=color_map, title="Smoothed" if i == 0 else None)
        show_gray_image(smooth_integrated_ims[i],
                        axis=axs[i][5],
                        color_map=color_map,
                        title="Integrated Smoothed" if i == 0 else None)
    plt.subplots_adjust(top=0.95, bottom=0.01, left=0.01, right=0.99, hspace=0.03, wspace=0.03)
    # plt.tight_layout(pad=0.3, h_pad=0.03, w_pad=0.03, rect=(0, 0, 0.98, 0.98))
    return fig
Ejemplo n.º 2
0
def plot_caricature(model,
                    model_input,
                    layer_ids=None,
                    decode_dictionary=None,
                    n_steps=512,
                    learning_rate=0.05,
                    blur=1,
                    cossim_pow=0.5,
                    sd=0.01,
                    fft=True,
                    decorrelate=True,
                    sigmoid=True):
    """
    Args:
        model (model): The keras model to be inspected by the Caricature visualization
        model_input (tensor): The input images to be fed to the model
        layer_ids (list): The layers of the model to be inspected by the Caricature visualization
        decode_dictionary (dict): A dictionary mapping model outputs to class names
        n_steps (int): How many steps of optimization to run when computing caricatures (quality vs time trade)
        learning_rate (float): The learning rate of the caricature optimizer. Should be higher than usual
        blur (float): How much blur to add to images during caricature generation
        cossim_pow (float): How much should similarity in form be valued versus creative license
        sd (float): The standard deviation of the noise used to seed the caricature
        fft (bool): Whether to use fft space (True) or image space (False) to create caricatures
        decorrelate (bool): Whether to use an ImageNet-derived color correlation matrix to de-correlate
                            colors in the caricature. Parameter has no effect on grey scale images.
        sigmoid (bool): Whether to use sigmoid (True) or clipping (False) to bound the caricature pixel values
    """
    if layer_ids is None or len(layer_ids) == 0:
        layer_ids = [i for i in range(len(model.layers))]

    predictions = np.asarray(model.predict(model_input))
    decoded = decode_predictions(predictions,
                                 top=3,
                                 dictionary=decode_dictionary)

    caricatures = [
        generate_caricatures(model,
                             model_input,
                             layer,
                             n_steps=n_steps,
                             learning_rate=learning_rate,
                             blur=blur,
                             cossim_pow=cossim_pow,
                             sd=sd,
                             fft=fft,
                             decorrelate=decorrelate,
                             sigmoid=sigmoid) for layer in layer_ids
    ]

    num_rows = model_input.shape[0]
    num_cols = len(layer_ids) + 2
    dpi = 96.0

    box_width = max(220, model_input.shape[2])
    box_height = max(220, model_input.shape[1])
    fig, axs = plt.subplots(num_rows,
                            num_cols,
                            figsize=(num_cols * (box_width / dpi),
                                     num_rows * (box_height / dpi)),
                            dpi=dpi)

    if num_rows == 1:
        axs = [axs]  # axs object not wrapped if there's only one row
    for i in range(num_rows):
        show_text(np.ones_like(model_input[i]),
                  decoded[i],
                  axis=axs[i][0],
                  title="Predictions" if i == 0 else None)
        show_image(model_input[i],
                   axis=axs[i][1],
                   title="Raw" if i == 0 else None)
        for j in range(len(layer_ids)):
            layer_id = layer_ids[j]
            layer_name = ": " + model.layers[layer_id].name
            show_image(caricatures[j][i],
                       axis=axs[i][2 + j],
                       title="Layer {}{}".format(layer_id, layer_name)
                       if i == 0 else None)

    plt.subplots_adjust(top=0.95,
                        bottom=0.01,
                        left=0.01,
                        right=0.99,
                        hspace=0.03,
                        wspace=0.03)
    return fig
Ejemplo n.º 3
0
def plot_gradcam(inputs,
                 model,
                 layer_id=None,
                 target_class=None,
                 decode_dictionary=None,
                 colormap=cv2.COLORMAP_INFERNO):
    """Creates a GradCam interpretation of the given input

    Args:
        inputs (tf.tensor): Model input, with batch along the zero axis
        model (tf.keras.model): tf.keras model to inspect
        layer_id (int, str, None): Which layer to inspect. Should be a convolutional layer. If None, the last \
                                    acceptable layer from the model will be selected
        target_class (int, None): Which output class to try to explain. None will default to explaining the maximum \
                                    likelihood prediction
        decode_dictionary (dict): A dictionary of "class_idx" -> "class_name" associations
        colormap (int): Which colormap to use when generating the heatmaps

    Returns:
        The matplotlib figure handle
    """
    gradcam = FEGradCAM()
    if isinstance(layer_id, int):
        layer_id = model.layers[layer_id].name
    if layer_id is None:
        for layer in reversed(model.layers):
            if layer.output.shape.ndims == 4:
                layer_id = layer.name
                break

    heatmaps, predictions = gradcam.explain(model_input=inputs,
                                            model=model,
                                            layer_name=layer_id,
                                            class_index=target_class,
                                            colormap=colormap)

    decoded = decode_predictions(np.asarray(predictions),
                                 top=3,
                                 dictionary=decode_dictionary)

    num_rows = math.ceil(inputs.shape[0] / 2.0)
    num_cols = 6
    dpi = 96.0

    box_width = max(220, inputs.shape[2])
    box_height = max(220, inputs.shape[1])
    fig, axs = plt.subplots(num_rows,
                            num_cols,
                            figsize=(num_cols * (box_width / dpi),
                                     num_rows * (box_height / dpi)),
                            dpi=dpi)
    if num_rows == 1:
        axs = [axs]  # axs object not wrapped if there's only one row

    odd_cols = inputs.shape[0] % 2 == 1
    if odd_cols:
        axs[num_rows - 1][3].axis('off')
        axs[num_rows - 1][4].axis('off')
        axs[num_rows - 1][5].axis('off')

    for row in range(num_rows):
        for idx, cols in enumerate(((0, 1, 2), (3, 4, 5))):
            if row == num_rows - 1 and idx == 1 and odd_cols:
                break
            show_text(np.ones_like(inputs[2 * row + idx]),
                      decoded[2 * row + idx],
                      axis=axs[row][cols[0]],
                      title="Predictions" if row == 0 else None)
            show_image(inputs[2 * row + idx],
                       axis=axs[row][cols[1]],
                       title="Raw" if row == 0 else None)
            show_image(heatmaps[2 * row + idx],
                       axis=axs[row][cols[2]],
                       title="GradCam" if row == 0 else None)

    plt.subplots_adjust(top=0.95,
                        bottom=0.01,
                        left=0.01,
                        right=0.99,
                        hspace=0.03,
                        wspace=0.03)

    return fig