def get_ax_size(ax: plt.Axes):
     fig = ax.figure
     bbox = ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
     width, height = bbox.width, bbox.height
     width *= fig.dpi
     height *= fig.dpi
     return width, height
Example #2
0
def _axis_set_watermark_img(ax: plt.Axes, file: str, size: float = 0.25):
    """Burn the axis watermark into the image

    Args:
        ax (plt.Axes): Matplotlib axis to render to.  Defaults to None.
        file (str): The file.
        size (float): The size.  Defaults to None.
    """
    ### Load image
    datafile = cbook.get_sample_data(str(file), asfileobj=False)
    img = image.imread(datafile)
    #im[:, :, -1] = 0.5  # set the alpha channel

    # window aspect
    b = ax.get_window_extent()
    b = abs(b.height / b.width)
    # b = ax.get_tightbbox(ax.get_figure().canvas.get_renderer())
    # b = abs(b.height/b.width)

    # ALternative: fig.figimage
    # scale image: yscale = float(img.shape[1])/img.shape[0]
    # extent: (left, right, bottom, top)
    kw = dict(interpolation='gaussian', alpha=0.05, resample=True, zorder=-200)
    ax.imshow(img,
              extent=(1 - size * b, 1, 1 - size, 1),
              transform=ax.transAxes,
              **kw)
