Beispiel #1
0
    def forward_batch(
            self, data: Union[Tensor, List[Tensor]],
            state: Dict[str, Any]) -> Union[np.ndarray, List[np.ndarray]]:
        """A method which will be invoked in order to transform a batch of data.

        This method will be invoked on batches of data during network postprocessing. Note that the inputs may be numpy
        arrays or TF/Torch tensors. Outputs are expected to be Numpy arrays, though this is not enforced. Developers
        should probably not need to override this implementation unless they are building an op specifically intended
        for postprocessing.

        Args:
            data: The arrays from the data dictionary corresponding to whatever keys this Op declares as its `inputs`.
            state: Information about the current execution context, for example {"mode": "train"}.

        Returns:
            The `data` after applying whatever transform this Op is responsible for. It will be written into the data
            dictionary based on whatever keys this Op declares as its `outputs`.
        """
        if isinstance(data, List):
            data = [to_number(elem) for elem in data]
            batch_size = data[0].shape[0]
            data = [[elem[i] for elem in data] for i in range(batch_size)]
        else:
            data = to_number(data)
            data = [data[i] for i in range(data.shape[0])]
        results = [self.forward(elem, state) for elem in data]
        if self.out_list:
            results = [
                np.array(col) for col in [[row[i] for row in results]
                                          for i in range(len(results[0]))]
            ]
        else:
            results = np.array(results)
        return results
 def on_batch_end(self, data: Data) -> None:
     y_true, y_pred = to_number(data[self.true_key]), to_number(
         data[self.pred_key])
     if y_true.shape[-1] > 1 and y_true.ndim > 1:
         y_true = np.argmax(y_true, axis=-1)
     assert y_pred.shape[0] == y_true.shape[0]
     self.y_true.extend(y_true)
     self.y_pred.extend(y_pred)
Beispiel #3
0
 def on_batch_end(self, data: Data) -> None:
     y_pred, y_true = to_number(data['pred']), to_number(data['target_real'])
     if y_true.shape[-1] > 1 and y_true.ndim > 2:
         y_true = np.argmax(y_true, axis=-1)
     if y_pred.shape[-1] > 1 and y_pred.ndim > 2:
         y_pred = np.argmax(y_pred, axis=-1)
     sentence_level_scores = self.batch_precision_parameters(y_true, y_pred)
     data.write_per_instance_log(self.outputs[0], sentence_level_scores)
Beispiel #4
0
    def on_batch_end(self, data: Data) -> None:
        y_true, y_pred = to_number(data[self.true_key]), to_number(data[self.pred_key])
        batch_size = y_true.shape[0]
        y_true, y_pred = y_true.reshape((batch_size, -1)), y_pred.reshape((batch_size, -1))

        prediction_label = (y_pred >= self.threshold).astype(np.int32)

        intersection = np.sum(y_true * prediction_label, axis=-1)
        area_sum = np.sum(y_true, axis=-1) + np.sum(prediction_label, axis=-1)
        dice = (2. * intersection + self.smooth) / (area_sum + self.smooth)
        self.dice.extend(list(dice))
Beispiel #5
0
 def on_batch_end(self, data: Data) -> None:
     y_true, y_pred = to_number(data[self.true_key]), to_number(data[self.pred_key])
     self.binary_classification = y_pred.shape[-1] == 1
     if y_true.shape[-1] > 1 and y_true.ndim > 1:
         y_true = np.argmax(y_true, axis=-1)
     if y_pred.shape[-1] > 1:
         y_pred = np.argmax(y_pred, axis=-1)
     else:
         y_pred = np.round(y_pred)
     assert y_pred.size == y_true.size
     self.y_pred.extend(y_pred.ravel())
     self.y_true.extend(y_true.ravel())
Beispiel #6
0
 def on_batch_end(self, data: Data) -> None:
     y_true, y_pred = to_number(data[self.true_key]), to_number(
         data[self.pred_key])
     if y_true.shape[-1] > 1 and y_true.ndim > 1:
         y_true = np.argmax(y_true, axis=-1)
     if y_pred.shape[-1] > 1:
         y_pred = np.argmax(y_pred, axis=-1)
     else:
         y_pred = np.round(y_pred)
     assert y_pred.size == y_true.size
     self.correct += np.sum(y_pred.ravel() == y_true.ravel())
     self.total += len(y_pred.ravel())
