Exemplo n.º 1
0
def draw_elements(
    ax: mpl.axes.Axes,
    lattice: Lattice,
    *,
    labels: bool = True,
    location: str = "top",
):
    """Draw the elements of a lattice onto a matplotlib axes."""
    x_min, x_max = ax.get_xlim()
    y_min, y_max = ax.get_ylim()
    rect_height = 0.05 * (y_max - y_min)
    if location == "top":
        y0 = y_max = y_max + rect_height
    else:
        y0 = y_min - rect_height
        y_min -= 3 * rect_height
        plt.hlines(y0, x_min, x_max, color="black", linewidth=1)
    ax.set_ylim(y_min, y_max)

    sign = -1
    start = end = 0
    for element, group in groupby(lattice.sequence):
        start = end
        end += element.length * sum(1 for _ in group)
        if end <= x_min:
            continue
        elif start >= x_max:
            break

        try:
            color = ELEMENT_COLOR[type(element)]
        except KeyError:
            continue

        y0_local = y0
        if isinstance(element, Dipole) and element.angle < 0:
            y0_local += rect_height / 4

        ax.add_patch(
            plt.Rectangle(
                (max(start, x_min), y0_local - 0.5 * rect_height),
                min(end, x_max) - max(start, x_min),
                rect_height,
                facecolor=color,
                clip_on=False,
                zorder=10,
            ))
        if labels and type(element) in {Dipole, Quadrupole}:
            sign = -sign
            ax.annotate(
                element.name,
                xy=(0.5 * (start + end), y0 + sign * rect_height),
                fontsize=FONT_SIZE,
                ha="center",
                va="bottom" if sign > 0 else "top",
                annotation_clip=False,
                zorder=11,
            )
Exemplo n.º 2
0
def label_on_bar(ax: matplotlib.axes.Axes,
                 rects: matplotlib.container.BarContainer, sample_sizes: List,
                 means: List, sems: List) -> None:
    for i, rect in enumerate(rects):
        height = rect.patches[0].get_height()
        N_sample, mean, sem = sample_sizes[i], "{:.3f}".format(
            means[i]), "{:.3f}".format(sems[i])
        info_str = f"N={N_sample}\nMean={mean}\nSEM={sem}"
        ax.annotate(info_str,
                    xy=(rect.patches[0].get_x() +
                        rect.patches[0].get_width() / 10, height / 2))
Exemplo n.º 3
0
def annotate_bars(ax: matplotlib.axes.Axes, size: Union[str, int] = "small") -> None:
    if not ax.patches:
        return
    heights = [p.get_height() for p in ax.patches]
    heights = [0 if math.isnan(x) else x for x in heights]
    total = sum(heights)
    for p in ax.patches:
        text = "{} ({:.1%})".format(p.get_height(), p.get_height() / total)
        ax.annotate(
            text,
            (p.get_x() + p.get_width() / 2.0, p.get_height() + 0.05),
            ha="center",
            va="center",
            xytext=(0, 10),
            textcoords="offset points",
            size=size,
        )
Exemplo n.º 4
0
def draw_sub_lattices(
    ax: mpl.axes.Axes,
    lattice: Lattice,
    *,
    labels: bool = True,
    location: str = "top",
):
    x_min, x_max = ax.get_xlim()
    length_gen = [0.0, *(obj.length for obj in lattice.children)]
    position_list = np.add.accumulate(length_gen)
    i_min = np.searchsorted(position_list, x_min)
    i_max = np.searchsorted(position_list, x_max, side="right")
    ticks = position_list[i_min:i_max]
    ax.set_xticks(ticks)
    ax.grid(color=Color.LIGHT_GRAY, linestyle="--", linewidth=1)

    if labels:
        y_min, y_max = ax.get_ylim()
        height = 0.08 * (y_max - y_min)
        if location == "top":
            y0 = y_max - height
        else:
            y0, y_min = y_min - height / 3, y_min - height

        ax.set_ylim(y_min, y_max)
        start = end = 0
        for obj in lattice.children:
            end += obj.length
            if not isinstance(obj, Lattice) or start >= x_max or end <= x_min:
                continue

            x0 = 0.5 * (max(start, x_min) + min(end, x_max))
            ax.annotate(
                obj.name,
                xy=(x0, y0),
                fontsize=FONT_SIZE + 2,
                fontstyle="oblique",
                va="center",
                ha="center",
                clip_on=True,
                zorder=102,
            )
            start = end
