Ejemplo n.º 1
0
def get_axes_extent(ax: Axes, ax_crs: CRS, crs: CRS=SphericalEarth):
    """ Get the extent of an Axes in geographical (or other) coordinates. """
    xl, xr = ax.get_xlim()
    yb, yt = ax.get_ylim()

    ll = ax_crs.transform(crs, xl, yb)
    lr = ax_crs.transform(crs, xr, yb)
    ur = ax_crs.transform(crs, xr, yt)
    ul = ax_crs.transform(crs, xl, yt)
    return Polygon([ll, lr, ur, ul], crs=crs)
Ejemplo n.º 2
0
    def plot2axes(self, axes: plt.Axes, width: NumberLike=1, color: str='w'):
        """Plot the line to an axes.

        Parameters
        ----------
        axes : matplotlib.axes.Axes
            An MPL axes to plot to.
        color : str
            The color of the line.
        """
        axes.plot((self.point1.x, self.point2.x), (self.point1.y, self.point2.y), linewidth=width, color=color)
Ejemplo n.º 3
0
 def _plot_flatness(self, direction: str, axis: plt.Axes=None):
     plt.ioff()
     if axis is None:
         fig, axis = plt.subplots()
     data = self.flatness[direction.lower()]
     axis.set_title(direction.capitalize() + " Flatness")
     axis.plot(data['profile'].values)
     _remove_ticklabels(axis)
     axis.axhline(data['profile max'], color='r')
     axis.axhline(data['profile min'], color='r')
     axis.axvline(data['profile left'], color='g', linestyle='-.')
     axis.axvline(data['profile right'], color='g', linestyle='-.')
def show_spiral(data):
    from matplotlib.pyplot import imshow,show,figure,Axes,set_cmap
    
    #this is to remove plat axes
    fig = figure()
    fig.set_size_inches(6, 6)
    ax = Axes(fig, [0., 0., 1., 1.])
    ax.set_axis_off()
    fig.add_axes(ax)
    set_cmap('binary')
    #display plot
    ax.imshow(data, interpolation='nearest')
    show()
def show_Ulam_spiral(Ulam_data):
    """
        Plots the Ulam spiral and saves it. Depends on matplotlib.pyplot
    """
    from matplotlib.pyplot import set_cmap,savefig,figure,Axes,show

    
    fig = figure()
    fig.set_size_inches(6, 6)
    ax = Axes(fig, [0., 0., 1., 1.])
    ax.set_axis_off()
    fig.add_axes(ax)
    set_cmap('jet')
    ax.imshow(Ulam_data, interpolation='nearest')
    show()
Ejemplo n.º 6
0
def plot_polygon(geom: Union[Polygon, Iterable[Polygon]], *args,
        ax: Axes=None, crs: CRS=None, **kwargs):
    """ Plot a Polygon geometry, projected to the coordinate system `crs` """
    kwargs.setdefault("facecolor", "none")
    kwargs.setdefault("edgecolor", "black")
    x, y = geom.get_coordinate_lists(crs=crs)
    return ax.fill(x, y, *args, **kwargs)
Ejemplo n.º 7
0
def plot_multipoint(geom: Union[Multipoint, Iterable[Multipoint]], *args,
        ax: Axes=None, crs: CRS=None, **kwargs):
    """ Plot a Line geometry, projected to the coordinate system `crs` """
    kwargs.setdefault("linestyle", "none")
    kwargs.setdefault("marker", ".")
    x, y = geom.get_coordinate_lists(crs=crs)
    return ax.plot(x, y, *args, **kwargs)
Ejemplo n.º 8
0
 def _plot_image(self, axis: plt.Axes=None, title: str=''):
     plt.ioff()
     if axis is None:
         fig, axis = plt.subplots()
     axis.imshow(self.array, cmap=get_dicom_cmap())
     axis.axhline(self.positions['vertical']*self.array.shape[0], color='r')  # y
     axis.axvline(self.positions['horizontal']*self.array.shape[1], color='r')  # x
     _remove_ticklabels(axis)
     axis.set_title(title)
Ejemplo n.º 9
0
def plot_multiline(geom: Union[Multiline, Iterable[Multiline]], *args,
        ax: Axes=None, crs: CRS=None, **kwargs):
    """ Plot a Line geometry, projected to the coordinate system `crs` """
    out = []
    for line in geom:
        x, y = line.get_coordinate_lists(crs=crs)
        out.append(ax.plot(x, y, *args, **kwargs))
    return out
Ejemplo n.º 10
0
def plot_multipolygon(geom: Union[Multipolygon, Iterable[Multipolygon]], *args,
        ax: Axes=None, crs: CRS=None, **kwargs):
    """ Plot a Line geometry, projected to the coordinate system `crs` """
    kwargs.setdefault("facecolor", "none")
    kwargs.setdefault("edgecolor", "black")
    out = []
    for polygon in geom:
        x, y = polygon.get_coordinate_lists(crs=crs)
        out.append(ax.fill(x, y, *args, **kwargs))
    return out
Ejemplo n.º 11
0
    def plot2axes(self, axes: plt.Axes=None, edgecolor: str='black', fill: bool=False, plot_peaks: bool=True):
        """Add 2 circles to the axes: one at the maximum and minimum radius of the ROI.

        See Also
        --------
        :meth:`~pylinac.core.profile.CircleProfile.plot2axes` : Further parameter info.
        """
        if axes is None:
            fig, axes = plt.subplots()
            axes.imshow(self.image_array)
        axes.add_patch(mpl_Circle((self.center.x, self.center.y), edgecolor=edgecolor, radius=self.radius*(1+self.width_ratio),
                                  fill=fill))
        axes.add_patch(mpl_Circle((self.center.x, self.center.y), edgecolor=edgecolor, radius=self.radius*(1-self.width_ratio),
                                  fill=fill))
        if plot_peaks:
            x_locs = [peak.x for peak in self.peaks]
            y_locs = [peak.y for peak in self.peaks]
            axes.autoscale(enable=False)
            axes.scatter(x_locs, y_locs, s=20, marker='x', c=edgecolor)
Ejemplo n.º 12
0
    def plot(self, ax: plt.Axes=None, show: bool=True, clear_fig: bool=False, **kwargs):
        """Plot the image.

        Parameters
        ----------
        ax : matplotlib.Axes instance
            The axis to plot the image to. If None, creates a new figure.
        show : bool
            Whether to actually show the image. Set to false when plotting multiple items.
        clear_fig : bool
            Whether to clear the prior items on the figure before plotting.
        """
        if ax is None:
            fig, ax = plt.subplots()
        if clear_fig:
            plt.clf()
        ax.imshow(self.array, cmap=get_dicom_cmap(), **kwargs)
        if show:
            plt.show()
        return ax
