Exemplo n.º 1
0
 def on_batch_end(self, data: Data) -> None:
     y_true, y_pred = to_number(data[self.true_key]), to_number(data[self.pred_key])
     if y_true.shape[-1] > 1 and y_true.ndim > 1:
         y_true = np.argmax(y_true, axis=-1)
     if y_pred.shape[-1] > 1:
         y_pred = np.argmax(y_pred, axis=-1)
     else:
         y_pred = np.round(y_pred)
     assert y_pred.size == y_true.size
     self.correct += np.sum(y_pred.ravel() == y_true.ravel())
     self.total += len(y_pred.ravel())
Exemplo n.º 2
0
 def on_batch_end(self, data: Data) -> None:
     y_true, y_pred = to_number(data[self.true_key]), to_number(data[self.pred_key])
     self.binary_classification = y_pred.shape[-1] == 1
     if y_true.shape[-1] > 1 and y_true.ndim > 1:
         y_true = np.argmax(y_true, axis=-1)
     if y_pred.shape[-1] > 1:
         y_pred = np.argmax(y_pred, axis=-1)
     else:
         y_pred = np.round(y_pred)
     assert y_pred.size == y_true.size
     self.y_pred.extend(y_pred.ravel())
     self.y_true.extend(y_true.ravel())
Exemplo n.º 3
0
    def on_batch_end(self, data: Data) -> None:
        y_true, y_pred = to_number(data[self.true_key]), to_number(
            data[self.pred_key])
        batch_size = y_true.shape[0]
        y_true, y_pred = y_true.reshape((batch_size, -1)), y_pred.reshape(
            (batch_size, -1))

        prediction_label = (y_pred >= self.threshold).astype(np.int32)

        intersection = np.sum(y_true * prediction_label, axis=-1)
        area_sum = np.sum(y_true, axis=-1) + np.sum(prediction_label, axis=-1)
        dice = (2. * intersection + self.smooth) / (area_sum + self.smooth)
        self.dice.extend(list(dice))
Exemplo n.º 4
0
    def write_embeddings(
        self,
        mode: str,
        embeddings: Iterable[Tuple[str, Tensor, Optional[List[Any]],
                                   Optional[Tensor]]],
        step: int,
    ):
        """Write embeddings (like UMAP) to TensorBoard.

        Args:
            mode: The current mode of execution ('train', 'eval', 'test', 'infer').
            embeddings: A collection of quadruplets like [("key", <features>, [<label1>, ...], <label_images>)].
                Features are expected to be batched, and if labels and/or label images are provided they should have the
                same batch dimension as the features.
            step: The current training step.
        """
        for key, features, labels, label_imgs in embeddings:
            flat = to_number(reshape(features, [features.shape[0], -1]))
            if not isinstance(label_imgs, (torch.Tensor, type(None))):
                label_imgs = to_tensor(label_imgs, 'torch')
                if len(label_imgs.shape) == 4:
                    label_imgs = permute(label_imgs, [0, 3, 1, 2])
            self.summary_writers[mode].add_embedding(mat=flat,
                                                     metadata=labels,
                                                     label_img=label_imgs,
                                                     tag=key,
                                                     global_step=step)
Exemplo n.º 5
0
    def get_smoothed_masks(
            self,
            batch: Dict[str, Any],
            stdev_spread: float = .15,
            nsamples: int = 25,
            nintegration: Optional[int] = None,
            magnitude: bool = True) -> Dict[str, Union[Tensor, np.ndarray]]:
        """Generates smoothed greyscale saliency mask(s) from a given `batch` of data.

        Args:
            batch: An input batch of data.
            stdev_spread: Amount of noise to add to the input, as fraction of the total spread (x_max - x_min).
            nsamples: Number of samples to average across to get the smooth gradient.
            nintegration: Number of samples to compute when integrating (None to disable).
            magnitude: If true, computes the sum of squares of gradients instead of just the sum.

        Returns:
            Greyscale saliency mask(s) smoothed via the SmoothGrad method.
        """
        # Shallow copy batch since we're going to modify its contents later
        batch = {key: val for key, val in batch.items()}
        model_inputs = [batch[ins] for ins in self.model_inputs]
        stdevs = [
            to_number(stdev_spread *
                      (reduce_max(ins) - reduce_min(ins))).item()
            for ins in model_inputs
        ]

        # Adding noise to the image might cause the max likelihood class value to change, so need to keep track of
        # which class we're comparing to
        response = self._get_mask(batch)
        for gather_key, output_key in zip(self.gather_keys,
                                          self.model_outputs):
            batch[gather_key] = response[output_key]

        if magnitude:
            for key in self.outputs:
                response[key] = response[key] * response[key]

        for _ in range(nsamples - 1):
            noisy_batch = {key: batch[key] for key in self.gather_keys}
            for idx, input_name in enumerate(self.model_inputs):
                noise = random_normal_like(model_inputs[idx], std=stdevs[idx])
                x_plus_noise = model_inputs[idx] + noise
                noisy_batch[input_name] = x_plus_noise
            grads_and_preds = self._get_mask(
                noisy_batch
            ) if not nintegration else self._get_integrated_masks(
                noisy_batch, nsamples=nintegration)
            for name in self.outputs:
                grad = grads_and_preds[name]
                if magnitude:
                    response[name] += grad * grad
                else:
                    response[name] += grad
        for key in self.outputs:
            grad = response[key]
            response[key] = self._convert_for_visualization(grad / nsamples)
        return response