Exemplo n.º 5
0
def plot_correlations(
    fig: matplotlib.figure.Figure,
    ax: matplotlib.axes.Axes,
    r2: float,
    slope: float,
    y_inter: float,
    corr_vals: np.ndarray,
    vis_vals: np.ndarray,
    scale_factor: Union[float, int],
    corr_bname: str,
    vis_bname: str,
    odir: Union[Path, str],
):
    """
    Plot the correlations between NIR band and the visible bands for
    the Hedley et al. (2005) sunglint correction method

    Parameters
    ----------
    fig : matplotlib.figure object
        Reusing a matplotlib.figure object to avoid the creation many
        fig instantances

    ax : matplotlib.axes._subplots object
        Reusing the axes object

    r2 : float
        The correlation coefficient squared of the linear regression
        between NIR and a VIS band

    slope : float
        The slope/gradient of the linear regression between NIR and
        a VIS band

    y_inter : float
        The intercept of the linear regression between NIR and a
        VIS band

    corr_vals : numpy.ndarray
        1D array containing the NIR values from the ROI

    vis_vals : numpy.ndarray
        1D array containing the VIS values from the ROI

    scale_factor : int or None
        The scale factor used to convert integers to reflectances
        that range [0...1]

    corr_bname : str
        The NIR band number

    vis_bname : str
        The VIS band number

    odir : str
        Directory where the correlation plots are saved

    """
    # clear previous plot
    ax.clear()

    # ----------------------------------- #
    #   Create a unique cmap for hist2d   #
    # ----------------------------------- #
    ncolours = 256

    # get the jet colormap
    colour_array = plt.get_cmap("jet")(range(ncolours))  # 256 x 4

    # change alpha values
    # e.g. low values have alpha = 1, high values have alpha = 0
    # color_array[:,-1] = np.linspace(1.0,0.0,ncolors)
    # e.g. low values have alpha = 0, high values have alpha = 1
    # color_array[:,-1] = np.linspace(0.0,1.0,ncolors)

    # We want only the first few colours to have low alpha
    # as they would represent low density [meshgrid] bins
    # which we are not interested in, and hence would want
    # them to appear as a white colour (alpha ~ 0)
    num_alpha = 25
    colour_array[0:num_alpha, -1] = np.linspace(0.0, 1.0, num_alpha)
    colour_array[num_alpha:, -1] = 1

    # create a colormap object
    cmap = LinearSegmentedColormap.from_list(name="jet_alpha",
                                             colors=colour_array)

    # ----------------------------------- #
    #  Plot density using np.histogram2d  #
    # ----------------------------------- #
    xbin_low, xbin_high = np.percentile(corr_vals, (1, 99),
                                        interpolation="linear")
    ybin_low, ybin_high = np.percentile(vis_vals, (1, 99),
                                        interpolation="linear")

    nbins = [int(xbin_high - xbin_low), int(ybin_high - ybin_low)]

    bin_range = [[int(xbin_low), int(xbin_high)],
                 [int(ybin_low), int(ybin_high)]]

    hist2d, xedges, yedges = np.histogram2d(x=corr_vals,
                                            y=vis_vals,
                                            bins=nbins,
                                            range=bin_range)

    # normalised hist to range [0...1] then rotate and flip
    hist2d = np.flipud(np.rot90(hist2d / hist2d.max()))

    # Mask zeros
    hist_masked = np.ma.masked_where(hist2d == 0, hist2d)

    # use pcolormesh to plot the hist2D
    qm = ax.pcolormesh(xedges, yedges, hist_masked, cmap=cmap)

    # create a colour bar axes within ax
    cbaxes = inset_axes(
        ax,
        width="3%",
        height="30%",
        bbox_to_anchor=(0.37, 0.03, 1, 1),
        loc="lower center",
        bbox_transform=ax.transAxes,
    )

    # Add a colour bar inside the axes
    fig.colorbar(
        cm.ScalarMappable(cmap=cmap),
        cax=cbaxes,
        ticks=[0.0, 1],
        orientation="vertical",
        label="Point Density",
    )

    # ----------------------------------- #
    #     Plot linear regression line     #
    # ----------------------------------- #
    x_range = np.array([xbin_low, xbin_high])
    (ln, ) = ax.plot(
        x_range,
        slope * (x_range) + y_inter,
        color="k",
        linestyle="-",
        label="linear regr.",
    )

    # ----------------------------------- #
    #          Format the figure          #
    # ----------------------------------- #
    # add legend (top left)
    lgnd = ax.legend(loc=2, fontsize=10)

    # add annotation
    ann_str = (r"$r^{2}$" + " = {0:0.2f}\n"
               "slope = {1:0.2f}\n"
               "y-inter = {2:0.2f}".format(r2, slope, y_inter))
    ann = ax.annotate(ann_str,
                      xy=(0.02, 0.76),
                      xycoords="axes fraction",
                      fontsize=10)

    # Add labels to figure
    xlabel = f"Reflectance ({corr_bname})"
    ylabel = f"Reflectance ({vis_bname})"

    if scale_factor is not None:
        if scale_factor > 1:
            xlabel += " " + r"$\times$" + " {0}".format(int(scale_factor))
            ylabel += " " + r"$\times$" + " {0}".format(int(scale_factor))

    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    # plt.show(); sys.exit()

    # Save figure
    png_file = os.path.join(
        odir, "Correlation_{0}_vs_{1}.png".format(corr_bname, vis_bname))

    fig.savefig(png_file,
                format="png",
                bbox_inches="tight",
                pad_inches=0.1,
                dpi=300)

    # delete all lines and annotations from figure,
    # so it can be reused in the next iteration
    qm.remove()
    ln.remove()
    ann.remove()
    lgnd.remove()
