def plot_saliency(model, model_input, baseline_input=None, decode_dictionary=None, color_map="inferno", smooth=7): """Displays or saves a saliency mask interpretation of the given input Args: model: A model to evaluate. Should be a classifier which takes the 0th axis as the batch axis model_input: Input tensor, shaped for the model ex. (1, 299, 299, 3) baseline_input: An example of what a blank model input would be. Should be a tensor with the same shape as model_input decode_dictionary: A dictionary of "class_idx" -> "class_name" associations color_map: The color map to use to visualize the saliency maps. Consider "Greys_r", "plasma", or "magma" as alternatives smooth: The number of samples to use when generating a smoothed image """ predictions = np.asarray(model(model_input)) decoded = decode_predictions(predictions, top=3, dictionary=decode_dictionary) grad_sal = GradientSaliency(model) grad_int = IntegratedGradients(model) vanilla_masks = grad_sal.get_mask(model_input) vanilla_ims = convert_for_visualization(vanilla_masks) smooth_masks = grad_sal.get_smoothed_mask(model_input, nsamples=smooth) smooth_ims = convert_for_visualization(smooth_masks) smooth_integrated_masks = grad_int.get_smoothed_mask(model_input, nsamples=smooth, input_baseline=baseline_input) smooth_integrated_ims = convert_for_visualization(smooth_integrated_masks) filtered_inputs = (model_input + 1) * np.asarray(smooth_integrated_ims)[:, :, :, None] - 1 num_rows = model_input.shape[0] num_cols = 6 dpi = 96.0 box_width = max(220, model_input.shape[2]) box_height = max(220, model_input.shape[1]) fig, axs = plt.subplots(num_rows, num_cols, figsize=(num_cols * (box_width / dpi), num_rows * (box_height / dpi)), dpi=dpi) if num_rows == 1: axs = [axs] # axs object not wrapped if there's only one row for i in range(num_rows): show_text(np.ones_like(model_input[i]), decoded[i], axis=axs[i][0], title="Predictions" if i == 0 else None) show_image(model_input[i], axis=axs[i][1], title="Raw" if i == 0 else None) show_image(filtered_inputs[i], axis=axs[i][2], title="Filtered" if i == 0 else None) show_gray_image(vanilla_ims[i], axis=axs[i][3], color_map=color_map, title="Vanilla" if i == 0 else None) show_gray_image(smooth_ims[i], axis=axs[i][4], color_map=color_map, title="Smoothed" if i == 0 else None) show_gray_image(smooth_integrated_ims[i], axis=axs[i][5], color_map=color_map, title="Integrated Smoothed" if i == 0 else None) plt.subplots_adjust(top=0.95, bottom=0.01, left=0.01, right=0.99, hspace=0.03, wspace=0.03) # plt.tight_layout(pad=0.3, h_pad=0.03, w_pad=0.03, rect=(0, 0, 0.98, 0.98)) return fig
def plot_caricature(model, model_input, layer_ids=None, decode_dictionary=None, n_steps=512, learning_rate=0.05, blur=1, cossim_pow=0.5, sd=0.01, fft=True, decorrelate=True, sigmoid=True): """ Args: model (model): The keras model to be inspected by the Caricature visualization model_input (tensor): The input images to be fed to the model layer_ids (list): The layers of the model to be inspected by the Caricature visualization decode_dictionary (dict): A dictionary mapping model outputs to class names n_steps (int): How many steps of optimization to run when computing caricatures (quality vs time trade) learning_rate (float): The learning rate of the caricature optimizer. Should be higher than usual blur (float): How much blur to add to images during caricature generation cossim_pow (float): How much should similarity in form be valued versus creative license sd (float): The standard deviation of the noise used to seed the caricature fft (bool): Whether to use fft space (True) or image space (False) to create caricatures decorrelate (bool): Whether to use an ImageNet-derived color correlation matrix to de-correlate colors in the caricature. Parameter has no effect on grey scale images. sigmoid (bool): Whether to use sigmoid (True) or clipping (False) to bound the caricature pixel values """ if layer_ids is None or len(layer_ids) == 0: layer_ids = [i for i in range(len(model.layers))] predictions = np.asarray(model.predict(model_input)) decoded = decode_predictions(predictions, top=3, dictionary=decode_dictionary) caricatures = [ generate_caricatures(model, model_input, layer, n_steps=n_steps, learning_rate=learning_rate, blur=blur, cossim_pow=cossim_pow, sd=sd, fft=fft, decorrelate=decorrelate, sigmoid=sigmoid) for layer in layer_ids ] num_rows = model_input.shape[0] num_cols = len(layer_ids) + 2 dpi = 96.0 box_width = max(220, model_input.shape[2]) box_height = max(220, model_input.shape[1]) fig, axs = plt.subplots(num_rows, num_cols, figsize=(num_cols * (box_width / dpi), num_rows * (box_height / dpi)), dpi=dpi) if num_rows == 1: axs = [axs] # axs object not wrapped if there's only one row for i in range(num_rows): show_text(np.ones_like(model_input[i]), decoded[i], axis=axs[i][0], title="Predictions" if i == 0 else None) show_image(model_input[i], axis=axs[i][1], title="Raw" if i == 0 else None) for j in range(len(layer_ids)): layer_id = layer_ids[j] layer_name = ": " + model.layers[layer_id].name show_image(caricatures[j][i], axis=axs[i][2 + j], title="Layer {}{}".format(layer_id, layer_name) if i == 0 else None) plt.subplots_adjust(top=0.95, bottom=0.01, left=0.01, right=0.99, hspace=0.03, wspace=0.03) return fig
def plot_gradcam(inputs, model, layer_id=None, target_class=None, decode_dictionary=None, colormap=cv2.COLORMAP_INFERNO): """Creates a GradCam interpretation of the given input Args: inputs (tf.tensor): Model input, with batch along the zero axis model (tf.keras.model): tf.keras model to inspect layer_id (int, str, None): Which layer to inspect. Should be a convolutional layer. If None, the last \ acceptable layer from the model will be selected target_class (int, None): Which output class to try to explain. None will default to explaining the maximum \ likelihood prediction decode_dictionary (dict): A dictionary of "class_idx" -> "class_name" associations colormap (int): Which colormap to use when generating the heatmaps Returns: The matplotlib figure handle """ gradcam = FEGradCAM() if isinstance(layer_id, int): layer_id = model.layers[layer_id].name if layer_id is None: for layer in reversed(model.layers): if layer.output.shape.ndims == 4: layer_id = layer.name break heatmaps, predictions = gradcam.explain(model_input=inputs, model=model, layer_name=layer_id, class_index=target_class, colormap=colormap) decoded = decode_predictions(np.asarray(predictions), top=3, dictionary=decode_dictionary) num_rows = math.ceil(inputs.shape[0] / 2.0) num_cols = 6 dpi = 96.0 box_width = max(220, inputs.shape[2]) box_height = max(220, inputs.shape[1]) fig, axs = plt.subplots(num_rows, num_cols, figsize=(num_cols * (box_width / dpi), num_rows * (box_height / dpi)), dpi=dpi) if num_rows == 1: axs = [axs] # axs object not wrapped if there's only one row odd_cols = inputs.shape[0] % 2 == 1 if odd_cols: axs[num_rows - 1][3].axis('off') axs[num_rows - 1][4].axis('off') axs[num_rows - 1][5].axis('off') for row in range(num_rows): for idx, cols in enumerate(((0, 1, 2), (3, 4, 5))): if row == num_rows - 1 and idx == 1 and odd_cols: break show_text(np.ones_like(inputs[2 * row + idx]), decoded[2 * row + idx], axis=axs[row][cols[0]], title="Predictions" if row == 0 else None) show_image(inputs[2 * row + idx], axis=axs[row][cols[1]], title="Raw" if row == 0 else None) show_image(heatmaps[2 * row + idx], axis=axs[row][cols[2]], title="GradCam" if row == 0 else None) plt.subplots_adjust(top=0.95, bottom=0.01, left=0.01, right=0.99, hspace=0.03, wspace=0.03) return fig