Пример #1
0
 def forward(self, data: List[Tensor], state: Dict[str, Any]) -> Tensor:
     data, loss = data
     grad = get_gradient(target=loss, sources=data, tape=state['tape'], retain_graph=self.retain_graph)
     adverse_data = clip_by_value(data + self.epsilon * sign(grad),
                                  min_value=self.clip_low or reduce_min(data),
                                  max_value=self.clip_high or reduce_max(data))
     return adverse_data
Пример #2
0
    def _convert_for_visualization(tensor: Tensor,
                                   tile: int = 99) -> np.ndarray:
        """Modify the range of data in a given input `tensor` to be appropriate for visualization.

        Args:
            tensor: Input masks, whose channel values are to be reduced by absolute value summation.
            tile: The percentile [0-100] used to set the max value of the image.

        Returns:
            A (batch X width X height) image after visualization clipping is applied.
        """
        if isinstance(tensor, torch.Tensor):
            channel_axis = 1
        else:
            channel_axis = -1
        flattened_mask = reduce_sum(abs(tensor),
                                    axis=channel_axis,
                                    keepdims=True)

        non_batch_axes = list(range(len(flattened_mask.shape)))[1:]

        vmax = percentile(flattened_mask,
                          tile,
                          axis=non_batch_axes,
                          keepdims=True)
        vmin = reduce_min(flattened_mask, axis=non_batch_axes, keepdims=True)

        return clip_by_value((flattened_mask - vmin) / (vmax - vmin), 0, 1)
Пример #3
0
    def get_smoothed_masks(
            self,
            batch: Dict[str, Any],
            stdev_spread: float = .15,
            nsamples: int = 25,
            nintegration: Optional[int] = None,
            magnitude: bool = True) -> Dict[str, Union[Tensor, np.ndarray]]:
        """Generates smoothed greyscale saliency mask(s) from a given `batch` of data.

        Args:
            batch: An input batch of data.
            stdev_spread: Amount of noise to add to the input, as fraction of the total spread (x_max - x_min).
            nsamples: Number of samples to average across to get the smooth gradient.
            nintegration: Number of samples to compute when integrating (None to disable).
            magnitude: If true, computes the sum of squares of gradients instead of just the sum.

        Returns:
            Greyscale saliency mask(s) smoothed via the SmoothGrad method.
        """
        # Shallow copy batch since we're going to modify its contents later
        batch = {key: val for key, val in batch.items()}
        model_inputs = [batch[ins] for ins in self.model_inputs]
        stdevs = [
            to_number(stdev_spread *
                      (reduce_max(ins) - reduce_min(ins))).item()
            for ins in model_inputs
        ]

        # Adding noise to the image might cause the max likelihood class value to change, so need to keep track of
        # which class we're comparing to
        response = self._get_mask(batch)
        for gather_key, output_key in zip(self.gather_keys,
                                          self.model_outputs):
            batch[gather_key] = response[output_key]

        if magnitude:
            for key in self.outputs:
                response[key] = response[key] * response[key]

        for _ in range(nsamples - 1):
            noisy_batch = {key: batch[key] for key in self.gather_keys}
            for idx, input_name in enumerate(self.model_inputs):
                noise = random_normal_like(model_inputs[idx], std=stdevs[idx])
                x_plus_noise = model_inputs[idx] + noise
                noisy_batch[input_name] = x_plus_noise
            grads_and_preds = self._get_mask(
                noisy_batch
            ) if not nintegration else self._get_integrated_masks(
                noisy_batch, nsamples=nintegration)
            for name in self.outputs:
                grad = grads_and_preds[name]
                if magnitude:
                    response[name] += grad * grad
                else:
                    response[name] += grad
        for key in self.outputs:
            grad = response[key]
            response[key] = self._convert_for_visualization(grad / nsamples)
        return response