Exemplo n.º 6
0
def floor_plan(
    ax: mpl.axes.Axes,
    lattice: Lattice,
    *,
    start_angle: float = 0,
    labels: bool = True,
):
    ax.set_aspect("equal")
    codes = Path.MOVETO, Path.LINETO
    current_angle = start_angle
    start = np.zeros(2)
    end = np.zeros(2)
    x_min = y_min = 0
    x_max = y_max = 0
    sign = 1
    for element, group in groupby(lattice.sequence):
        start = end.copy()
        length = element.length * sum(1 for _ in group)
        if isinstance(element, Drift):
            color = Color.BLACK
            line_width = 1
        else:
            color = ELEMENT_COLOR[type(element)]
            line_width = 6

        # TODO: refactor current angle
        angle = 0
        if isinstance(element, Dipole):
            angle = element.k0 * length
            radius = length / angle
            vec = radius * np.array([np.sin(angle), 1 - np.cos(angle)])
            sin = np.sin(current_angle)
            cos = np.cos(current_angle)
            rot = np.array([[cos, -sin], [sin, cos]])
            end += rot @ vec

            angle_center = current_angle + 0.5 * np.pi
            center = start + radius * np.array(
                [np.cos(angle_center),
                 np.sin(angle_center)])
            diameter = 2 * radius
            arc_angle = -90
            theta1 = current_angle * 180 / np.pi
            theta2 = (current_angle + angle) * 180 / np.pi
            if angle < 0:
                theta1, theta2 = theta2, theta1

            line = patches.Arc(
                center,
                width=diameter,
                height=diameter,
                angle=arc_angle,
                theta1=theta1,
                theta2=theta2,
                color=color,
                linewidth=line_width,
            )
            current_angle += angle
        else:
            end += length * np.array(
                [np.cos(current_angle),
                 np.sin(current_angle)])
            line = patches.PathPatch(Path((start, end), codes),
                                     color=color,
                                     linewidth=line_width)

        x_min = min(x_min, end[0])
        y_min = min(y_min, end[1])
        x_max = max(x_max, end[0])
        y_max = max(y_max, end[1])

        ax.add_patch(line)  # TODO: currently splitted elements get drawn twice

        if labels and isinstance(element, (Dipole, Quadrupole)):
            angle_center = (current_angle - 0.5 * angle) + 0.5 * np.pi
            sign = -sign
            center = 0.5 * (start + end) + 0.5 * sign * np.array(
                [np.cos(angle_center),
                 np.sin(angle_center)])
            ax.annotate(
                element.name,
                xy=center,
                fontsize=6,
                ha="center",
                va="center",
                # rotation=(current_angle * 180 / np.pi -90) % 180,
                annotation_clip=False,
                zorder=11,
            )

    margin = 0.01 * max((x_max - x_min), (y_max - y_min))
    ax.set_xlim(x_min - margin, x_max + margin)
    ax.set_ylim(y_min - margin, y_max + margin)
    return ax