Beispiel #7
0
 def on_batch_end(self, data: Data) -> None:
     y_true, y_pred = to_number(data[self.true_key]), to_number(
         data[self.pred_key])
     if y_true.shape[-1] > 1 and y_true.ndim > 1:
         y_true = np.argmax(y_true, axis=-1)
     if y_pred.shape[-1] > 1:
         y_pred = np.argmax(y_pred, axis=-1)
     else:  # binaray classification (pred shape is [batch, 1])
         if self.from_logits:
             y_pred = 1 / (1 + np.exp(-y_pred))
         y_pred = np.round(y_pred)
     assert y_pred.size == y_true.size
     self.correct += np.sum(y_pred.ravel() == y_true.ravel())
     self.total += len(y_pred.ravel())
Beispiel #8
0
 def on_batch_end(self, data: Data) -> None:
     for case in self.instance_cases:
         result = case.criteria(
             *[data[var_name] for var_name in case.criteria_inputs])
         if not isinstance(result, np.ndarray):
             raise TypeError(
                 f"In test with description '{case.description}': "
                 "Criteria return of per-instance test needs to be ndarray with dtype bool."
             )
         elif result.dtype != np.dtype("bool"):
             raise TypeError(
                 f"In test with description '{case.description}': "
                 "Criteria return of per-instance test needs to be ndarray with dtype bool."
             )
         result = result.reshape(-1)
         case.result.append(result)
         if self.data_id:
             data_id = to_number(data[self.data_id]).reshape((-1, ))
             if data_id.size != result.size:
                 raise ValueError(
                     f"In test with description '{case.description}': "
                     "Array size of criteria return doesn't match ID array size. Size of criteria"
                     "return should be equal to the batch_size such that each entry represents the test"
                     "result of its corresponding data instance.")
             case.fail_id.append(data_id[result == False])
Beispiel #9
0
    def _print_message(self, header: str, data: Data, log_epoch: bool = False) -> None:
        """Print a log message to the screen, and record the `data` into the `system` summary.

        Args:
            header: The prefix for the log message.
            data: A collection of data to be recorded.
            log_epoch: Whether epoch information should be included in the log message.
        """
        log_message = header
        if log_epoch:
            log_message += "epoch: {}; ".format(self.system.epoch_idx)
            self.system.write_summary('epoch', self.system.epoch_idx)
        deferred = []
        for key, val in humansorted(data.read_logs().items(), key=lambda x: x[0]):
            if isinstance(val, ValWithError):
                log_message += "{}: {}; ".format(key, str(val))
            else:
                val = to_number(val)
                if val.size > 1:
                    deferred.append("\n{}:\n{};".format(key, np.array2string(val, separator=',')))
                else:
                    log_message += "{}: {}; ".format(key, str(val))
            self.system.write_summary(key, val)
        log_message = log_message.strip()
        for elem in deferred:
            log_message += elem
        print(log_message)
Beispiel #10
0
 def __init__(self,
              loss: LossOp,
              threshold: Union[float, str] = 'exp',
              regularization: float = 1.0,
              average_loss: bool = True,
              output_confidence: Optional[str] = None):
     if len(loss.outputs) != 1 or loss.out_list:
         raise ValueError(
             "SuperLoss only supports lossOps which have a single output.")
     self.loss = loss
     self.loss.average_loss = False
     super().__init__(inputs=loss.inputs,
                      outputs=loss.outputs[0] if not output_confidence else
                      (loss.outputs[0], output_confidence),
                      mode=loss.mode,
                      ds_id=loss.ds_id,
                      average_loss=average_loss)
     if not isinstance(threshold, str):
         threshold = to_number(threshold).item()
     if not isinstance(threshold, float) and threshold != 'exp':
         raise ValueError(
             f'SuperLoss threshold parameter must be "exp" or a float, but got {threshold}'
         )
     self.tau_method = threshold
     if regularization <= 0:
         raise ValueError(
             f"SuperLoss regularization parameter must be greater than 0, but got {regularization}"
         )
     self.lam = regularization
     self.cap = -1.9999998 / e  # Slightly more than -2 / e for numerical stability
     self.initialized = {}
     self.tau = {}
