Пример #1
0
 def on_epoch_end(self, data: Data) -> None:
     if self.system.mode == 'train' and self.histogram_freq.freq and not self.histogram_freq.is_step and \
             self.system.epoch_idx % self.histogram_freq.freq == 0:
         self.writer.write_weights(mode=self.system.mode,
                                   models=self.system.network.models,
                                   step=self.system.global_step,
                                   visualize=self.paint_weights)
     # Write out any embeddings which were aggregated over batches
     for name, val_list in self.collected_embeddings.items():
         embeddings = None if any(
             x[0] is None
             for x in val_list) else concat([x[0] for x in val_list])
         labels = None if any(
             x[1] is None
             for x in val_list) else concat([x[1] for x in val_list])
         imgs = None if any(
             x[2] is None
             for x in val_list) else concat([x[2] for x in val_list])
         self.writer.write_embeddings(mode=self.system.mode,
                                      step=self.system.global_step,
                                      embeddings=[(name, embeddings, labels,
                                                   imgs)])
     self.collected_embeddings.clear()
     # Get any embeddings which were generated externally on epoch end
     if self.embedding_freq.freq and (self.embedding_freq.is_step
                                      or self.system.epoch_idx %
                                      self.embedding_freq.freq == 0):
         self.writer.write_embeddings(
             mode=self.system.mode,
             step=self.system.global_step,
             embeddings=filter(
                 lambda x: x[1] is not None,
                 map(
                     lambda t:
                     (t[0], data.get(t[0]), data.get(t[1]), data.get(t[2])),
                     self.write_embeddings)))
     if self.update_freq.freq and (self.update_freq.is_step
                                   or self.system.epoch_idx %
                                   self.update_freq.freq == 0):
         self.writer.write_scalars(mode=self.system.mode,
                                   step=self.system.global_step,
                                   scalars=filter(lambda x: is_number(x[1]),
                                                  data.items()))
         self.writer.write_images(mode=self.system.mode,
                                  step=self.system.global_step,
                                  images=filter(
                                      lambda x: x[1] is not None,
                                      map(lambda y: (y, data.get(y)),
                                          self.write_images)))
Пример #2
0
    def on_epoch_end(self, data: Data) -> None:
        mode = self.system.mode
        if self.n_found[mode] > 0:
            if self.n_required[mode] > 0:
                # We are keeping a user-specified number of samples
                self.samples[mode] = {
                    key: concat(val)[:self.n_required[mode]]
                    for key, val in self.samples[mode].items()
                }
            else:
                # We are keeping one batch of data
                self.samples[mode] = {
                    key: val[0]
                    for key, val in self.samples[mode].items()
                }
            # even if you haven't found n_required samples, you're at end of epoch so no point trying to collect more
            self.n_found[mode] = 0
            self.n_required[mode] = 0

        masks = self.salnet.get_masks(self.samples[mode])
        smoothed, integrated, smint = {}, {}, {}
        if self.smoothing:
            smoothed = self.salnet.get_smoothed_masks(self.samples[mode],
                                                      nsamples=self.smoothing)
        if self.integrating:
            if isinstance(self.integrating, Tuple):
                n_integration, n_smoothing = self.integrating
            else:
                n_integration = self.integrating
                n_smoothing = self.smoothing
            integrated = self.salnet.get_integrated_masks(
                self.samples[mode], nsamples=n_integration)
            if n_smoothing:
                smint = self.salnet.get_smoothed_masks(
                    self.samples[mode],
                    nsamples=n_smoothing,
                    nintegration=n_integration)

        # Arrange the outputs
        args = {}
        if self.class_key:
            classes = self.samples[mode][self.class_key]
            if self.label_mapping:
                classes = np.array([
                    self.label_mapping[clazz]
                    for clazz in to_number(squeeze(classes))
                ])
            args[self.class_key] = classes
        for key in self.model_outputs:
            classes = masks[key]
            if self.label_mapping:
                classes = np.array([
                    self.label_mapping[clazz]
                    for clazz in to_number(squeeze(classes))
                ])
            args[key] = classes
        sal = smint or integrated or smoothed or masks
        for key, val in self.samples[mode].items():
            if key is not self.class_key:
                args[key] = val
                # Create a linear combination of the original image, the saliency mask, and the product of the two in
                # order to highlight regions of importance
                min_val = reduce_min(val)
                diff = reduce_max(val) - min_val
                for outkey in self.outputs:
                    args["{} {}".format(
                        key, outkey)] = (0.3 * (sal[outkey] *
                                                (val - min_val) + min_val) +
                                         0.3 * val + 0.4 * sal[outkey] * diff +
                                         min_val)
        for key in self.outputs:
            args[key] = masks[key]
            if smoothed:
                args["Smoothed {}".format(key)] = smoothed[key]
            if integrated:
                args["Integrated {}".format(key)] = integrated[key]
            if smint:
                args["SmInt {}".format(key)] = smint[key]
        result = ImgData(colormap="inferno", **args)

        data.write_without_log(self.outputs[0], result)
