예제 #1
0
 def on_end(self, data: Data) -> None:
     index_summaries = DefaultKeyDict(default=lambda x: Summary(name=x))
     for mode in self.mode:
         final_scores = sorted([(idx, elem[-1][1]) for idx, elem in self.index_history[mode].items()],
                               key=lambda x: x[1])
         max_idx_list = {elem[0] for elem in final_scores[-1:-self.n_max_to_keep - 1:-1]}
         min_idx_list = {elem[0] for elem in final_scores[:self.n_min_to_keep]}
         target_idx_list = Set.union(min_idx_list, max_idx_list, self.idx_to_keep)
         for idx in target_idx_list:
             for step, score in self.index_history[mode][idx]:
                 index_summaries[idx].history[mode][self.metric_key][step] = score
     self.system.add_graph(self.outputs[0], list(index_summaries.values()))  # So traceability can draw it
     data.write_without_log(self.outputs[0], list(index_summaries.values()))
예제 #2
0
 def on_epoch_end(self, data: Data) -> None:
     self.y_true = np.squeeze(np.stack(self.y_true))
     self.y_pred = np.stack(self.y_pred)
     calibrator = cal.PlattBinnerMarginalCalibrator(num_calibration=len(
         self.y_true),
                                                    num_bins=10)
     calibrator.train_calibration(probs=self.y_pred, labels=self.y_true)
     if self.save_path:
         if not self.save_key or (self.save_key
                                  and to_number(data[self.save_key]) == 0):
             with open(self.save_path, 'wb') as f:
                 dill.dump(calibrator.calibrate, file=f)
             print(
                 f"FastEstimator-PBMCalibrator: Calibrator written to {self.save_path}"
             )
     data.write_without_log(self.outputs[0],
                            calibrator.calibrate(self.y_pred))
예제 #3
0
 def on_end(self, data: Data) -> None:
     self.system.add_graph(
         self.outputs[0],
         list(self.label_summaries.values()))  # So traceability can draw it
     data.write_without_log(self.outputs[0],
                            list(self.label_summaries.values()))
예제 #4
0
    def on_epoch_end(self, data: Data) -> None:
        mode = self.system.mode
        if self.n_found[mode] > 0:
            if self.n_required[mode] > 0:
                # We are keeping a user-specified number of samples
                self.samples[mode] = {
                    key: concat(val)[:self.n_required[mode]]
                    for key, val in self.samples[mode].items()
                }
            else:
                # We are keeping one batch of data
                self.samples[mode] = {
                    key: val[0]
                    for key, val in self.samples[mode].items()
                }
            # even if you haven't found n_required samples, you're at end of epoch so no point trying to collect more
            self.n_found[mode] = 0
            self.n_required[mode] = 0

        masks = self.salnet.get_masks(self.samples[mode])
        smoothed, integrated, smint = {}, {}, {}
        if self.smoothing:
            smoothed = self.salnet.get_smoothed_masks(self.samples[mode],
                                                      nsamples=self.smoothing)
        if self.integrating:
            if isinstance(self.integrating, Tuple):
                n_integration, n_smoothing = self.integrating
            else:
                n_integration = self.integrating
                n_smoothing = self.smoothing
            integrated = self.salnet.get_integrated_masks(
                self.samples[mode], nsamples=n_integration)
            if n_smoothing:
                smint = self.salnet.get_smoothed_masks(
                    self.samples[mode],
                    nsamples=n_smoothing,
                    nintegration=n_integration)

        # Arrange the outputs
        args = {}
        if self.class_key:
            classes = self.samples[mode][self.class_key]
            if self.label_mapping:
                classes = np.array([
                    self.label_mapping[clazz]
                    for clazz in to_number(squeeze(classes))
                ])
            args[self.class_key] = classes
        for key in self.model_outputs:
            classes = masks[key]
            if self.label_mapping:
                classes = np.array([
                    self.label_mapping[clazz]
                    for clazz in to_number(squeeze(classes))
                ])
            args[key] = classes
        sal = smint or integrated or smoothed or masks
        for key, val in self.samples[mode].items():
            if key is not self.class_key:
                args[key] = val
                # Create a linear combination of the original image, the saliency mask, and the product of the two in
                # order to highlight regions of importance
                min_val = reduce_min(val)
                diff = reduce_max(val) - min_val
                for outkey in self.outputs:
                    args["{} {}".format(
                        key, outkey)] = (0.3 * (sal[outkey] *
                                                (val - min_val) + min_val) +
                                         0.3 * val + 0.4 * sal[outkey] * diff +
                                         min_val)
        for key in self.outputs:
            args[key] = masks[key]
            if smoothed:
                args["Smoothed {}".format(key)] = smoothed[key]
            if integrated:
                args["Integrated {}".format(key)] = integrated[key]
            if smint:
                args["SmInt {}".format(key)] = smint[key]
        result = ImgData(colormap="inferno", **args)

        data.write_without_log(self.outputs[0], result)
