def forward(self, data: List[Tensor], state: Dict[str, Any]) -> Tensor: data, loss = data grad = get_gradient(target=loss, sources=data, tape=state['tape'], retain_graph=self.retain_graph) adverse_data = clip_by_value(data + self.epsilon * sign(grad), min_value=self.clip_low or reduce_min(data), max_value=self.clip_high or reduce_max(data)) return adverse_data
def _convert_for_visualization(tensor: Tensor, tile: int = 99) -> np.ndarray: """Modify the range of data in a given input `tensor` to be appropriate for visualization. Args: tensor: Input masks, whose channel values are to be reduced by absolute value summation. tile: The percentile [0-100] used to set the max value of the image. Returns: A (batch X width X height) image after visualization clipping is applied. """ if isinstance(tensor, torch.Tensor): channel_axis = 1 else: channel_axis = -1 flattened_mask = reduce_sum(abs(tensor), axis=channel_axis, keepdims=True) non_batch_axes = list(range(len(flattened_mask.shape)))[1:] vmax = percentile(flattened_mask, tile, axis=non_batch_axes, keepdims=True) vmin = reduce_min(flattened_mask, axis=non_batch_axes, keepdims=True) return clip_by_value((flattened_mask - vmin) / (vmax - vmin), 0, 1)
def get_smoothed_masks( self, batch: Dict[str, Any], stdev_spread: float = .15, nsamples: int = 25, nintegration: Optional[int] = None, magnitude: bool = True) -> Dict[str, Union[Tensor, np.ndarray]]: """Generates smoothed greyscale saliency mask(s) from a given `batch` of data. Args: batch: An input batch of data. stdev_spread: Amount of noise to add to the input, as fraction of the total spread (x_max - x_min). nsamples: Number of samples to average across to get the smooth gradient. nintegration: Number of samples to compute when integrating (None to disable). magnitude: If true, computes the sum of squares of gradients instead of just the sum. Returns: Greyscale saliency mask(s) smoothed via the SmoothGrad method. """ # Shallow copy batch since we're going to modify its contents later batch = {key: val for key, val in batch.items()} model_inputs = [batch[ins] for ins in self.model_inputs] stdevs = [ to_number(stdev_spread * (reduce_max(ins) - reduce_min(ins))).item() for ins in model_inputs ] # Adding noise to the image might cause the max likelihood class value to change, so need to keep track of # which class we're comparing to response = self._get_mask(batch) for gather_key, output_key in zip(self.gather_keys, self.model_outputs): batch[gather_key] = response[output_key] if magnitude: for key in self.outputs: response[key] = response[key] * response[key] for _ in range(nsamples - 1): noisy_batch = {key: batch[key] for key in self.gather_keys} for idx, input_name in enumerate(self.model_inputs): noise = random_normal_like(model_inputs[idx], std=stdevs[idx]) x_plus_noise = model_inputs[idx] + noise noisy_batch[input_name] = x_plus_noise grads_and_preds = self._get_mask( noisy_batch ) if not nintegration else self._get_integrated_masks( noisy_batch, nsamples=nintegration) for name in self.outputs: grad = grads_and_preds[name] if magnitude: response[name] += grad * grad else: response[name] += grad for key in self.outputs: grad = response[key] response[key] = self._convert_for_visualization(grad / nsamples) return response
def _get_integrated_masks(self, batch: Dict[str, Any], nsamples: int = 25) -> Dict[str, Tensor]: """Generates raw integrated saliency mask(s) from a given `batch` of data. This method assumes that the Network is already loaded. Args: batch: A batch of input data to be fed to the model. nsamples: How many samples to consider during integration. Returns: The raw integrated saliency mask(s) for the given `batch` of data. """ model_inputs = [batch[ins] for ins in self.model_inputs] # Use a random uniform baseline as advised in https://distill.pub/2020/attribution-baselines/ input_baselines = [ random_uniform_like(ins, minval=reduce_min(ins), maxval=reduce_max(ins)) for ins in model_inputs ] input_diffs = [ model_input - input_baseline for model_input, input_baseline in zip(model_inputs, input_baselines) ] response = {} for alpha in np.linspace(0.0, 1.0, nsamples): noisy_batch = {key: batch[key] for key in self.gather_keys} for idx, input_name in enumerate(self.model_inputs): x_step = input_baselines[idx] + alpha * input_diffs[idx] noisy_batch[input_name] = x_step grads_and_preds = self._get_mask(noisy_batch) for key in self.outputs: if key in response: response[key] += grads_and_preds[key] else: response[key] = grads_and_preds[key] for key in self.outputs: grad = response[key] for diff in input_diffs: grad = grad * diff response[key] = grad return response
def on_epoch_end(self, data: Data) -> None: mode = self.system.mode if self.n_found[mode] > 0: if self.n_required[mode] > 0: # We are keeping a user-specified number of samples self.samples[mode] = { key: concat(val)[:self.n_required[mode]] for key, val in self.samples[mode].items() } else: # We are keeping one batch of data self.samples[mode] = { key: val[0] for key, val in self.samples[mode].items() } # even if you haven't found n_required samples, you're at end of epoch so no point trying to collect more self.n_found[mode] = 0 self.n_required[mode] = 0 masks = self.salnet.get_masks(self.samples[mode]) smoothed, integrated, smint = {}, {}, {} if self.smoothing: smoothed = self.salnet.get_smoothed_masks(self.samples[mode], nsamples=self.smoothing) if self.integrating: if isinstance(self.integrating, Tuple): n_integration, n_smoothing = self.integrating else: n_integration = self.integrating n_smoothing = self.smoothing integrated = self.salnet.get_integrated_masks( self.samples[mode], nsamples=n_integration) if n_smoothing: smint = self.salnet.get_smoothed_masks( self.samples[mode], nsamples=n_smoothing, nintegration=n_integration) # Arrange the outputs args = {} if self.class_key: classes = self.samples[mode][self.class_key] if self.label_mapping: classes = np.array([ self.label_mapping[clazz] for clazz in to_number(squeeze(classes)) ]) args[self.class_key] = classes for key in self.model_outputs: classes = masks[key] if self.label_mapping: classes = np.array([ self.label_mapping[clazz] for clazz in to_number(squeeze(classes)) ]) args[key] = classes sal = smint or integrated or smoothed or masks for key, val in self.samples[mode].items(): if key is not self.class_key: args[key] = val # Create a linear combination of the original image, the saliency mask, and the product of the two in # order to highlight regions of importance min_val = reduce_min(val) diff = reduce_max(val) - min_val for outkey in self.outputs: args["{} {}".format( key, outkey)] = (0.3 * (sal[outkey] * (val - min_val) + min_val) + 0.3 * val + 0.4 * sal[outkey] * diff + min_val) for key in self.outputs: args[key] = masks[key] if smoothed: args["Smoothed {}".format(key)] = smoothed[key] if integrated: args["Integrated {}".format(key)] = integrated[key] if smint: args["SmInt {}".format(key)] = smint[key] result = ImgData(colormap="inferno", **args) data.write_without_log(self.outputs[0], result)