Ejemplo n.º 1
0
    def __init__(
            self,
            ax: mpl.axes.Axes,
            layer: LayerData,
            cmaps: dict,
            vlims: dict,
            contours: dict,
            extent, interpolation, mask=None):
        self.ax = ax
        self._meas = layer.y.info.get('meas')
        self._contours = layer.contour_plot_args(contours)
        self._data = self._data_from_ndvar(layer.y)
        self._extent = extent
        self._mask = mask

        if layer.plot_type == PlotType.IMAGE:
            kwargs = layer.im_plot_args(vlims, cmaps)
            self.im = ax.imshow(self._data, origin='lower', aspect=self._aspect, extent=extent, interpolation=interpolation, **kwargs)
            if mask is not None:
                self.im.set_clip_path(mask)
            self._cmap = kwargs['cmap']
            self.vmin, self.vmax = self.im.get_clim()
        elif layer.plot_type == PlotType.CONTOUR:
            self.im = None
            self.vmin = self.vmax = None
        else:
            raise RuntimeError(f"layer of type {layer.plot_type}")

        # draw flexible parts
        self._contour_h = None
        self._draw_contours()
Ejemplo n.º 2
0
    def __init__(
            self,
            ax: mpl.axes.Axes,
            layer: LayerData,
            cmaps: dict,
            vlims: dict,
            contours: dict,
            extent, interpolation, mask=None):
        self.ax = ax
        self._meas = layer.y.info.get('meas')
        self._contours = layer.contour_plot_args(contours)
        self._data = self._data_from_ndvar(layer.y)
        self._extent = extent
        self._mask = mask

        if layer.plot_type == PlotType.IMAGE:
            kwargs = layer.im_plot_args(vlims, cmaps)
            self.im = ax.imshow(self._data, origin='lower', aspect=self._aspect, extent=extent, interpolation=interpolation, **kwargs)
            if mask is not None:
                self.im.set_clip_path(mask)
            self._cmap = kwargs['cmap']
            self.vmin, self.vmax = self.im.get_clim()
        elif layer.plot_type == PlotType.CONTOUR:
            self.im = None
            self.vmin = self.vmax = None
        else:
            raise RuntimeError(f"layer of type {layer.plot_type}")

        # draw flexible parts
        self._contour_h = None
        self._draw_contours()
Ejemplo n.º 3
0
def display_segmented_image(y: np.ndarray,
                            threshold: float = 0.5,
                            input_image: np.ndarray = None,
                            alpha_input_image: float = 0.2,
                            title: str = '',
                            ax: matplotlib.axes.Axes = None) -> None:
    """Display segemented image.

    This function displays the image where each class is shown in particular color.
    This is useful for getting a rapid view of the performance of the model
    on a few examples.

    Parameters:
        y: The array containing the prediction.
            Must be of shape (image_shape, num_classes)
        threshold: The threshold used on the predictions.
        input_image: If provided, display the input image in black.
        alpha_input_image: If an input_image is provided, the transparency of
            the input_image.
    """
    ax = ax or plt.gca()

    base_array = np.ones((y.shape[0], y.shape[1], 3)) * 1
    legend_handles = []

    for i in range(y.shape[-1]):
        # Retrieve a color (without the transparency value).
        colour = plt.cm.jet(i / y.shape[-1])[:-1]
        base_array[y[..., i] > threshold] = colour
        legend_handles.append(mpatches.Patch(color=colour, label=str(i)))

    # plt.figure(figsize=figsize)
    ax.imshow(base_array)
    ax.legend(handles=legend_handles, bbox_to_anchor=(1, 1), loc='upper left')
    ax.set_yticks([])
    ax.set_xticks([])
    ax.set_title(title)

    if input_image is not None:
        ax.imshow(input_image[..., 0],
                  cmap=plt.cm.binary,
                  alpha=alpha_input_image)

    if not ax:
        plt.show()
Ejemplo n.º 4
0
def show(image, ax: mpl.axes.Axes = None) -> mpl.axes.Axes:
    """Show an image

    Args:
        ax: Axis on which to show the image

    Returns:
        An axes object
    """
    if ax is None:
        ax = plt
    if len(image.shape) == 3 and image.shape[2] >= 3:
        return ax.imshow(image)
    if len(image.shape) == 3 and image.shape[2] == 1:
        return ax.imshow(image[:, :, 0])
    if len(image.shape) == 2:
        return ax.imshow(image)
    raise ValueError("Incorrect dimensions for image data.")