Beispiel #11
0
    def write_embeddings(
        self,
        mode: str,
        embeddings: Iterable[Tuple[str, Tensor, Optional[List[Any]],
                                   Optional[Tensor]]],
        step: int,
    ):
        """Write embeddings (like UMAP) to TensorBoard.

        Args:
            mode: The current mode of execution ('train', 'eval', 'test', 'infer').
            embeddings: A collection of quadruplets like [("key", <features>, [<label1>, ...], <label_images>)].
                Features are expected to be batched, and if labels and/or label images are provided they should have the
                same batch dimension as the features.
            step: The current training step.
        """
        for key, features, labels, label_imgs in embeddings:
            flat = to_number(reshape(features, [features.shape[0], -1]))
            if not isinstance(label_imgs, (torch.Tensor, type(None))):
                label_imgs = to_tensor(label_imgs, 'torch')
                if len(label_imgs.shape) == 4:
                    label_imgs = permute(label_imgs, [0, 3, 1, 2])
            self.summary_writers[mode].add_embedding(mat=flat,
                                                     metadata=labels,
                                                     label_img=label_imgs,
                                                     tag=key,
                                                     global_step=step)
Beispiel #12
0
    def get_smoothed_masks(
            self,
            batch: Dict[str, Any],
            stdev_spread: float = .15,
            nsamples: int = 25,
            nintegration: Optional[int] = None,
            magnitude: bool = True) -> Dict[str, Union[Tensor, np.ndarray]]:
        """Generates smoothed greyscale saliency mask(s) from a given `batch` of data.

        Args:
            batch: An input batch of data.
            stdev_spread: Amount of noise to add to the input, as fraction of the total spread (x_max - x_min).
            nsamples: Number of samples to average across to get the smooth gradient.
            nintegration: Number of samples to compute when integrating (None to disable).
            magnitude: If true, computes the sum of squares of gradients instead of just the sum.

        Returns:
            Greyscale saliency mask(s) smoothed via the SmoothGrad method.
        """
        # Shallow copy batch since we're going to modify its contents later
        batch = {key: val for key, val in batch.items()}
        model_inputs = [batch[ins] for ins in self.model_inputs]
        stdevs = [
            to_number(stdev_spread *
                      (reduce_max(ins) - reduce_min(ins))).item()
            for ins in model_inputs
        ]

        # Adding noise to the image might cause the max likelihood class value to change, so need to keep track of
        # which class we're comparing to
        response = self._get_mask(batch)
        for gather_key, output_key in zip(self.gather_keys,
                                          self.model_outputs):
            batch[gather_key] = response[output_key]

        if magnitude:
            for key in self.outputs:
                response[key] = response[key] * response[key]

        for _ in range(nsamples - 1):
            noisy_batch = {key: batch[key] for key in self.gather_keys}
            for idx, input_name in enumerate(self.model_inputs):
                noise = random_normal_like(model_inputs[idx], std=stdevs[idx])
                x_plus_noise = model_inputs[idx] + noise
                noisy_batch[input_name] = x_plus_noise
            grads_and_preds = self._get_mask(
                noisy_batch
            ) if not nintegration else self._get_integrated_masks(
                noisy_batch, nsamples=nintegration)
            for name in self.outputs:
                grad = grads_and_preds[name]
                if magnitude:
                    response[name] += grad * grad
                else:
                    response[name] += grad
        for key in self.outputs:
            grad = response[key]
            response[key] = self._convert_for_visualization(grad / nsamples)
        return response
Beispiel #13
0
    def write_scalars(self, mode: str, scalars: Iterable[Tuple[str, Any]], step: int) -> None:
        """Write summaries of scalars to TensorBoard.

        Args:
            mode: The current mode of execution ('train', 'eval', 'test', 'infer').
            scalars: A collection of pairs like [("key", val), ("key2", val2), ...].
            step: The current training step.
        """
        for key, val in scalars:
            self.summary_writers[mode].add_scalar(tag=key, scalar_value=to_number(val), global_step=step)
Beispiel #14
0
    def on_batch_end(self, data: Data) -> None:
        y_true, y_pred = to_number(data[self.true_key]), to_number(
            data[self.pred_key])
        if y_true.shape[-1] > 1 and y_true.ndim > 1:
            y_true = np.argmax(y_true, axis=-1)
        if y_pred.shape[-1] > 1:
            y_pred = np.argmax(y_pred, axis=-1)
        else:
            y_pred = np.round(y_pred)
        assert y_pred.size == y_true.size

        batch_confusion = confusion_matrix(y_true,
                                           y_pred,
                                           labels=list(
                                               range(0, self.num_classes)))

        if self.matrix is None:
            self.matrix = batch_confusion
        else:
            self.matrix += batch_confusion