Exemplo n.º 6
0
    def on_batch_end(self, data: Data) -> None:
        y_true, y_pred = to_number(data[self.true_key]), to_number(
            data[self.pred_key])
        if y_true.shape[-1] > 1 and y_true.ndim > 1:
            y_true = np.argmax(y_true, axis=-1)
        if y_pred.shape[-1] > 1:
            y_pred = np.argmax(y_pred, axis=-1)
        else:
            y_pred = np.round(y_pred)
        assert y_pred.size == y_true.size

        batch_confusion = confusion_matrix(y_true,
                                           y_pred,
                                           labels=list(
                                               range(0, self.num_classes)))

        if self.matrix is None:
            self.matrix = batch_confusion
        else:
            self.matrix += batch_confusion
Exemplo n.º 7
0
    def write_scalars(self, mode: str, scalars: Iterable[Tuple[str, Any]],
                      step: int) -> None:
        """Write summaries of scalars to TensorBoard.

        Args:
            mode: The current mode of execution ('train', 'eval', 'test', 'infer').
            scalars: A collection of pairs like [("key", val), ("key2", val2), ...].
            step: The current training step.
        """
        for key, val in scalars:
            self.summary_writers[mode].add_scalar(tag=key,
                                                  scalar_value=to_number(val),
                                                  global_step=step)
Exemplo n.º 8
0
    def write_images(self, mode: str, images: Iterable[Tuple[str, Any]],
                     step: int) -> None:
        """Write images to TensorBoard.

        Args:
            mode: The current mode of execution ('train', 'eval', 'test', 'infer').
            images: A collection of pairs like [("key", image1), ("key2", image2), ...].
            step: The current training step.
        """
        for key, img in images:
            if isinstance(img, ImgData):
                img = img.paint_figure()
            if isinstance(img, plt.Figure):
                self.summary_writers[mode].add_figure(tag=key,
                                                      figure=img,
                                                      global_step=step)
            else:
                self.summary_writers[mode].add_images(
                    tag=key,
                    img_tensor=to_number(img),
                    global_step=step,
                    dataformats='NCHW'
                    if isinstance(img, torch.Tensor) else 'NHWC')
Exemplo n.º 9
0
    def _print_message(self,
                       header: str,
                       data: Data,
                       log_epoch: bool = False) -> None:
        """Print a log message to the screen, and record the `data` into the `system` summary.

        Args:
            header: The prefix for the log message.
            data: A collection of data to be recorded.
            log_epoch: Whether epoch information should be included in the log message.
        """
        log_message = header
        if log_epoch:
            log_message += "epoch: {}; ".format(self.system.epoch_idx)
            self.system.write_summary('epoch', self.system.epoch_idx)
        for key, val in data.read_logs().items():
            val = to_number(val)
            self.system.write_summary(key, val)
            if val.size > 1:
                log_message += "\n{}:\n{};".format(
                    key, np.array2string(val, separator=','))
            else:
                log_message += "{}: {}; ".format(key, str(val))
        print(log_message)