Ejemplo n.º 5
0
def show_4d_images(fig: matplotlib.figure.Figure,
                   ax: matplotlib.axes.Axes,
                   image: np.array,
                   est_cor_idx: list = None,
                   img_name: str = None):
    """
    Lists all the files given a root folder.
    Args:
        fig (matplotlib.figure.Figure): Figure object
        ax (matplotlib.axes.Axes): Axes object
        image (np.array): 4D image tensor to be shown
        est_cor_idx (list): List containing the percentage of corruption
                           for each time and z-axes.
    """
    assert len(image.shape) == 4, "Image's tensor rank is {}, not 4".format(
        len(image.shape))

    axcolor = 'lightgoldenrodyellow'
    axz = plt.axes([0.15, 0.05, 0.65, 0.03], facecolor=axcolor)
    axt = plt.axes([0.15, 0.10, 0.65, 0.03], facecolor=axcolor)

    zpos = Slider(axz, 'Z Axis', 1, image.shape[-2], valfmt="%d")
    tpos = Slider(axt, 'Time Axis', 1, image.shape[-1], valfmt="%d")

    pos_z = 0
    pos_t = 0

    im = ax.imshow(image[:, :, pos_z, pos_t])
    if est_cor_idx:
        ax.set_title("ID: " + img_name +
                     " Estimated Corruption: {:.5}".format(str(est_cor_idx)))
    else:
        ax.set_title("ID: " + img_name)

    def update(val):
        pos_z = int(zpos.val)
        pos_t = int(tpos.val)

        fig.canvas.draw_idle()

        im.set_data(image[:, :, pos_z - 1, pos_t - 1])

        if est_cor_idx:
            ax.set_title(
                "ID: " + img_name +
                " Estimated Corruption: {:.5}".format(str(est_cor_idx)))
        else:
            ax.set_title("ID: " + img_name)

    zpos.on_changed(update)
    tpos.on_changed(update)

    plt.show()
Ejemplo n.º 6
0
def make_plot(ax: matplotlib.axes.Axes,
              img: np.ndarray,
              mask: np.ndarray,
              cmap_names: List[str] = ['rainbow'],
              classes: List[int] = None):

    ax.imshow(img)
    if classes:  # For the overall segmentation map
        for class_, cmap_name in zip(classes, cmap_names):
            cmap = eval(f'plt.cm.{cmap_name}')
            new_mask = mask.copy()
            new_mask[mask == class_] = 1
            new_mask[mask != class_] = 0
            alphas = Normalize(0, .3, clip=True)(new_mask)
            alphas = np.clip(alphas, 0., 0.5)  # alpha value clipped at the bottom at .4
            colors = Normalize()(new_mask)
            colors = cmap(colors)
            colors[..., -1] = alphas
            ax.imshow(colors, cmap=cmap)  # interpolation='none'
    else:  # For probability maps
        new_mask = mask.copy()
        new_mask[mask == 255] = 0
        cmap = eval(f'plt.cm.{cmap_names[0]}')
        alphas = Normalize(0, .3, clip=True)(new_mask)
        alphas = np.clip(alphas, 0., 0.5)  # alpha value clipped at the bottom at .4
        colors = Normalize()(new_mask)
        colors = cmap(colors)
        colors[..., -1] = alphas
        ax.imshow(colors, cmap=cmap)  # interpolation='none'
Ejemplo n.º 7
0
def display_grayscale_array(array: np.ndarray,
                            title: str = '',
                            ax: matplotlib.axes.Axes = None) -> None:
    """Display the grayscale input image.

    Parameters:
        image: This can be either an input digit from MNIST of a input image
            from the extended dataset.
        title: If provided, this will be added as title of the plot.
    """
    ax = ax or plt.gca()

    if len(array.shape) == 3:
        array = array[..., 0]

    ax.imshow(array, cmap=plt.cm.binary)
    ax.axes.set_yticks([])
    ax.axes.set_xticks([])

    if title:
        ax.set_title(title)

    if not ax:
        plt.show()
Ejemplo n.º 8
0
 def plot(self, ax: mpl.axes.Axes):
     """Plot the styled object onto a matplotlib axes."""
     img_height = int(abs(ax.get_xlim()[1] - ax.get_xlim()[0]))
     img_width = int(abs(ax.get_ylim()[1] - ax.get_ylim()[0]))
     array = self.pad_to(img_width, img_height)
     ax.imshow(array, cmap=self.cmap, alpha=self.style.get("alpha", 0.5))