Beispiel #15
0
 def forward_batch(self, data: List[Tensor],
                   state: Dict[str, Any]) -> List[np.ndarray]:
     if isinstance(self.calibration_fn, str):
         if 'warmup' in state and state['warmup']:
             # Don't attempt to load the calibration_fn during warmup
             return data
         with open(self.calibration_fn, 'rb') as f:
             notice = f"FastEstimator-Calibrate: calibration function loaded from {self.calibration_fn}"
             self.calibration_fn = dill.load(f)
             print(notice)
     return [self.calibration_fn(to_number(elem)) for elem in data]
    def on_batch_end(self, data: Data):
        # begin of reading det and gt
        pred = to_number(data[self.pred_key])  # pred is [batch, nms_max_outputs, 7]
        pred = self._reshape_pred(pred)

        gt = to_number(data[self.true_key])  # gt is np.array (batch, box, 5), box dimension is padded
        gt = self._reshape_gt(gt)

        ground_truth_bb = []
        for gt_item in gt:
            idx_in_batch, x1, y1, w, h, label = gt_item
            label = int(label)
            id_epoch = self._get_id_in_epoch(idx_in_batch)
            self.batch_image_ids.append(id_epoch)
            self.image_ids.append(id_epoch)
            tmp_dict = {'idx': id_epoch, 'x1': x1, 'y1': y1, 'w': w, 'h': h, 'label': label}
            ground_truth_bb.append(tmp_dict)

        predicted_bb = []
        for pred_item in pred:
            idx_in_batch, x1, y1, w, h, label, score = pred_item
            label = int(label)
            id_epoch = self.ids_batch_to_epoch[idx_in_batch]
            self.image_ids.append(id_epoch)
            tmp_dict = {'idx': id_epoch, 'x1': x1, 'y1': y1, 'w': w, 'h': h, 'label': label, 'score': score}
            predicted_bb.append(tmp_dict)

        for dict_elem in ground_truth_bb:
            self.gt[dict_elem['idx'], dict_elem['label']].append(dict_elem)

        for dict_elem in predicted_bb:
            self.det[dict_elem['idx'], dict_elem['label']].append(dict_elem)
        # end of reading det and gt

        # compute iou matrix, matrix index is (img_id, cat_id), each element in matrix has shape (num_det, num_gt)
        self.ious = {(img_id, cat_id): self.compute_iou(self.det[img_id, cat_id], self.gt[img_id, cat_id])
                     for img_id in self.batch_image_ids for cat_id in self.categories}

        for cat_id in self.categories:
            for img_id in self.batch_image_ids:
                self.evalimgs[(cat_id, img_id)] = self.evaluate_img(cat_id, img_id)
 def on_epoch_end(self, data: Data) -> None:
     self.y_true = np.squeeze(np.stack(self.y_true))
     self.y_pred = np.stack(self.y_pred)
     calibrator = cal.PlattBinnerMarginalCalibrator(num_calibration=len(
         self.y_true),
                                                    num_bins=10)
     calibrator.train_calibration(probs=self.y_pred, labels=self.y_true)
     if self.save_path:
         if not self.save_key or (self.save_key
                                  and to_number(data[self.save_key]) == 0):
             with open(self.save_path, 'wb') as f:
                 dill.dump(calibrator.calibrate, file=f)
             print(
                 f"FastEstimator-PBMCalibrator: Calibrator written to {self.save_path}"
             )
     data.write_without_log(self.outputs[0],
                            calibrator.calibrate(self.y_pred))
Beispiel #18
0
    def write_images(self, mode: str, images: Iterable[Tuple[str, Any]], step: int) -> None:
        """Write images to TensorBoard.

        Args:
            mode: The current mode of execution ('train', 'eval', 'test', 'infer').
            images: A collection of pairs like [("key", image1), ("key2", image2), ...].
            step: The current training step.
        """
        for key, img in images:
            if isinstance(img, ImgData):
                img = img.paint_figure()
            if isinstance(img, plt.Figure):
                self.summary_writers[mode].add_figure(tag=key, figure=img, global_step=step)
            else:
                self.summary_writers[mode].add_images(tag=key,
                                                      img_tensor=to_number(img),
                                                      global_step=step,
                                                      dataformats='NCHW' if isinstance(img, torch.Tensor) else 'NHWC')