Пример #4
0
    def _get_integrated_masks(self,
                              batch: Dict[str, Any],
                              nsamples: int = 25) -> Dict[str, Tensor]:
        """Generates raw integrated saliency mask(s) from a given `batch` of data.

        This method assumes that the Network is already loaded.

        Args:
            batch: A batch of input data to be fed to the model.
            nsamples: How many samples to consider during integration.

        Returns:
            The raw integrated saliency mask(s) for the given `batch` of data.
        """
        model_inputs = [batch[ins] for ins in self.model_inputs]

        # Use a random uniform baseline as advised in https://distill.pub/2020/attribution-baselines/
        input_baselines = [
            random_uniform_like(ins,
                                minval=reduce_min(ins),
                                maxval=reduce_max(ins)) for ins in model_inputs
        ]
        input_diffs = [
            model_input - input_baseline for model_input, input_baseline in
            zip(model_inputs, input_baselines)
        ]

        response = {}

        for alpha in np.linspace(0.0, 1.0, nsamples):
            noisy_batch = {key: batch[key] for key in self.gather_keys}
            for idx, input_name in enumerate(self.model_inputs):
                x_step = input_baselines[idx] + alpha * input_diffs[idx]
                noisy_batch[input_name] = x_step
            grads_and_preds = self._get_mask(noisy_batch)
            for key in self.outputs:
                if key in response:
                    response[key] += grads_and_preds[key]
                else:
                    response[key] = grads_and_preds[key]

        for key in self.outputs:
            grad = response[key]
            for diff in input_diffs:
                grad = grad * diff
            response[key] = grad

        return response
Пример #5
0
    def on_epoch_end(self, data: Data) -> None:
        mode = self.system.mode
        if self.n_found[mode] > 0:
            if self.n_required[mode] > 0:
                # We are keeping a user-specified number of samples
                self.samples[mode] = {
                    key: concat(val)[:self.n_required[mode]]
                    for key, val in self.samples[mode].items()
                }
            else:
                # We are keeping one batch of data
                self.samples[mode] = {
                    key: val[0]
                    for key, val in self.samples[mode].items()
                }
            # even if you haven't found n_required samples, you're at end of epoch so no point trying to collect more
            self.n_found[mode] = 0
            self.n_required[mode] = 0

        masks = self.salnet.get_masks(self.samples[mode])
        smoothed, integrated, smint = {}, {}, {}
        if self.smoothing:
            smoothed = self.salnet.get_smoothed_masks(self.samples[mode],
                                                      nsamples=self.smoothing)
        if self.integrating:
            if isinstance(self.integrating, Tuple):
                n_integration, n_smoothing = self.integrating
            else:
                n_integration = self.integrating
                n_smoothing = self.smoothing
            integrated = self.salnet.get_integrated_masks(
                self.samples[mode], nsamples=n_integration)
            if n_smoothing:
                smint = self.salnet.get_smoothed_masks(
                    self.samples[mode],
                    nsamples=n_smoothing,
                    nintegration=n_integration)

        # Arrange the outputs
        args = {}
        if self.class_key:
            classes = self.samples[mode][self.class_key]
            if self.label_mapping:
                classes = np.array([
                    self.label_mapping[clazz]
                    for clazz in to_number(squeeze(classes))
                ])
            args[self.class_key] = classes
        for key in self.model_outputs:
            classes = masks[key]
            if self.label_mapping:
                classes = np.array([
                    self.label_mapping[clazz]
                    for clazz in to_number(squeeze(classes))
                ])
            args[key] = classes
        sal = smint or integrated or smoothed or masks
        for key, val in self.samples[mode].items():
            if key is not self.class_key:
                args[key] = val
                # Create a linear combination of the original image, the saliency mask, and the product of the two in
                # order to highlight regions of importance
                min_val = reduce_min(val)
                diff = reduce_max(val) - min_val
                for outkey in self.outputs:
                    args["{} {}".format(
                        key, outkey)] = (0.3 * (sal[outkey] *
                                                (val - min_val) + min_val) +
                                         0.3 * val + 0.4 * sal[outkey] * diff +
                                         min_val)
        for key in self.outputs:
            args[key] = masks[key]
            if smoothed:
                args["Smoothed {}".format(key)] = smoothed[key]
            if integrated:
                args["Integrated {}".format(key)] = integrated[key]
            if smint:
                args["SmInt {}".format(key)] = smint[key]
        result = ImgData(colormap="inferno", **args)

        data.write_without_log(self.outputs[0], result)