Example #3
0
def show_image(im: Union[np.ndarray, Tensor],
               axis: plt.Axes = None,
               fig: plt.Figure = None,
               title: Optional[str] = None,
               color_map: str = "inferno",
               stack_depth: int = 0) -> Optional[plt.Figure]:
    """Plots a given image onto an axis. The repeated invocation of this function will cause figure plot overlap.

    If `im` is 2D and the length of second dimension are 4 or 5, it will be viewed as bounding box data (x0, y0, w, h,
    <label>).

    ```python
    boxes = np.array([[0, 0, 10, 20, "apple"],
                      [10, 20, 30, 50, "dog"],
                      [40, 70, 200, 200, "cat"],
                      [0, 0, 0, 0, "not_shown"],
                      [0, 0, -10, -20, "not_shown2"]])

    img = np.zeros((150, 150))
    fig, axis = plt.subplots(1, 1)
    fe.util.show_image(img, fig=fig, axis=axis) # need to plot image first
    fe.util.show_image(boxes, fig=fig, axis=axis)
    ```

    Users can also directly plot text

    ```python
    fig, axis = plt.subplots(1, 1)
    fe.util.show_image("apple", fig=fig, axis=axis)
    ```

    Args:
        axis: The matplotlib axis to plot on, or None for a new plot.
        fig: A reference to the figure to plot on, or None if new plot.
        im: The image (width X height) / bounding box / text to display.
        title: A title for the image.
        color_map: Which colormap to use for greyscale images.
        stack_depth: Multiple images can be drawn onto the same axis. When stack depth is greater than zero, the `im`
            will be alpha blended on top of a given axis.

    Returns:
        plotted figure. It will be the same object as user have provided in the argument.
    """
    if axis is None:
        fig, axis = plt.subplots(1, 1)
    axis.axis('off')
    # Compute width of axis for text font size
    bbox = axis.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
    width, height = bbox.width * fig.dpi, bbox.height * fig.dpi
    space = min(width, height)
    if not hasattr(im, 'shape') or len(im.shape) < 2:
        # text data
        im = to_number(im)
        if hasattr(im, 'shape') and len(im.shape) == 1:
            im = im[0]
        im = im.item()
        if isinstance(im, bytes):
            im = im.decode('utf8')
        text = "{}".format(im)
        axis.text(0.5,
                  0.5,
                  im,
                  ha='center',
                  transform=axis.transAxes,
                  va='center',
                  wrap=False,
                  family='monospace',
                  fontsize=min(45, space // len(text)))
    elif len(im.shape) == 2 and (im.shape[1] == 4 or im.shape[1] == 5):
        # Bounding Box Data. Should be (x0, y0, w, h, <label>)
        boxes = []
        im = to_number(im)
        color = ["m", "r", "c", "g", "y", "b"][stack_depth % 6]
        for box in im:
            # Unpack the box, which may or may not have a label
            x0 = float(box[0])
            y0 = float(box[1])
            width = float(box[2])
            height = float(box[3])
            label = None if len(box) < 5 else str(box[4])

            # Don't draw empty boxes, or invalid box
            if width <= 0 or height <= 0:
                continue
            r = Rectangle((x0, y0),
                          width=width,
                          height=height,
                          fill=False,
                          edgecolor=color,
                          linewidth=3)
            boxes.append(r)
            if label:
                axis.text(r.get_x() + 3,
                          r.get_y() + 3,
                          label,
                          ha='left',
                          va='top',
                          color=color,
                          fontsize=max(8, min(14, width // len(label))),
                          fontweight='bold',
                          family='monospace')
        pc = PatchCollection(boxes, match_original=True)
        axis.add_collection(pc)
    else:
        if isinstance(im, torch.Tensor) and len(im.shape) > 2:
            # Move channel first to channel last
            channels = list(range(len(im.shape)))
            channels.append(channels.pop(0))
            im = im.permute(*channels)
        # image data
        im = to_number(im)
        im_max = np.max(im)
        im_min = np.min(im)
        if np.issubdtype(im.dtype, np.integer):
            # im is already in int format
            im = im.astype(np.uint8)
        elif 0 <= im_min <= im_max <= 1:  # im is [0,1]
            im = (im * 255).astype(np.uint8)
        elif -0.5 <= im_min < 0 < im_max <= 0.5:  # im is [-0.5, 0.5]
            im = ((im + 0.5) * 255).astype(np.uint8)
        elif -1 <= im_min < 0 < im_max <= 1:  # im is [-1, 1]
            im = ((im + 1) * 127.5).astype(np.uint8)
        else:  # im is in some arbitrary range, probably due to the Normalize Op
            ma = abs(
                np.max(im,
                       axis=tuple([i for i in range(len(im.shape) - 1)])
                       if len(im.shape) > 2 else None))
            mi = abs(
                np.min(im,
                       axis=tuple([i for i in range(len(im.shape) - 1)])
                       if len(im.shape) > 2 else None))
            im = (((im + mi) / (ma + mi)) * 255).astype(np.uint8)
        # matplotlib doesn't support (x,y,1) images, so convert them to (x,y)
        if len(im.shape) == 3 and im.shape[2] == 1:
            im = np.reshape(im, (im.shape[0], im.shape[1]))
        alpha = 1 if stack_depth == 0 else 0.3
        if len(im.shape) == 2:
            axis.imshow(im, cmap=plt.get_cmap(name=color_map), alpha=alpha)
        else:
            axis.imshow(im, alpha=alpha)
    if title is not None:
        axis.set_title(title,
                       fontsize=min(20, 1 + width // len(title)),
                       family='monospace')
    return fig
Example #4
0
def plot_single_rf(rf,
                   tlim: list or tuple or None = None,
                   ylim: list or tuple or None = None,
                   depth: np.ndarray or None = None,
                   ax: plt.Axes = None,
                   outputdir: str = None,
                   pre_fix: str = None,
                   post_fix: str = None,
                   format: str = 'pdf',
                   clean: bool = False,
                   std: np.ndarray = None,
                   flipxy: bool = False):
    """Creates plot of a single receiver function

    Parameters
    ----------
    rf : :class:`pyglimer.RFTrace`
        single receiver function trace
    tlim: list or tuple or None
        x axis time limits in seconds if type=='time' or depth in km if
        type==depth (len(list)==2).
        If `None` full trace is plotted.
        Default None.
    ylim: list or tuple or None
        y axis amplitude limits in. If `None` ± 1.05 absmax. Default None.
    depth: :class:`numpy.ndarray`
        1D array of depths
    ax : `matplotlib.pyplot.Axes`, optional
        Can define an axes to plot the RF into. Defaults to None.
        If None, new figure is created.
    outputdir : str, optional
        If set, saves a pdf of the plot to the directory.
        If None, plot will be shown instantly. Defaults to None.
    pre_fix : str, optional
        prepend filename
    post_fix : str, optional
        append to filename
    clean: bool, optional
        If True, clears out all axes and plots RF only.
        Defaults to False.
    std: np.ndarray, optional
            **Only if self.type == stastack**. Plots the upper and lower
            limit of the standard deviation in the plot. Provide the std
            as a numpy array (can be easily computed from the output of
            :meth:`~pyglimer.rf.create.RFStream.bootstrap`)
    flipxy: bool, optional
        Plot Depth/Time on the Y-Axis and amplitude on the x-axis. Defaults
        to False.

     Returns
    -------
    ax : `matplotlib.pyplot.Axes`
    """
    set_mpl_params()

    # Get figure/axes dimensions
    if ax is None:
        if flipxy:
            height, width = 8, 3
        else:
            width, height = 10, 2.5
        fig = plt.figure(figsize=(width, height))
        ax = plt.axes(zorder=9999999)
        axtmp = None
    else:
        fig = plt.gcf()
        bbox = ax.get_window_extent().transformed(
            fig.dpi_scale_trans.inverted())
        width, height = bbox.width, bbox.height
        axtmp = ax

    # The ratio ensures that the text
    # is perfectly distanced from top left/right corner
    ratio = width / height

    # Use times depending on phase and moveout correction
    ydata = rf.data
    if rf.stats.type == 'time':
        # Get times
        times = rf.times() - (rf.stats.onset - rf.stats.starttime)
        if rf.stats.phase[-1] == 'S':
            times = np.flip(times)
            ydata = np.flip(-rf.data)
    else:
        z = np.hstack(((np.arange(-10, 0, .1)), np.arange(0, maxz + res, res)))
        times = z

    # Plot stuff into axes
    if flipxy:
        if std is not None:
            ax.plot(ydata - std, times, 'k--', lw=0.75)
            ax.plot(ydata + std, times, 'k--', lw=0.75)
            ax.fill_betweenx(times,
                             0,
                             ydata,
                             where=ydata > 0,
                             interpolate=True,
                             color=(0.9, 0.2, 0.2),
                             alpha=.8)
            ax.fill_betweenx(times,
                             0,
                             ydata,
                             where=ydata < 0,
                             interpolate=True,
                             color=(0.2, 0.2, 0.7),
                             alpha=.8)
        else:
            ax.fill_betweenx(times,
                             0,
                             ydata,
                             where=ydata > 0,
                             interpolate=True,
                             color=(0.9, 0.2, 0.2),
                             alpha=.8)
            ax.fill_betweenx(times,
                             0,
                             ydata,
                             where=ydata < 0,
                             interpolate=True,
                             color=(0.2, 0.2, 0.7),
                             alpha=.8)
        ax.plot(ydata, times, 'k', lw=0.75)

        # Set limits
        if tlim is None:
            # don't really wanna see the stuff before
            ax.set_ylim(0, times[-1])
        else:
            ax.set_ylim(tlim)

        if ylim is None:
            absmax = 1.1 * np.max(np.abs(ydata))
            ax.set_xlim([-absmax, absmax])
        else:
            ax.set_xlim(ylim)
        ax.invert_yaxis()
    else:
        if std is not None:
            ax.plot(times, ydata - std, 'k--', lw=0.75)
            ax.plot(times, ydata + std, 'k--', lw=0.75)
            ax.fill_between(times,
                            0,
                            ydata,
                            where=ydata > 0,
                            interpolate=True,
                            color=(0.9, 0.2, 0.2),
                            alpha=.8)
            ax.fill_between(times,
                            0,
                            ydata,
                            where=ydata < 0,
                            interpolate=True,
                            color=(0.2, 0.2, 0.7),
                            alpha=.8)
        else:
            ax.fill_between(times,
                            0,
                            ydata,
                            where=ydata > 0,
                            interpolate=True,
                            color=(0.9, 0.2, 0.2),
                            alpha=.8)
            ax.fill_between(times,
                            0,
                            ydata,
                            where=ydata < 0,
                            interpolate=True,
                            color=(0.2, 0.2, 0.7),
                            alpha=.8)
        ax.plot(times, ydata, 'k', lw=0.75)

        # Set limits
        if tlim is None:
            ax.set_xlim(0, times[-1])
            # don't really wanna see the stuff before
        else:
            ax.set_xlim(tlim)

        if ylim is None:
            absmax = 1.1 * np.max(np.abs(ydata))
            ax.set_ylim([-absmax, absmax])
        else:
            ax.set_ylim(ylim)

    # Removes top/right axes spines. If you want the whole thing, comment
    # or remove
    remove_topright()

    # Plot RF only
    if clean:
        remove_all()
    else:
        if rf.stats.type == 'time':
            if flipxy:
                ax.set_ylabel("Conversion Time [s]", rotation=90)
            else:
                ax.set_xlabel("Conversion Time [s]")
        else:
            if flipxy:
                ax.set_ylabel("Conversion Depth [km]", rotation=90)
            else:
                ax.set_xlabel("Conversion Depth [km]")
        if flipxy:
            ax.set_xlabel("A    ", rotation=0)
        else:
            ax.set_ylabel("A    ", rotation=0)

        # Start time in station stack does not make sense
        if rf.stats.type == 'stastack':
            text = rf.get_id()
        else:
            text = rf.stats.starttime.isoformat(sep=" ") + "\n" + rf.get_id()
        ax.text(0.995,
                1.0 - 0.005 * ratio,
                text,
                transform=ax.transAxes,
                horizontalalignment="right",
                verticalalignment="top")

    # Only use tight layout if not part of plot.
    if axtmp is None:
        plt.tight_layout()

    # Outout the receiver function as pdf using
    # its station name and starttime

    if outputdir is not None:
        # Set pre and post fix
        if pre_fix is not None:
            pre_fix = pre_fix + "_"
        else:
            pre_fix = ""
        if post_fix is not None:
            post_fix = "_" + post_fix
        else:
            post_fix = ""

        # Get filename
        filename = os.path.join(
            outputdir, pre_fix + rf.get_id() + "_" +
            rf.stats.starttime.strftime('%Y%m%dT%H%M%S') + post_fix +
            f".{format}")
        plt.savefig(filename, format=format, transparent=True)
    else:
        plt.show()

    return ax
Example #5
0
def plot_single_rf(rf,
                   tlim: list or tuple or None = None,
                   ax: plt.Axes = None,
                   outputdir: str = None,
                   clean: bool = False):
    """Creates plot of a single receiver function

    Parameters
    ----------
    rf : :class:`pyglimer.RFTrace`
        single receiver function trace
    tlim: list or tuple or None
        x axis time limits in seconds if type=='time' or depth in km if
        type==depth (len(list)==2).
        If `None` full trace is plotted.
        Default None.
    ax : `matplotlib.pyplot.Axes`, optional
        Can define an axes to plot the RF into. Defaults to None.
        If None, new figure is created.
    outputdir : str, optional
        If set, saves a pdf of the plot to the directory.
        If None, plot will be shown instantly. Defaults to None.
    clean: bool
        If True, clears out all axes and plots RF only.
        Defaults to False.

     Returns
    -------
    ax : `matplotlib.pyplot.Axes`
    """
    set_mpl_params()

    # Get figure/axes dimensions
    if ax is None:
        width, height = 10, 2.5
        fig = plt.figure(figsize=(width, height))
        ax = plt.gca(zorder=9999999)
        axtmp = None
    else:
        bbox = ax.get_window_extent().transformed(
            fig.dpi_scale_trans.inverted())
        width, height = bbox.width, bbox.height
        axtmp = ax

    # The ratio ensures that the text
    # is perfectly distanced from top left/right corner
    ratio = width / height

    ydata = rf.data
    if rf.stats.type == 'time':
        # Get times
        times = rf.times() - (rf.stats.onset - rf.stats.starttime)
        if rf.stats.phase == 'S':
            times = np.flip(times)
            ydata = np.flip(-rf.data)
    else:
        z = np.hstack(((np.arange(-10, 0, .1)), np.arange(0, maxz + res, res)))
        times = z

    # Plot stuff into axes
    ax.fill_between(times,
                    0,
                    ydata,
                    where=ydata > 0,
                    interpolate=True,
                    color=(0.9, 0.2, 0.2))
    ax.fill_between(times,
                    0,
                    ydata,
                    where=ydata < 0,
                    interpolate=True,
                    color=(0.2, 0.2, 0.7))
    ax.plot(times, ydata, 'k', lw=0.75)

    # Set limits
    if tlim is None:
        # ax.set_xlim(times[0], times[-1])
        ax.set_xlim(0, times[-1])  # don't really wanna see the stuff before
    else:
        ax.set_xlim(tlim)

    # Removes top/right axes spines. If you want the whole thing, comment or remove
    remove_topright()

    # Plot RF only
    if clean:
        remove_all()
    else:
        if rf.stats.type == 'time':
            ax.set_xlabel("Conversion Time [s]")
        else:
            ax.set_xlabel("Conversion Depth [km]")
        ax.set_ylabel("A    ", rotation=0)
        text = rf.stats.starttime.isoformat(sep=" ") + "\n" + rf.get_id()
        ax.text(0.995,
                1.0 - 0.005 * ratio,
                text,
                transform=ax.transAxes,
                horizontalalignment="right",
                verticalalignment="top")

    # Only use tight layout if not part of plot.
    if axtmp is None:
        plt.tight_layout()

    # Outout the receiver function as pdf using
    # its station name and starttime
    if outputdir is not None:
        filename = os.path.join(
            outputdir,
            rf.get_id() + "_" +
            rf.stats.starttime._strftime_replacement('%Y%m%dT%H%M%S') + ".pdf")
        plt.savefig(filename, format="pdf")
    return ax