Exemplo n.º 10
0
    def on_batch_end(self, data: Data):
        # begin of reading det and gt
        pred = to_number(
            data[self.pred_key])  # pred is [batch, nms_max_outputs, 7]
        pred = self._reshape_pred(pred)

        gt = to_number(
            data[self.true_key]
        )  # gt is np.array (batch, box, 5), box dimension is padded
        gt = self._reshape_gt(gt)

        ground_truth_bb = []
        for gt_item in gt:
            idx_in_batch, x1, y1, w, h, label = gt_item
            label = int(label)
            id_epoch = self._get_id_in_epoch(idx_in_batch)
            self.batch_image_ids.append(id_epoch)
            self.image_ids.append(id_epoch)
            tmp_dict = {
                'idx': id_epoch,
                'x1': x1,
                'y1': y1,
                'w': w,
                'h': h,
                'label': label
            }
            ground_truth_bb.append(tmp_dict)

        predicted_bb = []
        for pred_item in pred:
            idx_in_batch, x1, y1, w, h, label, score = pred_item
            label = int(label)
            id_epoch = self.ids_batch_to_epoch[idx_in_batch]
            self.image_ids.append(id_epoch)
            tmp_dict = {
                'idx': id_epoch,
                'x1': x1,
                'y1': y1,
                'w': w,
                'h': h,
                'label': label,
                'score': score
            }
            predicted_bb.append(tmp_dict)

        for dict_elem in ground_truth_bb:
            self.gt[dict_elem['idx'], dict_elem['label']].append(dict_elem)

        for dict_elem in predicted_bb:
            self.det[dict_elem['idx'], dict_elem['label']].append(dict_elem)
        # end of reading det and gt

        # compute iou matrix, matrix index is (img_id, cat_id), each element in matrix has shape (num_det, num_gt)
        self.ious = {(img_id, cat_id):
                     self.compute_iou(self.det[img_id, cat_id],
                                      self.gt[img_id, cat_id])
                     for img_id in self.batch_image_ids
                     for cat_id in self.categories}

        for cat_id in self.categories:
            for img_id in self.batch_image_ids:
                self.evalimgs[(cat_id,
                               img_id)] = self.evaluate_img(cat_id, img_id)