예제 #5
0
    def on_epoch_end(self, data: Data) -> None:
        # Keep only the user-specified number of samples
        images = concat(self.images)[:self.n_samples or self.n_found]
        _, height, width = get_image_dims(images)
        grads = to_number(concat(self.grads)[:self.n_samples or self.n_found])
        if tf.is_tensor(images):
            grads = np.moveaxis(grads, source=-1,
                                destination=1)  # grads should be channel first
        args = {}
        labels = None if not self.labels else concat(
            self.labels)[:self.n_samples or self.n_found]
        if labels is not None:
            if len(labels.shape) > 1:
                labels = argmax(labels, axis=-1)
            if self.label_mapping:
                labels = np.array([
                    self.label_mapping[clazz]
                    for clazz in to_number(squeeze(labels))
                ])
            args[self.true_label_key] = labels
        preds = None if not self.preds else concat(
            self.preds)[:self.n_samples or self.n_found]
        if preds is not None:
            if len(preds.shape) > 1:
                preds = argmax(preds, axis=-1)
            if self.label_mapping:
                preds = np.array([
                    self.label_mapping[clazz]
                    for clazz in to_number(squeeze(preds))
                ])
            args[self.pred_label_key] = preds
        args[self.image_key] = images
        # Clear memory
        self._reset()
        # Make the image
        # TODO: In future maybe allow multiple different grads to have side-by-side comparisons of classes
        components = [np.mean(grads, axis=1)]
        components = [np.maximum(component, 0) for component in components]
        masks = []
        for component_batch in components:
            img_batch = []
            for img in component_batch:
                img = cv2.resize(img, (width, height))
                img = img - np.min(img)
                img = img / np.max(img)
                img = cv2.cvtColor(
                    cv2.applyColorMap(np.uint8(255 * img), cv2.COLORMAP_JET),
                    cv2.COLOR_BGR2RGB)
                img = np.float32(img) / 255
                img_batch.append(img)
            img_batch = np.array(img_batch, dtype=np.float32)
            # Switch to channel first for pytorch
            if isinstance(images, torch.Tensor):
                img_batch = np.moveaxis(img_batch, source=-1, destination=1)
            masks.append(img_batch)

        components = [
            images + mask for mask in masks
        ]  # This seems to work even if the image is 1 channel instead of 3
        components = [image / reduce_max(image) for image in components]

        for elem in components:
            args[self.grad_key] = elem

        result = ImgData(**args)
        data.write_without_log(self.outputs[0], result)
예제 #6
0
    def on_epoch_end(self, data: Data) -> None:
        # Keep only the user-specified number of samples
        images = concat(self.images)[:self.n_samples or self.n_found]
        _, height, width = get_image_dims(images)
        activations = to_number(
            concat(self.activations)[:self.n_samples or self.n_found])
        if tf.is_tensor(images):
            activations = np.moveaxis(
                activations, source=-1,
                destination=1)  # Activations should be channel first
        args = {}
        labels = None if not self.labels else concat(
            self.labels)[:self.n_samples or self.n_found]
        if labels is not None:
            if len(labels.shape) > 1:
                labels = argmax(labels, axis=-1)
            if self.label_mapping:
                labels = np.array([
                    self.label_mapping[clazz]
                    for clazz in to_number(squeeze(labels))
                ])
            args[self.true_label_key] = labels
        preds = None if not self.preds else concat(
            self.preds)[:self.n_samples or self.n_found]
        if preds is not None:
            if len(preds.shape) > 1:
                preds = argmax(preds, axis=-1)
            if self.label_mapping:
                preds = np.array([
                    self.label_mapping[clazz]
                    for clazz in to_number(squeeze(preds))
                ])
            args[self.pred_label_key] = preds
        args[self.image_key] = images
        # Clear memory
        self._reset()
        # Make the image
        n_components, batch_component_image = self._project_2d(activations)
        components = []  # component x image (batch x image)
        for component_idx in range(n_components):
            batch = []
            for base_image, component_image in zip(images,
                                                   batch_component_image):
                if len(component_image) > component_idx:
                    mask = component_image[component_idx]
                    mask = cv2.resize(mask, (width, height))
                    mask = mask - np.min(mask)
                    mask = mask / np.max(mask)
                    mask = cv2.cvtColor(
                        cv2.applyColorMap(np.uint8(255 * mask),
                                          cv2.COLORMAP_JET), cv2.COLOR_BGR2RGB)
                    mask = np.float32(mask) / 255
                    # switch to channel first for pytorch
                    if isinstance(base_image, torch.Tensor):
                        mask = np.moveaxis(mask, source=-1, destination=1)
                    new_image = base_image + mask
                    new_image = new_image / reduce_max(new_image)
                else:
                    # There's no component for this image, so display an empty image here
                    new_image = np.ones_like(base_image)
                batch.append(new_image)
            components.append(np.array(batch, dtype=np.float32))

        for idx, elem in enumerate(components):
            args[f"Component {idx}"] = elem

        result = ImgData(**args)
        data.write_without_log(self.outputs[0], result)