Ejemplo n.º 13
0
def plot_grid(grid: RegularGrid, ax: Axes=None, crs: CRS=None, band: Union[int, tuple]=-1, **kwargs):
    """ Plot a grid instance

    Parameters
    ----------
    grid : RegularGrid
        raster data to plot
    ax : Axes, optional
        Axes to plot to [default plt.gca()]
    crs : CRS, optional
        Currently not supported
    band : int or tuple, optional
        Band(s) to plot. If *grid* has three bands, by default the three are
        plotted in false colour as RGB channels. Otherwise, the first band is
        plotted by default. If *band* is a tuple, it must have three integer
        elements.

    Notes
    -----
    Additional arguments are passed to `matplotlib.pyplot.imshow`
    """
    kwargs.setdefault("origin", "bottom")
    kwargs.setdefault("extent", grid.get_extent(crs=crs))
    kwargs.setdefault("cmap", cm.binary_r)

    if crs is not None and crs != grid.crs:
        raise NotImplementedError("RegularGrid reprojection not supported")

    # compute the pixels that can actually be displayed
    # be slightly generous by using a factor of 0.75 to avoid choosing too low
    # of a resolution
    _, _, width, height = ax.bbox.bounds
    ny, nx = grid.size
    r = (max(int(0.75*ny//height), 1), max(int(0.75*nx//width), 1))
    if band == -1:
        if len(grid.bands) == 3 and (band == -1):
            band = (0, 1, 2)
        else:
            band = 0
    if isinstance(band, int):
        arr = grid[::r[0],::r[1],band]
        arr = np.ma.masked_equal(arr, grid.nodata)
    else:
        if len(band) not in (3, 4):
            raise ValueError("bands must be RGB or RGBA (length 3 or 4)")
        arr = np.dstack([grid[::r[0],::r[1],i] for i in band]).astype(np.float32)
        arr = np.ma.masked_equal(arr, grid.nodata)
        arr[:,:,:3] /= arr[:,:,:3].max()

    im = ax.imshow(arr, **kwargs)
    if ax == gca():
        sci(im)
    return im
Ejemplo n.º 14
0
def get_axes_limits(ax: Axes, ax_crs: CRS, crs: CRS=SphericalEarth):
    """ Get the limits of the window covered by an Axes in another coordinate
    system. """
    xl, xr = ax.get_xlim()
    yb, yt = ax.get_ylim()

    # Minimize bottom spine
    x_ = scipy.optimize.fminbound(lambda x: ax_crs.transform(crs, x, yb)[1], xl, xr)
    ymin = ax_crs.transform(crs, x_, yb)[1]

    # Maximize top spine
    x_ = scipy.optimize.fminbound(lambda x: -ax_crs.transform(crs, x, yt)[1], xl, xr)
    ymax = ax_crs.transform(crs, x_, yt)[1]

    # Minimize left spine
    y_ = scipy.optimize.fminbound(lambda y: ax_crs.transform(crs, xl, y)[0], yb, yt)
    xmin = ax_crs.transform(crs, xl, y_)[0]

    # Maximize right spine
    y_ = scipy.optimize.fminbound(lambda y: -ax_crs.transform(crs, xr, y)[0], yb, yt)
    xmax = ax_crs.transform(crs, xr, y_)[0]
    return xmin, xmax, ymin, ymax
Ejemplo n.º 15
0
    def _plot_analyzed_subimage(self, subimage: str, show: bool=True, ax: plt.Axes=None):
        """Plot an individual piece of the VMAT analysis.

        Parameters
        ----------
        subimage : str
            Specifies which image to plot.
        show : bool
            Whether to actually plot the image.
        ax : matplotlib Axes, None
            If None (default), creates a new figure to plot to, otherwise plots to the given axes.
        """
        plt.ioff()
        if ax is None:
            fig, ax = plt.subplots()

        # plot DMLC or OPEN image
        if subimage in (DMLC, OPEN):
            if subimage == DMLC:
                img = self.dmlc_image
            elif subimage == OPEN:
                img = self.open_image
            ax.imshow(img, cmap=get_dicom_cmap())
            self._draw_segments(ax)
            plt.sca(ax)
            plt.axis('off')
            plt.tight_layout()

        # plot profile
        elif subimage == PROFILE:
            dmlc_prof, open_prof = self._median_profiles((self.dmlc_image, self.open_image))
            ax.plot(dmlc_prof.values, label='DMLC')
            ax.plot(open_prof.values, label='Open')
            ax.autoscale(axis='x', tight=True)
            ax.legend(loc=8, fontsize='large')
            ax.grid()

        if show:
            plt.show()
Ejemplo n.º 16
0
 def add_guards_to_axes(self, axis: plt.Axes, color: str='g'):
     """Plot guard rails to the axis."""
     if self.settings.orientation == UP_DOWN:
         length = self.image.shape[0]
     else:
         length = self.image.shape[1]
     x_data = np.arange(length)
     left_y_data = self.left_guard(x_data)
     right_y_data = self.right_guard(x_data)
     if self.settings.orientation == UP_DOWN:
         axis.plot(left_y_data, x_data, color=color)
         axis.plot(right_y_data, x_data, color=color)
     else:
         axis.plot(x_data, left_y_data, color=color)
         axis.plot(x_data, right_y_data, color=color)
Ejemplo n.º 17
0
    def plot_axis_images(self, axis: str=GANTRY, show: bool=True, ax: plt.Axes=None):
        """Plot all CAX/BB/EPID positions for the images of a given axis.

        For example, axis='Couch' plots a reference image, and all the BB points of the other
        images where the couch was moving.

        Parameters
        ----------
        axis : {'Gantry', 'Collimator', 'Couch', 'Combo'}
            The images/markers from which accelerator axis to plot.
        show : bool
            Whether to actually show the images.
        ax : None, matplotlib.Axes
            The axis to plot to. If None, creates a new plot.
        """
        images = [image for image in self.images if image.variable_axis in (axis, REFERENCE)]
        ax = images[0].plot(show=False, ax=ax)  # plots the first marker; plot the rest of the markers below
        if axis != COUCH:
            # plot EPID
            epid_xs = [img.epid.x for img in images[1:]]
            epid_ys = [img.epid.y for img in images[1:]]
            ax.plot(epid_xs, epid_ys, 'b+', ms=8)
            # get CAX positions
            xs = [img.field_cax.x for img in images[1:]]
            ys = [img.field_cax.y for img in images[1:]]
            marker = 'gs'
        else:
            # get BB positions
            xs = [img.bb.x for img in images[1:]]
            ys = [img.bb.y for img in images[1:]]
            marker = 'ro'
        ax.plot(xs, ys, marker, ms=8)
        # set labels
        ax.set_title(axis + ' wobble')
        ax.set_xlabel(axis + ' positions superimposed')
        ax.set_ylabel(axis + f" iso size: {getattr(self, axis.lower() + '_iso_size'):3.2f}mm")
        if show:
            plt.show()
Ejemplo n.º 18
0
    def plot2axes(self, axes: plt.Axes, edgecolor: str='black', angle: float=0.0, fill: bool=False,
                  alpha: float=1, facecolor: str='g'):
        """Plot the Rectangle to the axes.

        Parameters
        ----------
        axes : matplotlib.axes.Axes
            An MPL axes to plot to.
        edgecolor : str
            The color of the circle.
        angle : float
            Angle of the rectangle.
        fill : bool
            Whether to fill the rectangle with color or leave hollow.
        """
        axes.add_patch(mpl_Rectangle((self.bl_corner.x, self.bl_corner.y),
                                     width=self.width,
                                     height=self.height,
                                     angle=angle,
                                     edgecolor=edgecolor,
                                     alpha=alpha,
                                     facecolor=facecolor,
                                     fill=fill))
Ejemplo n.º 19
0
def annotate(artist: Artist, label: str, where: str="over", ax: Axes=None, **kwargs):
    """ Add a Text object near *artist*. """
    if where == "over":
        x, y = _position_over(artist)
        kwargs.setdefault("va", "center")
    elif where == "below":
        x, y = _position_below(artist)
        kwargs.setdefault("va", "top")
    elif where == "above":
        x, y = _position_above(artist)
        kwargs.setdefault("va", "bottom")
    else:
        raise ValueError("invalid value for 'where'")
    return ax.text(x, y, label, **kwargs)
Ejemplo n.º 20
0
def plot_at_coordinates(images: np.ndarray,
                        coordinates: np.ndarray,
                        zoom: float,
                        ax: plt.Axes = None):
    assert images.ndim == 3 and coordinates.ndim == 2
    assert images.shape[0] == coordinates.shape[0] and coordinates.shape[1] == 2
    if ax is None:
        ax = plt.gca()

    artists = []
    for (x, y), to_plot in zip(coordinates, images):
        im = OffsetImage(to_plot, zoom=zoom, cmap='coolwarm')
        box = AnnotationBbox(im, (x, y), frameon=False)
        artists.append(ax.add_artist(box))
    return artists
Ejemplo n.º 21
0
    def plot_attribute_reconstruction_from_svd(
            self,
            attribute_name,
            number_of_principal_components='all',
            axes: plt.Axes = None,
            marker_size=3):

        original_data = self.get_attribute(attribute_name)
        reconstruction = self.reconstruction_from_principal_components(
            number_of_principal_components=number_of_principal_components,
            attribute_names=[attribute_name])

        if axes is None:
            figure = plt.figure()
            axes = plt.axes()
            figure.set_facecolor('white')
            axes.set_facecolor('white')
            axes.set_title(attribute_name +
                           f", {number_of_principal_components} principal "
                           f"components")
            axes.set_xlabel('Original data')
            axes.set_ylabel('Reconstructed data')

        axes.scatter(original_data, reconstruction, s=marker_size)
Ejemplo n.º 22
0
def _radar(df: pd.DataFrame,
           ax: plt.Axes,
           label: Text,
           all_tags: Sequence[Text],
           color: Text,
           alpha: float = 0.2,
           edge_alpha: float = 0.85,
           zorder: int = 2,
           edge_style: Text = '-'):
    """Plot utility for generating the underlying radar plot."""
    tmp = df.groupby('tag').mean().reset_index()

    values = []
    for curr_tag in all_tags:
        score = 0.
        selected = tmp[tmp['tag'] == curr_tag]
        if len(selected) == 1:
            score = float(selected['score'])
        else:
            print('{} bsuite scores found for tag {!r} with setting {!r}. '
                  'Replacing with zero.'.format(len(selected), curr_tag,
                                                label))
        values.append(score)
    values = np.maximum(values, 0.05)  # don't let radar collapse to 0.
    values = np.concatenate((values, [values[0]]))

    angles = np.linspace(0, 2 * np.pi, len(all_tags), endpoint=False)
    angles = np.concatenate((angles, [angles[0]]))

    ax.plot(angles,
            values,
            '-',
            linewidth=5,
            label=label,
            c=color,
            alpha=edge_alpha,
            zorder=zorder,
            linestyle=edge_style)
    ax.fill(angles, values, alpha=alpha, color=color, zorder=zorder)
    ax.set_thetagrids(angles * 180 / np.pi,
                      map(_tag_pretify, all_tags),
                      fontsize=18)

    # To avoid text on top of gridlines, we flip horizontalalignment
    # based on label location
    text_angles = np.rad2deg(angles)
    for label, angle in zip(ax.get_xticklabels()[:-1], text_angles[:-1]):
        if 90 <= angle <= 270:
            label.set_horizontalalignment('right')
        else:
            label.set_horizontalalignment('left')
Ejemplo n.º 23
0
def generate_gsr_range_correct_means_plot(axes: plt.Axes,
                                          stimulus_data: pd.DataFrame):
    """
    Plots range corrected means over two minute windows.

    :param axes: the axes to plot on
    :param stimulus_data: the raw stimulus data
    :return: None
    """
    plt.sca(axes)

    windowed_data = stimulus_data.resample("2min").mean()[:15]
    time = convert_date_to_time(windowed_data.index)

    axes.set_title("Range-Corrected GSR Means Over Two-Minute Windows")
    axes.set_xlabel("Time (minutes)", fontsize="large")
    axes.set_ylabel("Range-Corrected GSR (dimensionless)")
    axes.set_ylim(0, 1)
    set_windowed_x_axis(axes)
    axes.bar(time,
             windowed_data[RANGE_CORRECT_EDA],
             width=2,
             align="edge",
             edgecolor="black")
Ejemplo n.º 24
0
    def plot(self,
             ax: plt.Axes = None,
             show: bool = True,
             clear_fig: bool = False,
             **kwargs):
        """Plot the image.

        Parameters
        ----------
        ax : matplotlib.Axes instance
            The axis to plot the image to. If None, creates a new figure.
        show : bool
            Whether to actually show the image. Set to false when plotting multiple items.
        clear_fig : bool
            Whether to clear the prior items on the figure before plotting.
        """
        if ax is None:
            fig, ax = plt.subplots()
        if clear_fig:
            plt.clf()
        ax.imshow(self.array, cmap=get_dicom_cmap(), **kwargs)
        if show:
            plt.show()
        return ax
Ejemplo n.º 25
0
def draw_box(
    pyplot_axis: plt.Axes,
    vertices: np.ndarray,
    axes: Optional[Any] = None,
    color: Union[str, Tuple[float, float, float]] = "red",
) -> None:
    axes = _get_axes_or_default(axes)
    vertices = vertices[axes, :]
    connections = [
        [0, 1],
        [1, 2],
        [2, 3],
        [3, 0],  # Lower plane parallel to Z=0 plane
        [4, 5],
        [5, 6],
        [6, 7],
        [7, 4],  # Upper plane parallel to Z=0 plane
        [0, 4],
        [1, 5],
        [2, 6],
        [3, 7],  # Connections between upper and lower planes
    ]
    for connection in connections:
        pyplot_axis.plot(*vertices[:, connection], c=color, lw=0.5)
Ejemplo n.º 26
0
    def show_contour(self, ax: plt.Axes = None):
        """Plot contours on image.

        Parameters
        ----------
        ax : matplotlib.Axes
            Axes to use for plotting.

        Returns
        -------
        ax : matplotlib.Axes
        """
        if not ax:
            fig, ax = plt.subplots()

        ax.set_title('Contours')
        self.contour.plot_mpl(ax=ax)

        ax.imshow(self.image)
        ax.axis('image')
        ax.set_xticks([])
        ax.set_yticks([])

        return ax
Ejemplo n.º 27
0
def visualize_frequent_words(vectors_2d: np.ndarray,
                             dataset: DataSet,
                             k: int,
                             ax: plt.Axes = None) -> None:
    word_ids, counts = np.unique(dataset.data, return_counts=True)

    indices = np.argpartition(-counts, k)[:k]
    frequent_word_ids = word_ids[indices]

    if ax is None:
        fig, ax = plt.subplots(figsize=(13, 13))
    else:
        fig = None

    vectors_2d = vectors_2d[frequent_word_ids]

    ax.scatter(vectors_2d[:, 0], vectors_2d[:, 1], s=2, alpha=0.25)
    for i, id in enumerate(frequent_word_ids):
        ax.annotate(dataset.vocabulary.to_word(id),
                    (vectors_2d[i, 0], vectors_2d[i, 1]))

    if fig is not None:
        fig.tight_layout()
        fig.show()
Ejemplo n.º 28
0
def plot_curve(axes: plt.Axes,
               filename: str,
               label: str,
               mask: Sequence[float] = None,
               **kwargs: Any) -> List[lines.Line2D]:
    """
    Plot a curve in the axes object with the values in the given file.

    The file can include single line comments which start with #. Each line should contain two floats which are
    the x and the y value, respectively.

    Parameters
    ----------
    axes : matplotlib.pyplot.Axes
        The axes object.
    filename : str
        The filename.
    label : str
        The label of the plot.
    mask : Sequence[float]
        The range of the x-axis within data is plotted.
    kwargs : Any
        Additional kwargs which will be passed to the axes.plot method.

    Returns
    -------
    matplotlib.lines.Line2D
        A list of objects representing the plotted data.
    """
    try:
        if mask is None:
            mask = [-float('inf'), float('inf')]
        data = []
        with open(filename, 'r') as file:
            for line in file:
                if line.startswith('#'):
                    continue
                split_str = line.split()
                x_value, y_value = float(split_str[0]), float(split_str[1])
                if mask[0] <= x_value <= mask[1]:
                    data.append((x_value, y_value))
        sorted(data)
        x_values = [a[0] for a in data]
        y_values = [a[1] for a in data]
        return axes.plot(x_values, y_values, label=label, **kwargs)
    except FileNotFoundError:
        warnings.warn("Could not open the file {}.".format(filename))
        return []
Ejemplo n.º 29
0
def center_axis(axes: plt.Axes, which='y'):
    if which == 'y':
        max_abs = np.max(np.abs(axes.get_ylim()))
        axes.set_ylim(-max_abs, max_abs)
    elif which == 'x':
        max_abs = np.max(np.abs(axes.get_xlim()))
        axes.set_xlim(-max_abs, max_abs)
    elif which == 'both':
        pass
    return
Ejemplo n.º 30
0
 def _default_after_subplot(self, ax: plt.Axes, group_name: str, x_label: str):
     """Add title xlabel and legend to single chart
     Args:
         ax: matplotlib Axes
         group_name: name of metrics group (eg. Accuracy, Recall)
         x_label: label of x axis (eg. epoch, iteration, batch)
     """
     ax.set_title(group_name)
     ax.set_xlabel(x_label)
     ax.legend(loc='center right')
    def __init__(
        self,
        ax: plt.Axes,
    ):
        """
        :param ax: xxx
        """
        canvas = ax.get_figure().canvas
        canvas.mpl_connect("button_press_event", self._on_press)
        canvas.mpl_connect("button_release_event", self._on_release)
        canvas.mpl_connect("scroll_event", self._on_scroll_event)

        self.ax = ax
        self.canvas = canvas
        self.last_point: Optional[Tuple[float, float]] = None
        self.motion_handler: Optional[Callable] = None
Ejemplo n.º 32
0
    def _plot_ax_outliers(axes: plt.Axes, ax_data: pd.Series, extents: np.ndarray):
        if plotter == sns.kdeplot:
            group = .5 * np.diff(axes.get_ylim())
            ax_data = ax_data.values

            outlier_data = ax_data[np.logical_or(cutoff_lo > ax_data, ax_data > cutoff_hi)]
            _plot_outliers(axes, outlier_data, orient=orient, group=group, padding=padding,
                           plot_extents=extents, fmt=fmt)
            return axes

        if not group_names or len(group_names) == 1:
            _plot_group_outliers(ax_data, extents, axes=axes)
            return

        for group_idx, group_name in enumerate(group_names):
            _plot_group_outliers(ax_data, extents, group_idx=group_idx, group_name=group_name, axes=axes)
Ejemplo n.º 33
0
def graph_frame(hist: history.DB, db_res_case: int,
                contour_key: history.ContourKey, ax: plt.Axes):
    row_skeleton = history.ColumnResult._all_nones()._replace(
        result_case_num=db_res_case)
    column_data = list(hist.get_all_matching(row_skeleton))

    col_data_graph = [
        cd for cd in column_data if cd.contour_key == contour_key
    ]

    ax.clear()
    graphed_something = False

    for yielded, base_col in (
        (False, 'tab:blue'),
        (True, 'tab:orange'),
    ):
        # Fall back to NaNs, only override with the real data.
        x_to_cd = {
            cd.x: history.ColumnResult._nans_at(cd.x)
            for cd in col_data_graph
        }

        # Override with the real thing
        for cd in col_data_graph:
            if cd.yielded == yielded:
                x_to_cd[cd.x] = cd
                graphed_something = True

        res_to_plot = [cd for x, cd in sorted(x_to_cd.items())]
        x = [cd.x for cd in res_to_plot]
        y_min = [cd.minimum for cd in res_to_plot]
        y_mean = [cd.mean for cd in res_to_plot]
        y_max = [cd.maximum for cd in res_to_plot]

        ax.fill_between(x, y_min, y_max, color=base_col, alpha=0.2)
        yield_text = "Dilated" if yielded else "Undilated"
        ax.plot(x,
                y_mean,
                color=base_col,
                label=f"{contour_key.name} ({yield_text})")

    ax.legend()

    return graphed_something
Ejemplo n.º 34
0
    def attach_probability_current_to_axis(
        self,
        axis: plt.Axes,
        plot_limit: Optional[float] = None,
        distance_unit: u.Unit = "bohr_radius",
        rate_unit="per_asec",
    ):
        distance_unit_value, _ = u.get_unit_value_and_latex(distance_unit)
        rate_unit_value, _ = u.get_unit_value_and_latex(rate_unit)

        (
            current_mesh_z,
            current_mesh_rho,
        ) = (self.mesh.get_probability_current_density_vector_field()
             )  # actually densities here

        current_mesh_z *= self.mesh.delta_z
        current_mesh_rho *= self.mesh.delta_rho

        skip_count = (
            int(self.mesh.z_mesh.shape[0] / 50),
            int(self.mesh.z_mesh.shape[1] / 50),
        )
        skip = (slice(None, None,
                      skip_count[0]), slice(None, None, skip_count[1]))

        normalization = np.nanmax(
            np.sqrt((current_mesh_z**2) + (current_mesh_rho**2))[skip])
        if normalization == 0:
            normalization = 1

        sli = self.mesh.get_mesh_slicer(plot_limit)

        quiv = axis.quiver(
            self.mesh.z_mesh[sli][skip] / distance_unit_value,
            self.mesh.rho_mesh[sli][skip] / distance_unit_value,
            current_mesh_z[sli][skip] / normalization,
            current_mesh_rho[sli][skip] / normalization,
            pivot="middle",
            scale=10,
            units="width",
            scale_units="width",
            alpha=0.5,
            color="white",
        )

        return quiv
Ejemplo n.º 35
0
def get_bar_data(
    plot_axis: plt.Axes
) -> Tuple[str, List[str], List[float], List[str], List[float], List[float],
           List[Tuple[float, float, float, float]]]:
    """
    Extracts plot's title, x-axis name and range and y-axis name and range.

    Parameters
    ----------
    plot_axis : matplotlib.pyplot.Axes
        A matplotlib axis from which all of the aforementioned information will
        be extracted.

    Returns
    -------
    plot_title : string
        Plot's title.
    plot_x_tick_names : List[string]
        Tick labels of the plot's x-axis.
    plot_x_range : List[Number]
        Range of the plot's x-axis.
    plot_y_tick_names : List[string]
        Tick labels of the plot's y-axis.
    plot_y_range : List[Number]
        Range of the plot's y-axis.
    plot_bar_width : List[Number]
        Bar width of every bar in the plot.
    plot_bar_colours : List[Tuple[float, float, float, float]]
        Bar colour of every bar in the plot. This is represented as an (r, g,
        b, alpha) tuple.
    """
    assert isinstance(plot_axis, plt.Axes), 'Must be a matplotlib axis.'

    plot_title = plot_axis.get_title()
    plot_x_tick_names = [
        x.get_text() for x in plot_axis.xaxis.get_ticklabels()
    ]
    plot_x_range = plot_axis.xaxis.get_view_interval()
    plot_y_tick_names = [
        y.get_text() for y in plot_axis.yaxis.get_ticklabels()
    ]
    plot_y_range = plot_axis.yaxis.get_view_interval()
    plot_bar_width = [ybar.get_width() for ybar in plot_axis.patches]
    plot_bar_colours = [ybar.get_facecolor() for ybar in plot_axis.patches]

    return (plot_title, plot_x_tick_names, plot_x_range, plot_y_tick_names,
            plot_y_range, plot_bar_width, plot_bar_colours)
Ejemplo n.º 36
0
def volcano_plot(filename: str,
                 *,
                 ax: plt.Axes = None,
                 lfc_col: str = DESEQ2_LOG2_CHANGE,
                 pad_col: str = DESEQ2_PADJ,
                 threshold: float = 0.05,
                 use_threshold: bool = True,
                 color_col: str = None,
                 normal_color: Union[str, tuple] = Palette().blue(),
                 sign_color: Union[str, tuple] = Palette().red(),
                 scatter_options: dict = None,
                 add_labels: bool = True):
    d = pd.read_csv(filename)
    # Sanity check
    try:
        t = d[lfc_col], d[pad_col]  # Temporary variable to check
    except KeyError as e:
        raise KeyError(f"Please check your input data. Your input data "
                       f"should have columns for log fold change "
                       f"'{DESEQ2_LOG2_CHANGE}' "
                       f"and adjusted p value '{DESEQ2_PADJ}' column. If "
                       f"your column names are different, please provide "
                       f"them with argument 'lfc_col' and 'pad_col'"
                       f"") from e

    d = d.fillna(0)  # type: pd.DataFrame
    # Generate temporary column name
    temp = f"temp_col{np.random.randint(1000, 9000)}"
    colors = f"temp_color{np.random.randint(1000, 8000)}"

    d[temp] = d[pad_col].map(lambda x: -np.log10(x))
    d[colors] = normal_color

    # Use different color for values below threshold
    if use_threshold:
        d.loc[d[pad_col] < threshold, colors] = sign_color

    # If explicit color column is given, use those colors
    if color_col is not None:
        colors = color_col

    # If no axes is given, generate default one
    if ax is None:
        _, ax = plt.subplots()

    opts = {"marker": "."}
    if scatter_options is not None:
        opts = {**opts, **scatter_options}

    ax.scatter(d[lfc_col], d[temp], color=d[colors], **opts)

    if add_labels:
        ax.set_ylabel("-Log$_{10}$ Adj P value")
        ax.set_xlabel("Log$_2$ Fold change")
Ejemplo n.º 37
0
def image(
    data: SWIFTDataset, ax: plt.Axes, radial_bins: np.array, center: np.array
) -> None:
    """
    Creates the image of the gas density.
    """

    delta = image_bounds[1] - image_bounds[0]

    # Need to re-scale our x, y to 0:1
    x = data.gas.coordinates[:, 0]
    y = data.gas.coordinates[:, 1]

    left = center[0] + image_bounds[0]
    bottom = center[1] + image_bounds[0]

    x = x - left
    x = x / delta
    y = y - bottom
    y = y / delta

    h = data.gas.smoothing_lengths
    h = h / delta

    m = data.gas.masses

    image = scatter_parallel(y, x, m, h, image_res)

    ax.imshow(image, cmap=image_cmap, norm=LogNorm(), origin="lower")

    ax.text(
        0.025,
        0.975,
        "Gas Density",
        color=image_textcolor,
        transform=ax.transAxes,
        ha="left",
        va="top",
    )

    ax.set_xticks([])
    ax.set_yticks([])

    return
Ejemplo n.º 38
0
def traj_colormap(ax: plt.Axes,
                  traj: trajectory.PosePath3D,
                  array: ListOrArray,
                  plot_mode: PlotMode,
                  min_map: float,
                  max_map: float,
                  title: str = "") -> None:
    """
    color map a path/trajectory in xyz coordinates according to
    an array of values
    :param ax: plot axis
    :param traj: trajectory.PosePath3D or trajectory.PoseTrajectory3D object
    :param array: Nx1 array of values used for color mapping
    :param plot_mode: PlotMode
    :param min_map: lower bound value for color mapping
    :param max_map: upper bound value for color mapping
    :param title: plot title
    """
    pos = traj.positions_xyz
    norm = mpl.colors.Normalize(vmin=min_map, vmax=max_map, clip=True)
    mapper = cm.ScalarMappable(
        norm=norm,
        cmap=SETTINGS.plot_trajectory_cmap)  # cm.*_r is reversed cmap
    mapper.set_array(array)
    colors = [mapper.to_rgba(a) for a in array]
    line_collection = colored_line_collection(pos, colors, plot_mode)
    ax.add_collection(line_collection)
    ax.autoscale_view(True, True, True)
    if plot_mode == PlotMode.xyz:
        ax.set_zlim(np.amin(traj.positions_xyz[:, 2]),
                    np.amax(traj.positions_xyz[:, 2]))
        if SETTINGS.plot_xyz_realistic:
            set_aspect_equal_3d(ax)
    fig = plt.gcf()
    cbar = fig.colorbar(
        mapper, ticks=[min_map, (max_map - (max_map - min_map) / 2), max_map])
    cbar.ax.set_yticklabels([
        "{0:0.3f}".format(min_map),
        "{0:0.3f}".format(max_map - (max_map - min_map) / 2),
        "{0:0.3f}".format(max_map)
    ])
    if title:
        ax.legend(frameon=True)
        plt.title(title)
Ejemplo n.º 39
0
def plot_text(axes: plt.Axes, text: str):
    """Plot text on an axes

    Args:
        axes (plt.Axes): an axes object
        text (str): the text
    """

    axes.axis('off')
    axes.grid('off')
    axes.text(x=0, y=0, s=text, horizontalalignment='left', fontdict=FONT)
Ejemplo n.º 40
0
def plot_grid(grid: RegularGrid, ax: Axes = None, crs: CRS = None, **kwargs):
    kwargs.setdefault("origin", "bottom")
    kwargs.setdefault("extent", grid.get_extent(crs=crs))
    kwargs.setdefault("cmap", cm.binary_r)

    # compute the pixels that can actually be displayed
    _, _, width, height = ax.bbox.bounds
    ny, nx = grid.size
    r = (max(int(0.75 * ny // height), 1), max(int(0.75 * nx // width), 1))
    arr = grid[::r[0], ::r[1]]

    if not np.isnan(grid.nodata):
        arr[arr == grid.nodata] = np.nan
    im = ax.imshow(arr, **kwargs)
    if ax == gca():
        sci(im)
    return im
Ejemplo n.º 41
0
    def scatter(self, ax: plt.Axes) -> PathCollection:
        """Draw an x/y scatter plot

        :param ax: Plot Axes.
        :returns: the result of ax.scatter
        """

        return ax.scatter(
            self.x,
            self.y,
            label=self.label,
            # edgecolors=self.edge,
            # facecolor=self.face,
            alpha=self.alpha,
            marker=self.marker,
            s=np.ones_like(self.x) * (self.size * 10),  # ???
        )
Ejemplo n.º 42
0
    def line(self, ax: plt.Axes, show_markers: bool) -> List:
        """Draw an x/y line plot

        :param ax: Plot Axes.
        :param show_markers: Whether to show markers
        :returns: the result of ax.plot
        """

        return ax.plot(
            self.x,
            self.y,
            label=self.label,
            # color=self.color,
            alpha=self.alpha,
            marker=self.marker if show_markers else "",
            markersize=self.size,
        )
Ejemplo n.º 43
0
def plot_2B(ax: plt.Axes):
    cm = np.zeros((len(Exp1.structures), len(Exp1.structures)))
    for pid in DataExp1.pids:
        data = DataExp1(pid)
        cm += data.plot_confusion_matrix()
    cm /= len(DataExp1.pids)
    ticklabels = list(map(lambda s: f'${s}$', Exp1.structures))
    plot_confusion_matrix(cm, ticklabels, ticklabels, ax)
    ax.set_title('Human avg.')
    ax.set_xlabel('Choice')
    ax.set_ylabel('True Structure')
    plt.tight_layout()
Ejemplo n.º 44
0
 def p_bones(
     axes_: Axes,
     xy_pt1: np.ndarray,
     xy_pt2: np.ndarray,
     dxy: np.ndarray,
     p_f: float = 0.9,
     color_: str = blue_,
     n_bones: int = 5,
     do_primary: bool = True,
 ) -> None:
     alpha_ = 0.7
     p_dashing = [1, 0]  # [4, 4]
     p_lw = 3
     axes_.plot(
         [xy_pt1[0], xy_pt2[0]],
         [xy_pt1[1], xy_pt2[1]],
         dashes=p_dashing,
         lw=p_lw,
         color=color_,
         alpha=1 if do_primary else alpha_,
     )
     bones = np.linspace(1, 0, n_bones, endpoint=False)
     for i_, f in enumerate(bones):
         x = xy_pt1[0] * f + xy_pt2[0] * (1 - f)
         y = xy_pt1[1] * f + xy_pt2[1] * (1 - f)
         sf = 4 if do_primary else 3
         dx = dxy[0] * sf
         dy = dxy[1] * sf
         x_pair = [x - dx, x + dx]
         y_pair = [y - dy, y + dy]
         axes_.plot(
             x_pair,
             y_pair,
             lw=5 if i_ == 0 else 2.5,
             color=color_,
             alpha=1 if do_primary else alpha_,
         )
     f = p_f
     p_xy = (
         xy_pt1[0] * f + xy_pt2[0] * (1 - f) - 0.1,
         xy_pt1[1] * f + xy_pt2[1] * (1 - f) - 0.1,
     )
     if do_primary:
         axes_.text(
             *p_xy,
             r"$\mathbf{\widetilde{p}}$",
             color=color_,
             fontsize=18,
             rotation=0,
             transform=axes_.transAxes,
             horizontalalignment="right",
             verticalalignment="bottom",
         )
Ejemplo n.º 45
0
Archivo: plot.py Proyecto: whynot-s/evo
def set_aspect_equal_3d(ax: plt.Axes) -> None:
    """
    kudos to https://stackoverflow.com/a/35126679
    :param ax: matplotlib 3D axes object
    """
    xlim = ax.get_xlim3d()
    ylim = ax.get_ylim3d()
    zlim = ax.get_zlim3d()

    from numpy import mean
    xmean = mean(xlim)
    ymean = mean(ylim)
    zmean = mean(zlim)

    plot_radius = max([
        abs(lim - mean_)
        for lims, mean_ in ((xlim, xmean), (ylim, ymean), (zlim, zmean))
        for lim in lims
    ])

    ax.set_xlim3d([xmean - plot_radius, xmean + plot_radius])
    ax.set_ylim3d([ymean - plot_radius, ymean + plot_radius])
    ax.set_zlim3d([zmean - plot_radius, zmean + plot_radius])
Ejemplo n.º 46
0
def render_poly(poly: shapely.geometry.Polygon, ax: plt.Axes, kw=None):
    """Render an individual shapely shapely.geometry.Polygon.

    Args:
        poly (shapely.geometry.Polygon): Poly or multipoly to render.
        ax (plt.Axes): Matplotlib axis to render to.
        kw (dict): Dictionary of kwargs for the plotting.  Defaults to None.

    Returns:
        patch: ax.add_patch result
    """
    # TODO: maybe if done in batch we can speed this up?
    kw = kw or {}
    return ax.add_patch(
        descartes.PolygonPatch(poly, **{
            **style_config.poly,
            **kw
        }))
Ejemplo n.º 47
0
def hexplot(
        ax: plt.Axes,
        grid: np.ndarray,
        data: np.ndarray,
        hex_size: float=11.5,
        cmap: str='viridis'
) -> plt.Axes:
    """
    Plot grid and data on a hexagon grid. Useful for SOMs.

    Parameters
    ----------
    ax : Axes to plot on.
    grid : Array of (x, y) tuples.
    data : Array of len(grid) with datapoint.
    hex_size : Radius in points determining the hexagon size.
    cmap : Colormap to use for colouring.

    Returns
    -------
    ax : Axes with hexagon plot.

    """

    # Create hexagons
    collection = RegularPolyCollection(
        numsides=6,
        sizes=(2 * np.pi * hex_size ** 2,),
        edgecolors=(0, 0, 0, 0),
        transOffset=ax.transData,
        offsets=grid,
        array=data,
        cmap=plt.get_cmap(cmap)
    )

    # Scale the plot properly
    ax.add_collection(collection, autolim=True)
    ax.set_xlim(grid[:, 0].min() - 0.75, grid[:, 0].max() + 0.75)
    ax.set_ylim(grid[:, 1].min() - 0.75, grid[:, 1].max() + 0.75)
    ax.axis('off')

    return ax
Ejemplo n.º 48
0
    def __draw_track__(self, rhythm_track: Track, axes: plt.Axes, **kw):
        ioi_vector = self.get_feature_extractor("ioi_vector").process(
            rhythm_track)
        onset_positions = self.get_feature_extractor(
            "onset_positions").process(rhythm_track)

        # noinspection SpellCheckingInspection
        styles = {
            'edgecolor': kw['color'],
            'facecolor': colors.to_rgba(kw['color'], 0.18),
            'linewidth': 2.0
        }

        return axes.bar(onset_positions,
                        ioi_vector,
                        width=ioi_vector,
                        align="edge",
                        **styles)
Ejemplo n.º 49
0
    def attach_mesh_to_axis(
        self,
        axis: plt.Axes,
        mesh: "meshes.ScalarMesh",
        distance_unit: u.Unit = "bohr_radius",
        norm=si.vis.AbsoluteRenormalize(),
        plot_limit=None,
        slicer="get_mesh_slicer",
        **kwargs,
    ):
        unit_value, _ = u.get_unit_value_and_latex(distance_unit)

        _slice = getattr(self.mesh, slicer)(plot_limit)

        (line, ) = axis.plot(self.mesh.z_mesh[_slice] / unit_value,
                             norm(mesh[_slice]), **kwargs)

        return line
Ejemplo n.º 50
0
 def make_plot(
         self,
         axes: pyplot.Axes = None,
         vmin: float = None,
         vmax: float = None,
         show: bool = False,
         title: str = None,
         # figsize=rcParams["figure.figsize"],
         extent: (float, float, float, float) = None,
         cbar_label: str = None,
         **kwargs
 ):
     if vmin is None:
         vmin = np.min(self.values)
     if vmax is None:
         vmax = np.max(self.values)
     kwargs.update(**fig.get_topobathy_kwargs(self.values, vmin, vmax))
     kwargs.pop('col_val')
     levels = kwargs.pop('levels')
     if vmin != vmax:
         self.tricontourf(
             axes=axes,
             levels=levels,
             vmin=vmin,
             vmax=vmax,
             **kwargs
         )
     else:
         self.tripcolor(axes=axes, **kwargs)
     self.quadface(axes=axes, **kwargs)
     axes.axis('scaled')
     if extent is not None:
         axes.axis(extent)
     if title is not None:
         axes.set_title(title)
     mappable = ScalarMappable(cmap=kwargs['cmap'])
     mappable.set_array([])
     mappable.set_clim(vmin, vmax)
     divider = make_axes_locatable(axes)
     cax = divider.append_axes("bottom", size="2%", pad=0.5)
     cbar = plt.colorbar(
         mappable,
         cax=cax,
         orientation='horizontal'
     )
     cbar.set_ticks([vmin, vmax])
     cbar.set_ticklabels([np.around(vmin, 2), np.around(vmax, 2)])
     if cbar_label is not None:
         cbar.set_label(cbar_label)
     if show:
         pyplot.show()
     return axes
Ejemplo n.º 51
0
def make_dual_axis(ax: plt.Axes = None, axis='x', unit='nm', minor_ticks=True):
    if ax is None:
        ax = plt.gca()

    if axis == 'x':
        pseudo_ax = ax.twiny()
        limits = ax.get_xlim()
        u, l = 1e7 / np.array(limits)
        pseudo_ax.set_xlim(limits)
        sub_axis = pseudo_ax.xaxis

    elif axis == 'y':
        pseudo_ax = ax.twinx()
        limits = ax.get_ylim()
        u, l = 1e7 / np.array(limits)
        pseudo_ax.set_ylim(limits)
        sub_axis = pseudo_ax.yaxis
    else:
        raise ValueError('axis must be either x or y.')

    def conv(x, y):
        return '%.0f' % (1e7 / x)

    ff = plt.FuncFormatter(conv)
    sub_axis.set_major_formatter(ff)
    major = [1000, 500, 200, 100, 50]
    minor = [200, 100, 50, 25, 10]
    for x, m in zip(major, minor):
        a, b = math.ceil(u / x), math.ceil(l / x)
        n = abs(b - a)
        if n > 4:
            ticks = np.arange(a * x, b * x, x, )

            a, b = math.floor(u / m), math.floor(l / m)
            min_ticks = np.arange(a * m, b * m, m)

            break

    sub_axis.set_ticks(1e7 / ticks)
    sub_axis.set_ticks(1e7 / min_ticks, minor=True)
    if minor_ticks:
        ax.minorticks_on()
        # pseudo_ax.minorticks_on()
    if unit is 'nm':
        sub_axis.set_label('Wavelengths [nm]')
    elif unit is 'cm':
        sub_axis.set_label('Wavenumber [1/cm]')
Ejemplo n.º 52
0
    def plot2axes(self, axes: plt.Axes=None, edgecolor: str='black', fill: bool=False, plot_peaks: bool=True):
        """Plot the circle to an axes.

        Parameters
        ----------
        axes : matplotlib.Axes, None
            The axes to plot on. If None, will create a new figure of the image array.
        edgecolor : str
            Color of the Circle; must be a valid matplotlib color.
        fill : bool
            Whether to fill the circle. matplotlib keyword.
        plot_peaks : bool
            If True, plots the found peaks as well.
        """
        if axes is None:
            fig, axes = plt.subplots()
            axes.imshow(self.image_array)
        axes.add_patch(
            mpl_Circle((self.center.x, self.center.y), edgecolor=edgecolor, radius=self.radius, fill=fill))
        if plot_peaks:
            x_locs = [peak.x for peak in self.peaks]
            y_locs = [peak.y for peak in self.peaks]
            axes.autoscale(enable=False)
            axes.scatter(x_locs, y_locs, s=40, marker='x', c=edgecolor)
Ejemplo n.º 53
0
 def _plot_symmetry(self, direction: str, axis: plt.Axes=None):
     plt.ioff()
     if axis is None:
         fig, axis = plt.subplots()
     data = self.symmetry[direction.lower()]
     axis.set_title(direction.capitalize() + " Symmetry")
     axis.plot(data['profile'].values)
     # plot lines
     cax_idx = data['profile'].fwxm_center()
     axis.axvline(data['profile left'], color='g', linestyle='-.')
     axis.axvline(data['profile right'], color='g', linestyle='-.')
     axis.axvline(cax_idx, color='m', linestyle='-.')
     # plot symmetry array
     if not data['array'] == 0:
         twin_axis = axis.twinx()
         twin_axis.plot(range(cax_idx, data['profile right']), data['array'][int(round(len(data['array'])/2)):])
         twin_axis.set_ylabel("Symmetry (%)")
     _remove_ticklabels(axis)
     # plot profile mirror
     central_idx = int(round(data['profile'].values.size / 2))
     offset = cax_idx - central_idx
     mirror_vals = data['profile'].values[::-1]
     axis.plot(data['profile']._indices + 2 * offset, mirror_vals)
Ejemplo n.º 54
0
def plot_heatmap(
    dataset: netCDF4.Dataset,
    ax: plt.Axes = None,
    color: str = "charge",
    residues: list = None,
    zerobased: bool = False,
):
    """Plot the states, or the charges as colored blocks

    Parameters
    ----------
    dataset - netCDF$.Dataset containing Protons information.
    ax - matplotlib Axes object
    color - 'charge', 'state', 'taut' , color by charge, by state, or charge and shade by tautomer
    residues - list, residues to plot
    zerobased - bool default False - use zero based labeling for states.

    Returns
    -------
    ax - plt.Axes
    """
    # Convert to array, and make sure types are int
    if ax is None:
        ax = plt.gca()

    if zerobased:
        label_offset = 0
    else:
        label_offset = 1

    if color == "charge":
        vmin = -2
        vmax = 2
        center = 0
        cmap = sns.diverging_palette(
            25, 244, l=60, s=95, sep=80, center="light", as_cmap=True
        )
        ticks = np.arange(vmin, vmax + 1)
        boundaries = np.arange(vmin - 0.5, vmax + 1.5)
        cbar_kws = {"ticks": ticks, "boundaries": boundaries, "label": color.title()}

    elif color == "state":
        vmin = 0 + label_offset
        vmax = label_offset + np.amax(dataset["Protons/Titration/state"][:, :])
        ticks = np.arange(vmin, vmax + 1)
        boundaries = np.arange(vmin - 0.5, vmax + 1.5)
        cbar_kws = {"ticks": ticks, "boundaries": boundaries, "label": color.title()}
        center = None
        cmap = "Accent"

    else:
        raise ValueError("color argument should be 'charge', or 'state'.")

    to_plot = None
    if residues is None:
        if color == "charge":
            to_plot = charge_taut_trace(dataset)[0][:, :]
        elif color == "state":
            titration_states = dataset["Protons/Titration/state"][:, :]
            to_plot = titration_states + label_offset

    else:
        if isinstance(residues, int):
            residues = [residues]
        residues = np.asarray(residues).astype(np.int)
        if color == "charge":
            to_plot = charge_taut_trace(dataset)[0][:, residues]
        elif color == "state":
            to_plot = dataset["Protons/Titration/state"][:, residues] + label_offset

    ax = sns.heatmap(
        to_plot.T,
        ax=ax,
        vmin=vmin,
        vmax=vmax,
        center=center,
        xticklabels=int(np.floor(to_plot.shape[0] / 7)) - 1,
        yticklabels=int(np.floor(to_plot.shape[1] / 4)) - 1,
        cmap=cmap,
        cbar_kws=cbar_kws,
        edgecolor="None",
        snap=True,
    )

    for residue in range(to_plot.T.shape[1]):
        ax.axhline(residue, lw=0.4, c="w")

    ax.set_ylabel("Residue")
    ax.set_xlabel("Update")
    return ax
Ejemplo n.º 55
0
def plot_tautomer_heatmap(
    dataset: netCDF4.Dataset,
    ax: plt.Axes = None,
    residues: list = None,
    zerobased: bool = False,
):
    """Plot the charge of residues on a blue-red (negative-positive) scale, and add different shades for different tautomers.

    Parameters
    ----------
    dataset - netCDF4 dataset containing protons data
    ax - matplotlib Axes object
    residues - list, residues to plot
    zerobased - bool default False - use zero based labeling for states.

    Returns
    -------
    plt.Axes

    """
    # Convert to array, and make sure types are int
    if ax is None:
        ax = plt.gca()

    if zerobased:
        label_offset = 0
    else:
        label_offset = 1

    # color charges, and add shade for tautomers
    vmin = -2
    vmax = 2
    center = 0
    cmap = sns.diverging_palette(
        25, 244, l=60, s=95, sep=80, center="light", as_cmap=True
    )
    ticks = np.arange(vmin, vmax + 1)
    boundaries = np.arange(vmin - 0.5, vmax + 1.5)
    cbar_kws = {"ticks": ticks, "boundaries": boundaries, "label": "Charge"}

    taut_vmin = 0 + label_offset
    taut_vmax = label_offset + np.amax(dataset["Protons/Titration/state"][:, :])
    taut_ticks = np.arange(taut_vmin, taut_vmax + 1)
    taut_boundaries = np.arange(taut_vmin - 0.5, taut_vmax + 1.5)
    taut_cbar_kws = {"boundaries": taut_boundaries}

    taut_center = None
    taut_cmap = "Greys"

    to_plot = None
    if residues is None:
        to_plot, taut_to_plot = charge_taut_trace(dataset)

    else:
        if isinstance(residues, int):
            residues = [residues]
        residues = np.asarray(residues).astype(np.int)
        charges, tauts = charge_taut_trace(dataset)
        to_plot = charges[:, residues]
        taut_to_plot = tauts[:, residues]

    mesh = ax.pcolor(to_plot.T, cmap=cmap, vmin=vmin, vmax=vmax, snap=True, alpha=1.0)
    plt.colorbar(mesh, ax=ax, **cbar_kws)
    taut_mesh = ax.pcolor(
        taut_to_plot.T,
        cmap=taut_cmap,
        vmin=taut_vmin,
        vmax=taut_vmax,
        alpha=0.1,
        snap=True,
    )

    for residue in range(to_plot.T.shape[0]):
        ax.axhline(residue, lw=0.4, c="w")

    ax.set_ylabel("Residue")
    ax.set_xlabel("Update")

    return ax
Ejemplo n.º 56
0
def plot_line(geom: Union[Line, Iterable[Line]], *args,
        ax: Axes=None, crs: CRS=None, **kwargs):
    """ Plot a Line geometry, projected to the coordinate system `crs` """
    x, y = geom.get_coordinate_lists(crs=crs)
    return ax.plot(x, y, *args, **kwargs)
Ejemplo n.º 57
0
def plot_point(geom: Union[Point, Iterable[Point]], *args,
        ax: Axes=None, crs: CRS=None, **kwargs):
    """ Plot a Point geometry, projected to the coordinate system `crs` """
    kwargs.setdefault("marker", ".")
    x, y = geom.get_vertex(crs=crs)
    return ax.plot(x, y, *args, **kwargs)
Ejemplo n.º 58
0
def _remove_ticklabels(axis: plt.Axes):
    axis.get_yaxis().set_ticklabels([])
    axis.get_xaxis().set_ticklabels([])
Ejemplo n.º 59
0
def label_ticks(xs: Iterable[float], ys: Iterable[float], ax: Axes=None,
        map_crs: CRS=Cartesian, graticule_crs: CRS=SphericalEarth,
        textargs=None, tickargs=None,
        xformatter=None, yformatter=None):
    """ Label graticule lines, returning a list if Text objects.

    Parameters
    ----------
    xs : Iterable[float],
    ys : Iterable[float]
        Easting and northing componenets of labels, in `graticule_crs`
    ax : Axes, optional
        Axes to draw to (default current Axes)
    map_crs : karta.crs.CRS, optional
        CRS giving the display projection (default Cartesian)
    graticule_crs : karta.crs.CRS, optional
        CRS giving the graticule/label projection (default SphericalEarth)
    textargs : dict, optional
        Keyword arguments to pass to plt.text
    tickargs : dict, optional
        Keyword arguments to pass to plt.plot
    xformatter : callable, optional
        function that given an easting/longitude returns a label
    yformatter : callable, optional
        function that given a northing/latitude returns a label
    """
    if textargs is None:
        textargs = dict()

    if tickargs is None:
        tickargs = dict(marker="+", mew=2, ms=14, mfc="k", mec="k", ls="none")

    if xformatter is None:
        xformatter = lambda x: "{0} E".format(x)

    if yformatter is None:
        yformatter = lambda y: "{0} N".format(y)

    # Find tick locations
    bbox = get_axes_extent(ax, map_crs, graticule_crs)  # bottom, right, top, left

    ticks = dict(xticks=[], yticks=[])

    xmin, xmax = sorted(ax.get_xlim())
    ymin, ymax = sorted(ax.get_ylim())

    # bottom spine
    for x in xs:
        if isbetween(x, bbox[0][0], bbox[1][0]):
            ticks["xticks"].append((froot(lambda xt:
                                          map_crs.transform(graticule_crs, xt, ymin)[0]-x,
                                          xmin, xmax), ymin, xformatter(x)))

    for y in ys:
        if isbetween(y, bbox[0][1], bbox[1][1]):
            ticks["yticks"].append((froot(lambda xt:
                                          map_crs.transform(graticule_crs, xt, ymin)[1]-y,
                                          xmin, xmax), ymin, yformatter(y)))

    # top spine
    for x in xs:
        if isbetween(x, bbox[2][0], bbox[3][0]):
            ticks["xticks"].append((froot(lambda xt:
                                          map_crs.transform(graticule_crs, xt, ymax)[0]-x,
                                          xmin, xmax), ymax, xformatter(x)))

    for y in ys:
        if isbetween(y, bbox[2][1], bbox[3][1]):
            ticks["yticks"].append((froot(lambda xt:
                                          map_crs.transform(graticule_crs, xt, ymax)[1]-y,
                                          xmin, xmax), ymax, yformatter(y)))

    # left spine
    for x in xs:
        if isbetween(x, bbox[0][0], bbox[3][0]):
            ticks["xticks"].append((xmin,
                                    froot(lambda yt:
                                          map_crs.transform(graticule_crs, xmin, yt)[0]-x,
                                          ymin, ymax), xformatter(x)))


    for y in ys:
        if isbetween(y, bbox[0][1], bbox[3][1]):
            ticks["yticks"].append((xmin,
                                    froot(lambda yt:
                                          map_crs.transform(graticule_crs, xmin, yt)[1]-y,
                                          ymin, ymax), yformatter(y)))


    # right spine
    for x in xs:
        if isbetween(x, bbox[1][0], bbox[2][0]):
            ticks["xticks"].append((xmax,
                                    froot(lambda yt:
                                          map_crs.transform(graticule_crs, xmax, yt)[0]-x,
                                          ymin, ymax), xformatter(x)))

    for y in ys:
        if isbetween(y, bbox[1][1], bbox[2][1]):
            ticks["yticks"].append((xmax,
                                    froot(lambda yt:
                                          map_crs.transform(graticule_crs, xmax, yt)[1]-y,
                                          ymin, ymax), yformatter(y)))

    # Update map
    txts = []
    for pt in ticks["xticks"]:
        ax.plot(pt[0], pt[1], **tickargs)
        txts.append(ax.text(pt[0], pt[1], pt[2], **textargs))

    for pt in ticks["yticks"]:
        ax.plot(pt[0], pt[1], **tickargs)
        txts.append(ax.text(pt[0], pt[1], pt[2], **textargs))

    ax.set_xticks([])
    ax.set_yticks([])
    return txts
Ejemplo n.º 60
0
    def _plot_deviation(self, item: str, ax: plt.Axes=None, show: bool=True):
        """Helper function: Plot the sag in Cartesian coordinates.

        Parameters
        ----------
        item : {'gantry', 'epid', 'collimator', 'couch'}
            The axis to plot.
        ax : None, matplotlib.Axes
            The axis to plot to. If None, creates a new plot.
        show : bool
            Whether to show the image.
        """
        title = f'Relative {item} displacement'
        if item == EPID:
            attr = 'epid'
            item = GANTRY
        else:
            attr = 'bb'
        # get axis images, angles, and shifts
        imgs = [image for image in self.images if image.variable_axis in (item, REFERENCE)]
        angles = [getattr(image, '{}_angle'.format(item.lower())) for image in imgs]
        z_sag = np.array([getattr(image, attr + '_z_offset') for image in imgs])
        y_sag = np.array([getattr(image, attr + '_y_offset') for image in imgs])
        x_sag = np.array([getattr(image, attr + '_x_offset') for image in imgs])
        rms = np.sqrt(x_sag**2+y_sag**2+z_sag**2)

        # plot the axis deviation
        if ax is None:
            ax = plt.subplot(111)
        ax.plot(angles, z_sag, 'bo', label='In/Out', ls='-.')
        ax.plot(angles, x_sag, 'm^', label='Left/Right', ls='-.')
        if item not in (COUCH, COLLIMATOR):
            ax.plot(angles, y_sag, 'r*', label='Up/Down', ls='-.')
        ax.plot(angles, rms, 'g+', label='RMS', ls='-')
        ax.set_title(title)
        ax.set_ylabel('mm')
        ax.set_xlabel(f"{item} angle")
        ax.set_xticks(np.arange(0, 361, 45))
        ax.set_xlim([-15, 375])
        ax.grid(True)
        ax.legend(numpoints=1)
        if show:
            plt.show()