Пример #3
0
    def on_epoch_end(self, data: Data) -> None:
        # Keep only the user-specified number of samples
        images = concat(self.images)[:self.n_samples or self.n_found]
        _, height, width = get_image_dims(images)
        grads = to_number(concat(self.grads)[:self.n_samples or self.n_found])
        if tf.is_tensor(images):
            grads = np.moveaxis(grads, source=-1,
                                destination=1)  # grads should be channel first
        args = {}
        labels = None if not self.labels else concat(
            self.labels)[:self.n_samples or self.n_found]
        if labels is not None:
            if len(labels.shape) > 1:
                labels = argmax(labels, axis=-1)
            if self.label_mapping:
                labels = np.array([
                    self.label_mapping[clazz]
                    for clazz in to_number(squeeze(labels))
                ])
            args[self.true_label_key] = labels
        preds = None if not self.preds else concat(
            self.preds)[:self.n_samples or self.n_found]
        if preds is not None:
            if len(preds.shape) > 1:
                preds = argmax(preds, axis=-1)
            if self.label_mapping:
                preds = np.array([
                    self.label_mapping[clazz]
                    for clazz in to_number(squeeze(preds))
                ])
            args[self.pred_label_key] = preds
        args[self.image_key] = images
        # Clear memory
        self._reset()
        # Make the image
        # TODO: In future maybe allow multiple different grads to have side-by-side comparisons of classes
        components = [np.mean(grads, axis=1)]
        components = [np.maximum(component, 0) for component in components]
        masks = []
        for component_batch in components:
            img_batch = []
            for img in component_batch:
                img = cv2.resize(img, (width, height))
                img = img - np.min(img)
                img = img / np.max(img)
                img = cv2.cvtColor(
                    cv2.applyColorMap(np.uint8(255 * img), cv2.COLORMAP_JET),
                    cv2.COLOR_BGR2RGB)
                img = np.float32(img) / 255
                img_batch.append(img)
            img_batch = np.array(img_batch, dtype=np.float32)
            # Switch to channel first for pytorch
            if isinstance(images, torch.Tensor):
                img_batch = np.moveaxis(img_batch, source=-1, destination=1)
            masks.append(img_batch)

        components = [
            images + mask for mask in masks
        ]  # This seems to work even if the image is 1 channel instead of 3
        components = [image / reduce_max(image) for image in components]

        for elem in components:
            args[self.grad_key] = elem

        result = ImgData(**args)
        data.write_without_log(self.outputs[0], result)
Пример #4
0
    def on_epoch_end(self, data: Data) -> None:
        # Keep only the user-specified number of samples
        images = concat(self.images)[:self.n_samples or self.n_found]
        _, height, width = get_image_dims(images)
        activations = to_number(
            concat(self.activations)[:self.n_samples or self.n_found])
        if tf.is_tensor(images):
            activations = np.moveaxis(
                activations, source=-1,
                destination=1)  # Activations should be channel first
        args = {}
        labels = None if not self.labels else concat(
            self.labels)[:self.n_samples or self.n_found]
        if labels is not None:
            if len(labels.shape) > 1:
                labels = argmax(labels, axis=-1)
            if self.label_mapping:
                labels = np.array([
                    self.label_mapping[clazz]
                    for clazz in to_number(squeeze(labels))
                ])
            args[self.true_label_key] = labels
        preds = None if not self.preds else concat(
            self.preds)[:self.n_samples or self.n_found]
        if preds is not None:
            if len(preds.shape) > 1:
                preds = argmax(preds, axis=-1)
            if self.label_mapping:
                preds = np.array([
                    self.label_mapping[clazz]
                    for clazz in to_number(squeeze(preds))
                ])
            args[self.pred_label_key] = preds
        args[self.image_key] = images
        # Clear memory
        self._reset()
        # Make the image
        n_components, batch_component_image = self._project_2d(activations)
        components = []  # component x image (batch x image)
        for component_idx in range(n_components):
            batch = []
            for base_image, component_image in zip(images,
                                                   batch_component_image):
                if len(component_image) > component_idx:
                    mask = component_image[component_idx]
                    mask = cv2.resize(mask, (width, height))
                    mask = mask - np.min(mask)
                    mask = mask / np.max(mask)
                    mask = cv2.cvtColor(
                        cv2.applyColorMap(np.uint8(255 * mask),
                                          cv2.COLORMAP_JET), cv2.COLOR_BGR2RGB)
                    mask = np.float32(mask) / 255
                    # switch to channel first for pytorch
                    if isinstance(base_image, torch.Tensor):
                        mask = np.moveaxis(mask, source=-1, destination=1)
                    new_image = base_image + mask
                    new_image = new_image / reduce_max(new_image)
                else:
                    # There's no component for this image, so display an empty image here
                    new_image = np.ones_like(base_image)
                batch.append(new_image)
            components.append(np.array(batch, dtype=np.float32))

        for idx, elem in enumerate(components):
            args[f"Component {idx}"] = elem

        result = ImgData(**args)
        data.write_without_log(self.outputs[0], result)