Beispiel #19
0
    def _print_message(self,
                       header: str,
                       data: Data,
                       log_epoch: bool = False) -> None:
        """Print a log message to the screen, and record the `data` into the `system` summary.

        Args:
            header: The prefix for the log message.
            data: A collection of data to be recorded.
            log_epoch: Whether epoch information should be included in the log message.
        """
        log_message = header
        if log_epoch:
            log_message += "epoch: {}; ".format(self.system.epoch_idx)
            self.system.write_summary('epoch', self.system.epoch_idx)
        for key, val in data.read_logs().items():
            val = to_number(val)
            self.system.write_summary(key, val)
            if val.size > 1:
                log_message += "\n{}:\n{};".format(
                    key, np.array2string(val, separator=','))
            else:
                log_message += "{}: {}; ".format(key, str(val))
        print(log_message)
Beispiel #20
0
    def on_epoch_end(self, data: Data) -> None:
        # Keep only the user-specified number of samples
        images = concat(self.images)[:self.n_samples or self.n_found]
        _, height, width = get_image_dims(images)
        grads = to_number(concat(self.grads)[:self.n_samples or self.n_found])
        if tf.is_tensor(images):
            grads = np.moveaxis(grads, source=-1,
                                destination=1)  # grads should be channel first
        args = {}
        labels = None if not self.labels else concat(
            self.labels)[:self.n_samples or self.n_found]
        if labels is not None:
            if len(labels.shape) > 1:
                labels = argmax(labels, axis=-1)
            if self.label_mapping:
                labels = np.array([
                    self.label_mapping[clazz]
                    for clazz in to_number(squeeze(labels))
                ])
            args[self.true_label_key] = labels
        preds = None if not self.preds else concat(
            self.preds)[:self.n_samples or self.n_found]
        if preds is not None:
            if len(preds.shape) > 1:
                preds = argmax(preds, axis=-1)
            if self.label_mapping:
                preds = np.array([
                    self.label_mapping[clazz]
                    for clazz in to_number(squeeze(preds))
                ])
            args[self.pred_label_key] = preds
        args[self.image_key] = images
        # Clear memory
        self._reset()
        # Make the image
        # TODO: In future maybe allow multiple different grads to have side-by-side comparisons of classes
        components = [np.mean(grads, axis=1)]
        components = [np.maximum(component, 0) for component in components]
        masks = []
        for component_batch in components:
            img_batch = []
            for img in component_batch:
                img = cv2.resize(img, (width, height))
                img = img - np.min(img)
                img = img / np.max(img)
                img = cv2.cvtColor(
                    cv2.applyColorMap(np.uint8(255 * img), cv2.COLORMAP_JET),
                    cv2.COLOR_BGR2RGB)
                img = np.float32(img) / 255
                img_batch.append(img)
            img_batch = np.array(img_batch, dtype=np.float32)
            # Switch to channel first for pytorch
            if isinstance(images, torch.Tensor):
                img_batch = np.moveaxis(img_batch, source=-1, destination=1)
            masks.append(img_batch)

        components = [
            images + mask for mask in masks
        ]  # This seems to work even if the image is 1 channel instead of 3
        components = [image / reduce_max(image) for image in components]

        for elem in components:
            args[self.grad_key] = elem

        result = ImgData(**args)
        data.write_without_log(self.outputs[0], result)
Beispiel #21
0
 def on_batch_end(self, data: Data) -> None:
     self.points.append((to_number(data[self.label_key]),
                         to_number(data[self.metric_key])))
 def on_batch_end(self, data: Data) -> None:
     if self.system.epoch_idx % self.epoch_frequency == 0:
         self.points.append((to_number(data[self.index_key]), to_number(data[self.metric_key])))
