Esempio n. 1
0
    def _get_mask(self, batch: Dict[str, Any]) -> Dict[str, Tensor]:
        """Generates raw saliency mask(s) from a given `batch` of data.

        This method assumes that the Network is already loaded.

        Args:
            batch: A batch of input data to be fed to the model.

        Returns:
            The model outputs and the raw saliency mask(s) for the given `batch` of data. Model predictions are reduced
            via argmax.
        """
        for key in self.gather_keys:
            # If there's no target key, use an empty array which will cause the max-likelihood class to be selected
            batch.setdefault(key, [])
        prediction = self.network.transform(data=batch, mode=self.mode)
        for key in self.model_outputs:
            prediction[key] = argmax(prediction[key], axis=1)
        return prediction
Esempio n. 2
0
 def forward(self, data: List[Tensor], state: Dict[str,
                                                   Any]) -> List[Tensor]:
     return [argmax(tensor=tensor, axis=self.axis) for tensor in data]
Esempio n. 3
0
    def on_epoch_end(self, data: Data) -> None:
        # Keep only the user-specified number of samples
        images = concat(self.images)[:self.n_samples or self.n_found]
        _, height, width = get_image_dims(images)
        grads = to_number(concat(self.grads)[:self.n_samples or self.n_found])
        if tf.is_tensor(images):
            grads = np.moveaxis(grads, source=-1,
                                destination=1)  # grads should be channel first
        args = {}
        labels = None if not self.labels else concat(
            self.labels)[:self.n_samples or self.n_found]
        if labels is not None:
            if len(labels.shape) > 1:
                labels = argmax(labels, axis=-1)
            if self.label_mapping:
                labels = np.array([
                    self.label_mapping[clazz]
                    for clazz in to_number(squeeze(labels))
                ])
            args[self.true_label_key] = labels
        preds = None if not self.preds else concat(
            self.preds)[:self.n_samples or self.n_found]
        if preds is not None:
            if len(preds.shape) > 1:
                preds = argmax(preds, axis=-1)
            if self.label_mapping:
                preds = np.array([
                    self.label_mapping[clazz]
                    for clazz in to_number(squeeze(preds))
                ])
            args[self.pred_label_key] = preds
        args[self.image_key] = images
        # Clear memory
        self._reset()
        # Make the image
        # TODO: In future maybe allow multiple different grads to have side-by-side comparisons of classes
        components = [np.mean(grads, axis=1)]
        components = [np.maximum(component, 0) for component in components]
        masks = []
        for component_batch in components:
            img_batch = []
            for img in component_batch:
                img = cv2.resize(img, (width, height))
                img = img - np.min(img)
                img = img / np.max(img)
                img = cv2.cvtColor(
                    cv2.applyColorMap(np.uint8(255 * img), cv2.COLORMAP_JET),
                    cv2.COLOR_BGR2RGB)
                img = np.float32(img) / 255
                img_batch.append(img)
            img_batch = np.array(img_batch, dtype=np.float32)
            # Switch to channel first for pytorch
            if isinstance(images, torch.Tensor):
                img_batch = np.moveaxis(img_batch, source=-1, destination=1)
            masks.append(img_batch)

        components = [
            images + mask for mask in masks
        ]  # This seems to work even if the image is 1 channel instead of 3
        components = [image / reduce_max(image) for image in components]

        for elem in components:
            args[self.grad_key] = elem

        result = ImgData(**args)
        data.write_without_log(self.outputs[0], result)
Esempio n. 4
0
    def on_epoch_end(self, data: Data) -> None:
        # Keep only the user-specified number of samples
        images = concat(self.images)[:self.n_samples or self.n_found]
        _, height, width = get_image_dims(images)
        activations = to_number(
            concat(self.activations)[:self.n_samples or self.n_found])
        if tf.is_tensor(images):
            activations = np.moveaxis(
                activations, source=-1,
                destination=1)  # Activations should be channel first
        args = {}
        labels = None if not self.labels else concat(
            self.labels)[:self.n_samples or self.n_found]
        if labels is not None:
            if len(labels.shape) > 1:
                labels = argmax(labels, axis=-1)
            if self.label_mapping:
                labels = np.array([
                    self.label_mapping[clazz]
                    for clazz in to_number(squeeze(labels))
                ])
            args[self.true_label_key] = labels
        preds = None if not self.preds else concat(
            self.preds)[:self.n_samples or self.n_found]
        if preds is not None:
            if len(preds.shape) > 1:
                preds = argmax(preds, axis=-1)
            if self.label_mapping:
                preds = np.array([
                    self.label_mapping[clazz]
                    for clazz in to_number(squeeze(preds))
                ])
            args[self.pred_label_key] = preds
        args[self.image_key] = images
        # Clear memory
        self._reset()
        # Make the image
        n_components, batch_component_image = self._project_2d(activations)
        components = []  # component x image (batch x image)
        for component_idx in range(n_components):
            batch = []
            for base_image, component_image in zip(images,
                                                   batch_component_image):
                if len(component_image) > component_idx:
                    mask = component_image[component_idx]
                    mask = cv2.resize(mask, (width, height))
                    mask = mask - np.min(mask)
                    mask = mask / np.max(mask)
                    mask = cv2.cvtColor(
                        cv2.applyColorMap(np.uint8(255 * mask),
                                          cv2.COLORMAP_JET), cv2.COLOR_BGR2RGB)
                    mask = np.float32(mask) / 255
                    # switch to channel first for pytorch
                    if isinstance(base_image, torch.Tensor):
                        mask = np.moveaxis(mask, source=-1, destination=1)
                    new_image = base_image + mask
                    new_image = new_image / reduce_max(new_image)
                else:
                    # There's no component for this image, so display an empty image here
                    new_image = np.ones_like(base_image)
                batch.append(new_image)
            components.append(np.array(batch, dtype=np.float32))

        for idx, elem in enumerate(components):
            args[f"Component {idx}"] = elem

        result = ImgData(**args)
        data.write_without_log(self.outputs[0], result)