def on_end(self, data: Data) -> None: index_summaries = DefaultKeyDict(default=lambda x: Summary(name=x)) for mode in self.mode: final_scores = sorted([(idx, elem[-1][1]) for idx, elem in self.index_history[mode].items()], key=lambda x: x[1]) max_idx_list = {elem[0] for elem in final_scores[-1:-self.n_max_to_keep - 1:-1]} min_idx_list = {elem[0] for elem in final_scores[:self.n_min_to_keep]} target_idx_list = Set.union(min_idx_list, max_idx_list, self.idx_to_keep) for idx in target_idx_list: for step, score in self.index_history[mode][idx]: index_summaries[idx].history[mode][self.metric_key][step] = score self.system.add_graph(self.outputs[0], list(index_summaries.values())) # So traceability can draw it data.write_without_log(self.outputs[0], list(index_summaries.values()))
def on_epoch_end(self, data: Data) -> None: self.y_true = np.squeeze(np.stack(self.y_true)) self.y_pred = np.stack(self.y_pred) calibrator = cal.PlattBinnerMarginalCalibrator(num_calibration=len( self.y_true), num_bins=10) calibrator.train_calibration(probs=self.y_pred, labels=self.y_true) if self.save_path: if not self.save_key or (self.save_key and to_number(data[self.save_key]) == 0): with open(self.save_path, 'wb') as f: dill.dump(calibrator.calibrate, file=f) print( f"FastEstimator-PBMCalibrator: Calibrator written to {self.save_path}" ) data.write_without_log(self.outputs[0], calibrator.calibrate(self.y_pred))
def on_end(self, data: Data) -> None: self.system.add_graph( self.outputs[0], list(self.label_summaries.values())) # So traceability can draw it data.write_without_log(self.outputs[0], list(self.label_summaries.values()))
def on_epoch_end(self, data: Data) -> None: mode = self.system.mode if self.n_found[mode] > 0: if self.n_required[mode] > 0: # We are keeping a user-specified number of samples self.samples[mode] = { key: concat(val)[:self.n_required[mode]] for key, val in self.samples[mode].items() } else: # We are keeping one batch of data self.samples[mode] = { key: val[0] for key, val in self.samples[mode].items() } # even if you haven't found n_required samples, you're at end of epoch so no point trying to collect more self.n_found[mode] = 0 self.n_required[mode] = 0 masks = self.salnet.get_masks(self.samples[mode]) smoothed, integrated, smint = {}, {}, {} if self.smoothing: smoothed = self.salnet.get_smoothed_masks(self.samples[mode], nsamples=self.smoothing) if self.integrating: if isinstance(self.integrating, Tuple): n_integration, n_smoothing = self.integrating else: n_integration = self.integrating n_smoothing = self.smoothing integrated = self.salnet.get_integrated_masks( self.samples[mode], nsamples=n_integration) if n_smoothing: smint = self.salnet.get_smoothed_masks( self.samples[mode], nsamples=n_smoothing, nintegration=n_integration) # Arrange the outputs args = {} if self.class_key: classes = self.samples[mode][self.class_key] if self.label_mapping: classes = np.array([ self.label_mapping[clazz] for clazz in to_number(squeeze(classes)) ]) args[self.class_key] = classes for key in self.model_outputs: classes = masks[key] if self.label_mapping: classes = np.array([ self.label_mapping[clazz] for clazz in to_number(squeeze(classes)) ]) args[key] = classes sal = smint or integrated or smoothed or masks for key, val in self.samples[mode].items(): if key is not self.class_key: args[key] = val # Create a linear combination of the original image, the saliency mask, and the product of the two in # order to highlight regions of importance min_val = reduce_min(val) diff = reduce_max(val) - min_val for outkey in self.outputs: args["{} {}".format( key, outkey)] = (0.3 * (sal[outkey] * (val - min_val) + min_val) + 0.3 * val + 0.4 * sal[outkey] * diff + min_val) for key in self.outputs: args[key] = masks[key] if smoothed: args["Smoothed {}".format(key)] = smoothed[key] if integrated: args["Integrated {}".format(key)] = integrated[key] if smint: args["SmInt {}".format(key)] = smint[key] result = ImgData(colormap="inferno", **args) data.write_without_log(self.outputs[0], result)
def on_epoch_end(self, data: Data) -> None: # Keep only the user-specified number of samples images = concat(self.images)[:self.n_samples or self.n_found] _, height, width = get_image_dims(images) grads = to_number(concat(self.grads)[:self.n_samples or self.n_found]) if tf.is_tensor(images): grads = np.moveaxis(grads, source=-1, destination=1) # grads should be channel first args = {} labels = None if not self.labels else concat( self.labels)[:self.n_samples or self.n_found] if labels is not None: if len(labels.shape) > 1: labels = argmax(labels, axis=-1) if self.label_mapping: labels = np.array([ self.label_mapping[clazz] for clazz in to_number(squeeze(labels)) ]) args[self.true_label_key] = labels preds = None if not self.preds else concat( self.preds)[:self.n_samples or self.n_found] if preds is not None: if len(preds.shape) > 1: preds = argmax(preds, axis=-1) if self.label_mapping: preds = np.array([ self.label_mapping[clazz] for clazz in to_number(squeeze(preds)) ]) args[self.pred_label_key] = preds args[self.image_key] = images # Clear memory self._reset() # Make the image # TODO: In future maybe allow multiple different grads to have side-by-side comparisons of classes components = [np.mean(grads, axis=1)] components = [np.maximum(component, 0) for component in components] masks = [] for component_batch in components: img_batch = [] for img in component_batch: img = cv2.resize(img, (width, height)) img = img - np.min(img) img = img / np.max(img) img = cv2.cvtColor( cv2.applyColorMap(np.uint8(255 * img), cv2.COLORMAP_JET), cv2.COLOR_BGR2RGB) img = np.float32(img) / 255 img_batch.append(img) img_batch = np.array(img_batch, dtype=np.float32) # Switch to channel first for pytorch if isinstance(images, torch.Tensor): img_batch = np.moveaxis(img_batch, source=-1, destination=1) masks.append(img_batch) components = [ images + mask for mask in masks ] # This seems to work even if the image is 1 channel instead of 3 components = [image / reduce_max(image) for image in components] for elem in components: args[self.grad_key] = elem result = ImgData(**args) data.write_without_log(self.outputs[0], result)
def on_epoch_end(self, data: Data) -> None: # Keep only the user-specified number of samples images = concat(self.images)[:self.n_samples or self.n_found] _, height, width = get_image_dims(images) activations = to_number( concat(self.activations)[:self.n_samples or self.n_found]) if tf.is_tensor(images): activations = np.moveaxis( activations, source=-1, destination=1) # Activations should be channel first args = {} labels = None if not self.labels else concat( self.labels)[:self.n_samples or self.n_found] if labels is not None: if len(labels.shape) > 1: labels = argmax(labels, axis=-1) if self.label_mapping: labels = np.array([ self.label_mapping[clazz] for clazz in to_number(squeeze(labels)) ]) args[self.true_label_key] = labels preds = None if not self.preds else concat( self.preds)[:self.n_samples or self.n_found] if preds is not None: if len(preds.shape) > 1: preds = argmax(preds, axis=-1) if self.label_mapping: preds = np.array([ self.label_mapping[clazz] for clazz in to_number(squeeze(preds)) ]) args[self.pred_label_key] = preds args[self.image_key] = images # Clear memory self._reset() # Make the image n_components, batch_component_image = self._project_2d(activations) components = [] # component x image (batch x image) for component_idx in range(n_components): batch = [] for base_image, component_image in zip(images, batch_component_image): if len(component_image) > component_idx: mask = component_image[component_idx] mask = cv2.resize(mask, (width, height)) mask = mask - np.min(mask) mask = mask / np.max(mask) mask = cv2.cvtColor( cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET), cv2.COLOR_BGR2RGB) mask = np.float32(mask) / 255 # switch to channel first for pytorch if isinstance(base_image, torch.Tensor): mask = np.moveaxis(mask, source=-1, destination=1) new_image = base_image + mask new_image = new_image / reduce_max(new_image) else: # There's no component for this image, so display an empty image here new_image = np.ones_like(base_image) batch.append(new_image) components.append(np.array(batch, dtype=np.float32)) for idx, elem in enumerate(components): args[f"Component {idx}"] = elem result = ImgData(**args) data.write_without_log(self.outputs[0], result)