Exemplo n.º 11
0
def show_image(im: Union[np.ndarray, Tensor],
               axis: plt.Axes = None,
               fig: plt.Figure = None,
               title: Optional[str] = None,
               color_map: str = "inferno",
               stack_depth: int = 0) -> Optional[plt.Figure]:
    """Plots a given image onto an axis. The repeated invocation of this function will cause figure plot overlap.

    If `im` is 2D and the length of second dimension are 4 or 5, it will be viewed as bounding box data (x0, y0, w, h,
    <label>).

    ```python
    boxes = np.array([[0, 0, 10, 20, "apple"],
                      [10, 20, 30, 50, "dog"],
                      [40, 70, 200, 200, "cat"],
                      [0, 0, 0, 0, "not_shown"],
                      [0, 0, -10, -20, "not_shown2"]])

    img = np.zeros((150, 150))
    fig, axis = plt.subplots(1, 1)
    fe.util.show_image(img, fig=fig, axis=axis) # need to plot image first
    fe.util.show_image(boxes, fig=fig, axis=axis)
    ```

    Users can also directly plot text

    ```python
    fig, axis = plt.subplots(1, 1)
    fe.util.show_image("apple", fig=fig, axis=axis)
    ```

    Args:
        axis: The matplotlib axis to plot on, or None for a new plot.
        fig: A reference to the figure to plot on, or None if new plot.
        im: The image (width X height) / bounding box / text to display.
        title: A title for the image.
        color_map: Which colormap to use for greyscale images.
        stack_depth: Multiple images can be drawn onto the same axis. When stack depth is greater than zero, the `im`
            will be alpha blended on top of a given axis.

    Returns:
        plotted figure. It will be the same object as user have provided in the argument.
    """
    if axis is None:
        fig, axis = plt.subplots(1, 1)
    axis.axis('off')
    # Compute width of axis for text font size
    bbox = axis.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
    width, height = bbox.width * fig.dpi, bbox.height * fig.dpi
    space = min(width, height)
    if not hasattr(im, 'shape') or len(im.shape) < 2:
        # text data
        im = to_number(im)
        if hasattr(im, 'shape') and len(im.shape) == 1:
            im = im[0]
        im = im.item()
        if isinstance(im, bytes):
            im = im.decode('utf8')
        text = "{}".format(im)
        axis.text(0.5,
                  0.5,
                  im,
                  ha='center',
                  transform=axis.transAxes,
                  va='center',
                  wrap=False,
                  family='monospace',
                  fontsize=min(45, space // len(text)))
    elif len(im.shape) == 2 and (im.shape[1] == 4 or im.shape[1] == 5):
        # Bounding Box Data. Should be (x0, y0, w, h, <label>)
        boxes = []
        im = to_number(im)
        color = ["m", "r", "c", "g", "y", "b"][stack_depth % 6]
        for box in im:
            # Unpack the box, which may or may not have a label
            x0 = int(box[0])
            y0 = int(box[1])
            width = int(box[2])
            height = int(box[3])
            label = None if len(box) < 5 else str(box[4])

            # Don't draw empty boxes, or invalid box
            if width <= 0 or height <= 0:
                continue
            r = Rectangle((x0, y0),
                          width=width,
                          height=height,
                          fill=False,
                          edgecolor=color,
                          linewidth=3)
            boxes.append(r)
            if label:
                axis.text(r.get_x() + 3,
                          r.get_y() + 3,
                          label,
                          ha='left',
                          va='top',
                          color=color,
                          fontsize=max(8, min(14, width // len(label))),
                          fontweight='bold',
                          family='monospace')
        pc = PatchCollection(boxes, match_original=True)
        axis.add_collection(pc)
    else:
        if isinstance(im, torch.Tensor) and len(im.shape) > 2:
            # Move channel first to channel last
            channels = list(range(len(im.shape)))
            channels.append(channels.pop(0))
            im = im.permute(*channels)
        # image data
        im = to_number(im)
        if np.issubdtype(im.dtype, np.integer):
            # im is already in int format
            im = im.astype(np.uint8)
        elif np.max(im) <= 1 and np.min(im) >= 0:  # im is [0,1]
            im = (im * 255).astype(np.uint8)
        elif np.min(im) >= -1 and np.max(im) <= 1:  # im is [-1, 1]
            im = ((im + 1) * 127.5).astype(np.uint8)
        else:  # im is in some arbitrary range, probably due to the Normalize Op
            ma = abs(
                np.max(im,
                       axis=tuple([i for i in range(len(im.shape) - 1)])
                       if len(im.shape) > 2 else None))
            mi = abs(
                np.min(im,
                       axis=tuple([i for i in range(len(im.shape) - 1)])
                       if len(im.shape) > 2 else None))
            im = (((im + mi) / (ma + mi)) * 255).astype(np.uint8)
        # matplotlib doesn't support (x,y,1) images, so convert them to (x,y)
        if len(im.shape) == 3 and im.shape[2] == 1:
            im = np.reshape(im, (im.shape[0], im.shape[1]))
        alpha = 1 if stack_depth == 0 else 0.3
        if len(im.shape) == 2:
            axis.imshow(im, cmap=plt.get_cmap(name=color_map), alpha=alpha)
        else:
            axis.imshow(im, alpha=alpha)
    if title is not None:
        axis.set_title(title,
                       fontsize=min(20, 1 + width // len(title)),
                       family='monospace')
    return fig
Exemplo n.º 12
0
    def on_epoch_end(self, data: Data) -> None:
        mode = self.system.mode
        if self.n_found[mode] > 0:
            if self.n_required[mode] > 0:
                # We are keeping a user-specified number of samples
                self.samples[mode] = {
                    key: concat(val)[:self.n_required[mode]]
                    for key, val in self.samples[mode].items()
                }
            else:
                # We are keeping one batch of data
                self.samples[mode] = {
                    key: val[0]
                    for key, val in self.samples[mode].items()
                }
            # even if you haven't found n_required samples, you're at end of epoch so no point trying to collect more
            self.n_found[mode] = 0
            self.n_required[mode] = 0

        masks = self.salnet.get_masks(self.samples[mode])
        smoothed, integrated, smint = {}, {}, {}
        if self.smoothing:
            smoothed = self.salnet.get_smoothed_masks(self.samples[mode],
                                                      nsamples=self.smoothing)
        if self.integrating:
            if isinstance(self.integrating, Tuple):
                n_integration, n_smoothing = self.integrating
            else:
                n_integration = self.integrating
                n_smoothing = self.smoothing
            integrated = self.salnet.get_integrated_masks(
                self.samples[mode], nsamples=n_integration)
            if n_smoothing:
                smint = self.salnet.get_smoothed_masks(
                    self.samples[mode],
                    nsamples=n_smoothing,
                    nintegration=n_integration)

        # Arrange the outputs
        args = {}
        if self.class_key:
            classes = self.samples[mode][self.class_key]
            if self.label_mapping:
                classes = np.array([
                    self.label_mapping[clazz]
                    for clazz in to_number(squeeze(classes))
                ])
            args[self.class_key] = classes
        for key in self.model_outputs:
            classes = masks[key]
            if self.label_mapping:
                classes = np.array([
                    self.label_mapping[clazz]
                    for clazz in to_number(squeeze(classes))
                ])
            args[key] = classes
        sal = smint or integrated or smoothed or masks
        for key, val in self.samples[mode].items():
            if key is not self.class_key:
                args[key] = val
                # Create a linear combination of the original image, the saliency mask, and the product of the two in
                # order to highlight regions of importance
                min_val = reduce_min(val)
                diff = reduce_max(val) - min_val
                for outkey in self.outputs:
                    args["{} {}".format(
                        key, outkey)] = (0.3 * (sal[outkey] *
                                                (val - min_val) + min_val) +
                                         0.3 * val + 0.4 * sal[outkey] * diff +
                                         min_val)
        for key in self.outputs:
            args[key] = masks[key]
            if smoothed:
                args["Smoothed {}".format(key)] = smoothed[key]
            if integrated:
                args["Integrated {}".format(key)] = integrated[key]
            if smint:
                args["SmInt {}".format(key)] = smint[key]
        result = ImgData(colormap="inferno", **args)

        data.write_without_log(self.outputs[0], result)