Beispiel #23
0
    def on_epoch_end(self, data: Data) -> None:
        mode = self.system.mode
        if self.n_found[mode] > 0:
            if self.n_required[mode] > 0:
                # We are keeping a user-specified number of samples
                self.samples[mode] = {
                    key: concat(val)[:self.n_required[mode]]
                    for key, val in self.samples[mode].items()
                }
            else:
                # We are keeping one batch of data
                self.samples[mode] = {
                    key: val[0]
                    for key, val in self.samples[mode].items()
                }
            # even if you haven't found n_required samples, you're at end of epoch so no point trying to collect more
            self.n_found[mode] = 0
            self.n_required[mode] = 0

        masks = self.salnet.get_masks(self.samples[mode])
        smoothed, integrated, smint = {}, {}, {}
        if self.smoothing:
            smoothed = self.salnet.get_smoothed_masks(self.samples[mode],
                                                      nsamples=self.smoothing)
        if self.integrating:
            if isinstance(self.integrating, Tuple):
                n_integration, n_smoothing = self.integrating
            else:
                n_integration = self.integrating
                n_smoothing = self.smoothing
            integrated = self.salnet.get_integrated_masks(
                self.samples[mode], nsamples=n_integration)
            if n_smoothing:
                smint = self.salnet.get_smoothed_masks(
                    self.samples[mode],
                    nsamples=n_smoothing,
                    nintegration=n_integration)

        # Arrange the outputs
        args = {}
        if self.class_key:
            classes = self.samples[mode][self.class_key]
            if self.label_mapping:
                classes = np.array([
                    self.label_mapping[clazz]
                    for clazz in to_number(squeeze(classes))
                ])
            args[self.class_key] = classes
        for key in self.model_outputs:
            classes = masks[key]
            if self.label_mapping:
                classes = np.array([
                    self.label_mapping[clazz]
                    for clazz in to_number(squeeze(classes))
                ])
            args[key] = classes
        sal = smint or integrated or smoothed or masks
        for key, val in self.samples[mode].items():
            if key is not self.class_key:
                args[key] = val
                # Create a linear combination of the original image, the saliency mask, and the product of the two in
                # order to highlight regions of importance
                min_val = reduce_min(val)
                diff = reduce_max(val) - min_val
                for outkey in self.outputs:
                    args["{} {}".format(
                        key, outkey)] = (0.3 * (sal[outkey] *
                                                (val - min_val) + min_val) +
                                         0.3 * val + 0.4 * sal[outkey] * diff +
                                         min_val)
        for key in self.outputs:
            args[key] = masks[key]
            if smoothed:
                args["Smoothed {}".format(key)] = smoothed[key]
            if integrated:
                args["Integrated {}".format(key)] = integrated[key]
            if smint:
                args["SmInt {}".format(key)] = smint[key]
        result = ImgData(colormap="inferno", **args)

        data.write_without_log(self.outputs[0], result)
Beispiel #24
0
    def on_epoch_end(self, data: Data) -> None:
        # Keep only the user-specified number of samples
        images = concat(self.images)[:self.n_samples or self.n_found]
        _, height, width = get_image_dims(images)
        activations = to_number(
            concat(self.activations)[:self.n_samples or self.n_found])
        if tf.is_tensor(images):
            activations = np.moveaxis(
                activations, source=-1,
                destination=1)  # Activations should be channel first
        args = {}
        labels = None if not self.labels else concat(
            self.labels)[:self.n_samples or self.n_found]
        if labels is not None:
            if len(labels.shape) > 1:
                labels = argmax(labels, axis=-1)
            if self.label_mapping:
                labels = np.array([
                    self.label_mapping[clazz]
                    for clazz in to_number(squeeze(labels))
                ])
            args[self.true_label_key] = labels
        preds = None if not self.preds else concat(
            self.preds)[:self.n_samples or self.n_found]
        if preds is not None:
            if len(preds.shape) > 1:
                preds = argmax(preds, axis=-1)
            if self.label_mapping:
                preds = np.array([
                    self.label_mapping[clazz]
                    for clazz in to_number(squeeze(preds))
                ])
            args[self.pred_label_key] = preds
        args[self.image_key] = images
        # Clear memory
        self._reset()
        # Make the image
        n_components, batch_component_image = self._project_2d(activations)
        components = []  # component x image (batch x image)
        for component_idx in range(n_components):
            batch = []
            for base_image, component_image in zip(images,
                                                   batch_component_image):
                if len(component_image) > component_idx:
                    mask = component_image[component_idx]
                    mask = cv2.resize(mask, (width, height))
                    mask = mask - np.min(mask)
                    mask = mask / np.max(mask)
                    mask = cv2.cvtColor(
                        cv2.applyColorMap(np.uint8(255 * mask),
                                          cv2.COLORMAP_JET), cv2.COLOR_BGR2RGB)
                    mask = np.float32(mask) / 255
                    # switch to channel first for pytorch
                    if isinstance(base_image, torch.Tensor):
                        mask = np.moveaxis(mask, source=-1, destination=1)
                    new_image = base_image + mask
                    new_image = new_image / reduce_max(new_image)
                else:
                    # There's no component for this image, so display an empty image here
                    new_image = np.ones_like(base_image)
                batch.append(new_image)
            components.append(np.array(batch, dtype=np.float32))

        for idx, elem in enumerate(components):
            args[f"Component {idx}"] = elem

        result = ImgData(**args)
        data.write_without_log(self.outputs[0], result)