Exemplo n.º 7
0
def annotate(X: Union[np.ndarray, Series, List, Tuple],
             Y: Union[np.ndarray, Series, List, Tuple],
             T: Union[np.ndarray, Series, List, Tuple],
             subset: Optional[Union[np.ndarray, Series, List, Tuple]] = None,
             ax: mpl.axes.Axes = None,
             word_shorten: Optional[int] = None,
             **annotate_kws):
    """Annotates a matplotlib plot with text.

    Offsets are pre-determined according to the scale of the plot.

    Parameters
    ----------
    X : list/tuple/np.ndarray/pd.Series (1d)
        The x-positions of the text.
    Y : list/tuple/np.ndarray/pd.Series (1d)
        The y-positions of the text.
    T : list/tuple/np.ndarray/pd.Series (1d)
        The text array.
    subset : list/tuple/np.ndarray/pd.Series (1d)
        An array of indices to select a subset from.
    ax : matplotlib.ax.Axes, optional, default=None
        If None, creates one.
    word_shorten : int, optional
        If not None, shortens annotated strings to be more concise and displayable

    Other Parameters
    ----------------
    annotate_kws : dict
        Other keywords to pass to `ax.annotate`

    Returns
    -------
    ax : matplotlib.ax.Axes
        The same matplotlib plot, or the one generated
    """
    instance_check((X, Y), (list, tuple, np.ndarray, Series))
    instance_check(T, (list, tuple, np.ndarray, Series, Index))
    instance_check(subset,
                   (type(None), list, tuple, np.ndarray, Series, Index))
    instance_check(ax, (type(None), mpl.axes.Axes))
    arrays_equal_size(X, Y, T)
    # convert to numpy.
    _X = as_flattened_numpy(X).copy()
    _Y = as_flattened_numpy(Y).copy()
    _T = as_flattened_numpy(T)

    if word_shorten:
        _T = shorten(_T, newl=word_shorten)

    if _X.dtype.kind == "f":
        _X += (_X.max() - _X.min()) / 30.0
        _Y += -((_Y.max() - _Y.min()) / 30.0)

    if ax is None:
        fig, ax = plt.subplots(figsize=(8, 5))
    if subset is None:
        for x, y, t in it.zip_longest(_X, _Y, _T):
            ax.annotate(t, xy=(x, y), **annotate_kws)
    else:
        for i in subset:
            ax.annotate(_T[i], xy=(_X[i], _Y[i]), **annotate_kws)

    return ax