Exemple #1
0
def plot_locs(
    ax: Axes,
    bp: int,
    slen: int,
    plocs: np.ndarray,
    galaxy_probs: np.ndarray,
    m: str = "x",
    s: float = 20,
    lw: float = 1,
    alpha: float = 1,
    annotate=False,
    cmap: str = "bwr",
) -> None:
    n_samples, xy = plocs.shape
    assert galaxy_probs.shape == (n_samples,) and xy == 2

    x = plocs[:, 1] - 0.5 + bp
    y = plocs[:, 0] - 0.5 + bp
    for i, (xi, yi) in enumerate(zip(x, y)):
        prob = galaxy_probs[i]
        cmp = mpl.cm.get_cmap(cmap)
        color = cmp(prob)
        if bp < xi < slen - bp and bp < yi < slen - bp:
            ax.scatter(xi, yi, color=color, marker=m, s=s, lw=lw, alpha=alpha)
            if annotate:
                ax.annotate(f"{galaxy_probs[i]:.2f}", (xi, yi), color=color, fontsize=8)
Exemple #2
0
    def __plot_bunches(self,
                       fig: plt.Figure,
                       ax: plt.Axes,
                       point: FiniteMetricVertex,
                       name: str = "u") -> None:
        """
        Plot all points and highlight the bunches for the given point on
        the provided figure/axes.

        :param fig: The matplotlib figure to plot on.
        :param ax: The matplotlib axes to plot on.
        :param point: The vertex whose bunches we wish to plot.
        :param name: The name to use to label the vertex/bunches.
        """
        ax.cla()

        # Plot all points and color by set A_i
        ax.scatter([v.i for v in self.vertices], [v.j for v in self.vertices],
                   s=4,
                   color="black",
                   marker=".",
                   label="Points")

        # Plot and label the point itself
        ax.scatter([point.i], [point.j],
                   s=12,
                   color="red",
                   marker="*",
                   label=name)
        ax.annotate(name, (point.i, point.j), color="red")

        # Force the xlim and ylim to become fixed
        ax.set_xlim(*ax.get_xlim())
        ax.set_ylim(*ax.get_ylim())

        # For the current point, mark and label its p_i s
        # and add circles
        p_i = [self.p[point][i] for i in range(self.k)]
        for i in range(1, self.k):
            if p_i[i] is None:
                continue
            ax.annotate("p_{}({})".format(i, name), (p_i[i][0].i, p_i[i][0].j),
                        xytext=(5, 5),
                        textcoords="offset pixels",
                        color="violet")
            circ = plt.Circle((point.i, point.j), p_i[i][1], fill=False)
            ax.add_patch(circ)

        # Plot the points in the bunch
        B = [w for w in self.B[point]]
        ax.scatter([w.i for w in B], [w.j for w in B],
                   s=12,
                   color="lime",
                   marker="*",
                   label="B({})".format(name))

        ax.set_title("Bunch B({})".format(name))
        ax.legend(bbox_to_anchor=(1.05, 1.0), loc='upper left')
        plt.tight_layout()
        fig.show()
Exemple #3
0
def label_pupils(pupil_data: pd.DataFrame, axes: plt.Axes, color: str,
                 xy_text: tuple):
    """
    A pupil labeling procedure which leverages pupil data to apply annotations.
    In particular, this function labels the min and max dots in a scatter plot
    with an annotation and a dark outline.

    :param pupil_data: a special dataframe of pupil data
    :param axes: the axes to plot on
    :param color: the color of the scatter plot dots
    :param xy_text: the offset for the label
    :return: None
    """
    edge_data = generate_pupil_edge_colors(pupil_data["Normalized"], color)
    axes.annotate(f'{pupil_data["Raw"].max():.2f} mm',
                  (pupil_data["Time"][edge_data[1]],
                   pupil_data["Category"][edge_data[1]]),
                  textcoords="offset points",
                  ha='center',
                  va='center',
                  xytext=xy_text)
    axes.annotate(f'{pupil_data["Raw"].min():.2f} mm',
                  (pupil_data["Time"][edge_data[0]],
                   pupil_data["Category"][edge_data[0]]),
                  textcoords="offset points",
                  ha='center',
                  va='center',
                  xytext=xy_text)
    # noinspection PyTypeChecker
    axes.scatter(pupil_data["Time"],
                 pupil_data["Category"],
                 s=pupil_data["Normalized"],
                 edgecolors=edge_data[2],
                 color=color)
Exemple #4
0
    def _watermark_axis(self, ax: plt.Axes):
        """Add a watermark.

        Args:
            ax (plt.Axes): axes
        """
        # self.logger.debug('WATERMARK')
        kw = dict(fontsize=15,
                  color='gray',
                  ha='right',
                  va='bottom',
                  alpha=0.18,
                  zorder=-100)
        ax.annotate('Qiskit Metal',
                    xy=(0.98, 0.02),
                    xycoords='axes fraction',
                    **kw)

        file = (self.gui.path_imgs / 'metal_logo.png')
        if file.is_file():
            #print(f'Found {file} for watermark.')
            _axis_set_watermark_img(ax, file, size=0.15)
        else:
            # import error?
            self.logger.error(f'Error could not load {file} for watermark.')
def draw_he_annotate(ax: plt.Axes, heat_time, he_points):
    for i in range(len(he_points[0])):
        time = heat_time[0][he_points[0][i]]
        time_name = convert_time(time)
        height = he_points[1][i]
        ax.annotate(time_name,
                    xy=(time, height),
                    xytext=(time, height + 5),
                    arrowprops=dict(facecolor='black', shrink=0.05))
Exemple #6
0
def plot_sol_points(fig: plt.Figure,
                    ax: plt.Axes,
                    R: PlanarTriangle,
                    P: Triangle,
                    points: List[int] = [0, 1, 2]) -> None:
    sol = R.min_max_traversal_triangle(P)
    label_offsets = np.array([[-0.05, -0.01], [-0.02, 0.02], [0.015, -0.0075]])
    for point in points:
        ax.scatter(*sol.points[point], c="green")
        ax.annotate(f"$q_{point}$", sol.points[point] + label_offsets[point])
Exemple #7
0
def truncated_countplot(
    x   : pd.Series,
    val : Any = 'mode',
    ax  : plt.Axes = None
    ) -> plt.Axes:
    """
    Truncated count plot to visualize more values when one dominates

    Arguments:
        x :
            Data Series
        val :
            Value to truncate in count plot. 'mode' will truncate the data mode.
        ax :
            matplotlib Axes object to draw plot onto
    Returns:
        ax :
            Returns the Axes object with the plot drawn onto it
    """
    # Setup Axes
    if not ax:
        fig, ax = plt.subplots()
    ax.set_xlabel(x.name)
    ax.set_ylabel('Counts')

    if val is None:
        sns.countplot(x=x, ax=ax)
        return

    if val == 'mode':
        val = x.mode().iloc[0]

    # Plot and truncate
    splot = sns.countplot(x=x, ax=ax)
    ymax = x[x != val].value_counts().iloc[0]*1.4
    ax.set_ylim(0, ymax)

    # Annotate truncated bin
    xticklabels = [x.get_text() for x in ax.get_xticklabels()]
    val_ibin = xticklabels.index(str(val))
    val_bin = splot.patches[val_ibin]
    xloc = val_bin.get_x() + 0.5*val_bin.get_width()
    yloc = ymax
    ax.annotate('', xy=(xloc, 0), xytext=(xloc, yloc), xycoords='data',
                arrowprops=dict(arrowstyle = '<-', color = 'black', lw = '4')
               )
    val_count = (x == val).sum()
    val_perc = val_count / len(x)
    ax.annotate(f'{val} (count={val_count}; {val_perc:.0%} of total)',
                xy=(0.5, 0), xytext=(0.5, 0.9), xycoords='axes fraction',
                ha='center'
               )

    return ax
Exemple #8
0
def plot_triv(fig: plt.Figure,
              ax: plt.Axes,
              R: PlanarTriangle,
              P: Triangle,
              i: int,
              annotate: bool = False) -> None:
    triv = get_triv(R, P, i)
    ax.scatter(*triv.points[2], color="blue")
    triv.plot(fig, ax, "blue")

    offsets = np.array([[-0.05, 0], [0.015, 0], [0.015, 0]])
    ax.annotate(f"$t_{i}$", triv.points[2] + offsets[i])
Exemple #9
0
    def __plot_p_i(self,
                   fig: plt.Figure,
                   ax: plt.Axes,
                   point: FiniteMetricVertex,
                   name: str = "u") -> None:
        """
        Plot all points and highlight the witnesses p_i for the given point
        along with corresponding rings on the given figure and axes.

        :param fig: The matplotlib figure to plot on.
        :param ax: The matplotlib axes to plot on.
        :param point: The vertex whose witnesses/rings we wish to plot.
        :param name: The name to use to label the vertex/bunches.
        """
        ax.cla()

        # Plot all points and color by set A_i
        for i, a_i in enumerate(self.A):
            ax.scatter([v.i for v in a_i], [v.j for v in a_i],
                       s=8,
                       marker="o",
                       label="A_{}".format(i))

        # Plot and label the point itself
        ax.scatter([point.i], [point.j],
                   s=12,
                   color="red",
                   marker="*",
                   label=name)
        ax.annotate(name, (point.i, point.j), color="red")

        # Force the xlim and ylim to become fixed
        ax.set_xlim(*ax.get_xlim())
        ax.set_ylim(*ax.get_ylim())

        # For the current point, mark and label its p_i s
        # and add circles
        p_i = [self.p[point][i] for i in range(self.k)]
        for i in range(1, self.k):
            if p_i[i] is None:
                continue
            ax.annotate("p_{}({})".format(i, name), (p_i[i][0].i, p_i[i][0].j),
                        xytext=(5, 5),
                        textcoords="offset pixels",
                        color="violet")
            circ = plt.Circle((point.i, point.j), p_i[i][1], fill=False)
            ax.add_patch(circ)

        ax.set_title("Witnesses p_i({}) and rings.".format(name))
        ax.legend(bbox_to_anchor=(1.05, 1.0), loc='upper left')
        plt.tight_layout()
        fig.show()
Exemple #10
0
def annotate_lines(
    ax_: Axes,
    num_lines: int = 1,  # None for all
    index_decimator: callable = default_index_decimator,
    color: str = "k",  # None for auto color
    xycoords: Tuple[str, str] = (
        "data",
        # 'axes fraction',
        "data",
    ),  # TODO: NOT DONE! Where to place annotation, use 'axes fraction' for along axes'
    ha: str = "left",
    va: str = "center",
    **kwargs,
) -> None:
    """

    :param ax_:
    :param num_lines:
    :param index_decimator:
    :param color:
    :param xycoords:
    :param ha:
    :param va:
    :param kwargs:
    """
    lines = ax_.lines
    if not num_lines:
        num_lines = len(lines)
    for l, _ in zip(lines, range(num_lines)):
        y = l.get_ydata()
        x = l.get_xdata()

        if not color:
            color = l.get_color()

        if index_decimator:
            mag_y = index_decimator(y)
        else:
            mag_y = list(range(len(y)))

        for x_, y_ in zip(x[mag_y], y[mag_y]):
            ax_.annotate(
                f"{y_:.2f}",
                xy=(x_, y_),  # ( 1, y_) axes fraction'
                xycoords=xycoords,
                ha=ha,
                va=va,
                color=color,
                **kwargs,
            )
Exemple #11
0
    def display_values(self, ax: plt.Axes) -> None:
        """Draw series values next to points

        :param ax: Plot Axes.
        """

        xytext = (self.size, -self.size // 2)
        for xy in zip(self.x, self.y):
            ax.annotate(
                xy[1] if self.share_x else f"{xy[0]}, {xy[1]}",
                xy=xy,
                color=self.color,
                xytext=xytext,
                textcoords="offset points",
                # arrowprops={"color": "#00FF00"},
            )
Exemple #12
0
def visualize_countries(model: Word2Vec,
                        vocabulary: Vocabulary,
                        ax: plt.Axes = None):
    countries = [
        'u.s.', 'u.k.', 'italy', 'korea', 'china', 'germany', 'japan',
        'france', 'russia', 'egypt'
    ]
    capitals = [
        'washington', 'london', 'rome', 'seoul', 'beijing', 'berlin', 'tokyo',
        'paris', 'moscow', 'cairo'
    ]

    vectors_2d = project_to_2d_by_pca(model, vocabulary)

    if ax is None:
        fig, ax = plt.subplots()
    else:
        fig = None

    # Plot countries
    country_ids = [vocabulary.to_id(word) for word in countries]
    country_vectors = vectors_2d[country_ids]
    ax.scatter(country_vectors[:, 0],
               country_vectors[:, 1],
               c='blue',
               alpha=0.7)
    for i, label in enumerate(countries):
        ax.annotate(label, (country_vectors[i, 0], country_vectors[i, 1]))

    # Plot capitals
    capital_ids = [vocabulary.to_id(word) for word in capitals]
    capital_vectors = vectors_2d[capital_ids]
    ax.scatter(capital_vectors[:, 0],
               capital_vectors[:, 1],
               c='orange',
               alpha=0.7)
    for i, label in enumerate(capitals):
        ax.annotate(label, (capital_vectors[i, 0], capital_vectors[i, 1]))

    # Draw arrows
    for country, capital in zip(countries, capitals):
        v1 = vectors_2d[vocabulary.to_id(country)]
        v2 = vectors_2d[vocabulary.to_id(capital)]
        ax.arrow(v1[0], v1[1], (v2 - v1)[0], (v2 - v1)[1], alpha=0.5)

    if fig is not None:
        fig.show()
Exemple #13
0
def plot_base(fig: plt.Figure,
              ax: plt.Axes,
              R: PlanarTriangle,
              P_ref: PlanarTriangle,
              annotate_points: bool = True,
              annotate_sides: bool = True,
              scale: bool = True) -> None:

    R.plot(fig, ax)
    ax.scatter(*(R.points.T), c="black")
    P_ref.plot(fig, ax, color="blue")

    if annotate_points:
        ax.annotate(f"$r_0$", R.points[0] - 0.035)
        ax.annotate(f"$r_1$", R.points[1] + 0.015)
        ax.annotate(f"$r_2$", R.points[2] + 0.015)

        labels = np.array([f"$p_{i}$"
                           for i in range(3)])[np.argsort(P_ref.angles)]
        ax.scatter(*(P_ref.points.T), c="blue")
        ax.annotate(labels[0], P_ref.points[0] - 0.035)
        ax.annotate(labels[1], P_ref.points[1] + 0.015)
        ax.annotate(labels[2], P_ref.points[2] + 0.015)

    if annotate_sides:
        midpoints = (P_ref.points + np.roll(P_ref.points, -1, axis=0)) / 2
        labels = ["\sqrt{3}", "1", "2"]
        if scale:
            offsets = np.array([
                [-0.05, -0.05],
                [0.01, 0.0225],
                [-0.075, 0.025],
            ])
            labels = [
                "$\\frac{" + label + "}{3 + \sqrt{3}}$" for label in labels
            ]
        else:
            offsets = np.array([
                [-0.05, -0.05],
                [0.01, 0.0225],
                [-0.01, 0.025],
            ])
            labels = [f"${label}$" for label in labels]

        for i in range(3):
            ax.annotate(labels[i], midpoints[i] + offsets[i])
Exemple #14
0
def add_annotation(ax: plt.Axes, **kwargs) -> None:
    """ Add a text annotation to a plot. """
    # get keyword arguments
    fontsize = kwargs.get("fontsize", 16)
    text = kwargs.get("text", "")
    location = kwargs.get("location", (0, 0))
    horizontal_alignment = kwargs.get("horizontal_alignment", "left")
    vertical_alignment = kwargs.get("vertical_alignment", "center")

    # add annotation
    ax.annotate(
        text,
        location,
        horizontalalignment=horizontal_alignment,
        verticalalignment=vertical_alignment,
        fontsize=fontsize,
    )
Exemple #15
0
def __draw_bboxes(axe: plt.Axes, bboxes: list[BBox], cmap_colors: list[str], class_names: list[str], show_bbox_label: bool):
    for bbox in bboxes:
        color = cmap_colors[bbox.cls]
        label = (f'{class_names[bbox.cls]}' if show_bbox_label else '') + \
            (f' ({round(bbox.score, 3)})' if bbox.score else '')
        rect = mpatches.Rectangle(bbox.upper_left_point,
                                  bbox.width, bbox.height,
                                  linewidth=2,
                                  edgecolor=color,
                                  facecolor='none')
        axe.add_patch(rect)
        axe.annotate(label.strip(),
                     (bbox.upper_left_point[0] + 5,
                      bbox.upper_left_point[1] - 5),
                     color='w',
                     weight='bold',
                     fontsize=10,
                     ha='left', va='bottom',
                     bbox=dict(facecolor=color, edgecolor='none', pad=1.5))
def plot_im(ax: plt.Axes, image: torch.Tensor, true_class: str):
    """Plots given image on given axes. Sets title of image to true_class.

    Args:
        ax (plt.Axes): Axes on which to draw.
        image (torch.Tensor): Tensor image which has to be drawn on axes.
        true_class (str): Name of image class.
    """
    image = un_normalize(image)
    image_cpu: np.ndarray = image.cpu().detach().numpy()
    im_max = np.max(image_cpu)
    im_min = np.min(image_cpu)
    image_cpu = (image_cpu - im_min) / (im_max - im_min)
    image_cpu = np.transpose(image_cpu, [1, 2, 0])
    ax.set_title(true_class)
    ax.annotate(true_class, xy=(112, 112))
    ax.grid(False)
    ax.set_axis_off()
    ax.imshow(image_cpu)
Exemple #17
0
def _reward_draw_spline(
    x: int,
    y: int,
    action: int,
    optimal: bool,
    reward: float,
    from_dest: bool,
    mappable: matplotlib.cm.ScalarMappable,
    annot_padding: float,
    ax: plt.Axes,
) -> Tuple[np.ndarray, Tuple[float, ...], str]:
    # Compute shape position and color
    pos = np.array([x, y])
    direction = np.array(ACTION_DELTA[action])
    if from_dest:
        pos = pos + direction
        direction = -direction
    vert = pos + OFFSETS[tuple(direction)]
    color = mappable.to_rgba(reward)

    # Add annotation
    text = f"{reward:.0f}"
    lum = sns.utils.relative_luminance(color)
    text_color = ".15" if lum > 0.408 else "w"
    hatch_color = ".5" if lum > 0.408 else "w"
    xy = pos + 0.5

    if tuple(direction) != (0, 0):
        xy = xy + annot_padding * direction
    fontweight = "bold" if optimal else None
    ax.annotate(
        text,
        xy=xy,
        ha="center",
        va="center_baseline",
        color=text_color,
        fontweight=fontweight,
    )

    return vert, color, hatch_color
Exemple #18
0
def getting_amplitude_and_dt(ax: plt.Axes, x: np.ndarray, cold: np.ndarray,
                             hot: np.ndarray) -> plt.Axes:
    """Adds hot and cold trace to axes with a straight line before and after transition to emphasise amplitude etc"""

    ax.set_title("Hot and cold part of transition")
    ax.set_xlabel('Sweep Gate (mV)')
    ax.set_ylabel('I (nA)')

    ax.plot(x, cold, color='blue', label='Cold', linewidth=1)
    ax.plot(x, hot, color='red', label='Hot', linewidth=1)

    # Add straight lines before and after transition to emphasise amplitude
    transition_width = 0.30
    before_transition_id = U.get_data_index(x,
                                            np.mean(x) - transition_width,
                                            is_sorted=True)
    after_transition_id = U.get_data_index(x,
                                           np.mean(x) + transition_width,
                                           is_sorted=True)

    line = lm.models.LinearModel()
    top_line = line.fit(cold[:before_transition_id],
                        x=x[:before_transition_id],
                        nan_policy='omit')
    bottom_line = line.fit(cold[after_transition_id:],
                           x=x[after_transition_id:],
                           nan_policy='omit')

    ax.plot(x, top_line.eval(x=x), linestyle=':', color='black')
    ax.plot(x, bottom_line.eval(x=x), linestyle=':', color='black')

    # Add vertical arrow between dashed lines
    # x_val = (np.mean(x) + x[-1]) / 2  # 3/4 along
    x_val = 0.20
    y_bot = bottom_line.eval(x=x_val)
    y_top = top_line.eval(x=x_val)
    arrow = ax.annotate(text='',
                        xy=(x_val, y_bot),
                        xytext=(x_val, y_top),
                        arrowprops=dict(arrowstyle='<|-|>', lw=1))
    text = ax.text(x=x_val + 0.02, y=(y_top + y_bot) / 2, s='dI/dN')

    ax.set_xlim(-0.5, 0.5)
    ax.legend(loc='center left')

    # Add horizontal lines to show thetas
    # TODO: should decide if I want to do this from fit results, or if it is just to give an idea theta..

    return ax
Exemple #19
0
def visualize_frequent_words(vectors_2d: np.ndarray,
                             dataset: DataSet,
                             k: int,
                             ax: plt.Axes = None) -> None:
    word_ids, counts = np.unique(dataset.data, return_counts=True)

    indices = np.argpartition(-counts, k)[:k]
    frequent_word_ids = word_ids[indices]

    if ax is None:
        fig, ax = plt.subplots(figsize=(13, 13))
    else:
        fig = None

    vectors_2d = vectors_2d[frequent_word_ids]

    ax.scatter(vectors_2d[:, 0], vectors_2d[:, 1], s=2, alpha=0.25)
    for i, id in enumerate(frequent_word_ids):
        ax.annotate(dataset.vocabulary.to_word(id),
                    (vectors_2d[i, 0], vectors_2d[i, 1]))

    if fig is not None:
        fig.tight_layout()
        fig.show()
Exemple #20
0
def add_contacts_to_plot(qc_frame: pd.DataFrame, axis: pyplot.Axes) -> None:
    if "OWC" in qc_frame:
        owc = qc_frame["OWC"].values[
            0]  # OWC is assumed constant in the dataframe
        axis.axhline(owc, color="black", linestyle="--", linewidth=1)
        axis.annotate(f"OWC={owc:g}", (0, owc))
    if "GOC" in qc_frame:
        goc = qc_frame["GOC"].values[0]
        axis.axhline(goc, color="black", linestyle="--", linewidth=1)
        axis.annotate(f"GOC={goc:g}", (0, goc))
    if "GWC" in qc_frame:
        gwc = qc_frame["GWC"].values[0]
        axis.axhline(gwc, color="black", linestyle="--", linewidth=1)
        axis.annotate(f"GWC={gwc:g}", (0, gwc))
def hexplot_TEST(
        node_data,
        decimals: int = 0,
        var_lim: Iterable = None,
        c: object = None,
        lc: str = 'k',
        #edge_data: Dict=None, edge_c='r',  # EDGE DATA NOT IMPLEMENTED YET
        ax: plt.Axes = None,
        scale_factor: float = 0.015):
    """
    Plot a map of the hexagonal lattice of the megaphragma compound eye. Each ommatidium will be labelled and coloured
    according to the strings and (r, g, b, a) values passed with node_data. This will also take a 1-col dataframe indexed by om
    TODO: fix dataframe input options, type of cmap as an arg, use plt.scatter instead of all the networkx functions? 
    :param node_data: Dict, {om: {'label': str,
                                 {'outline': matplotlib line spec, e.g. '-',
                                 {'colour':  (rgba)}}
    :param edge_data:
    :param ax: plt.Axes, if None, will plot to current Axis
    :param scale_factor: float, controls the spacing between nodes
    :c: Default node colour (if node_data: colours is None)
    :param e_colour: Default edge colour (if edge_data is used)
    :return:
    """
    if c == None:  # default color
        c = (0.2, 0.2, 0.2, 1)

    if not isinstance(node_data, Dict):
        node_data = __from_series(node_data,
                                  c=c,
                                  var_lim=var_lim,
                                  decimals=decimals)

    if ax == None:
        ax = plt.gca()

    om_list = sorted([str(om) for om in node_data.keys()])
    pos = [om_to_hex(o) for o in om_list]  # 2D figure coords of each om
    node_colours = []  #dict.fromkeys(om_list)
    node_outline = []  #dict.fromkeys(om_list)
    node_labels = []  #dict.fromkeys(om_list)

    #name_to_ind =
    for om, xy in zip(om_list, pos):

        if node_data[om].get('label', None) == None:
            label = om
        elif isinstance(node_data.get('label'), (int, float)):
            label = str(round(node_data.get('label'), decimals))
        else:
            label = node_data.get('label')

        if (node_data[om].get('colour') == None):
            fill_c = c
        else:
            fill_c = node_data[om].get('colour')

        x, y = (xy[0] * 0.01, xy[1] * 0.01)
        #y = xy[1] * 0.01

        ax.scatter(xy[0], xy[1], marker='H', color=fill_c, s=100)
        ax.annotate(label, xy, fontsize=8, color='w', ha='center', va='center')

    #ax.set_xlim((-30, 4))
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    #plt.axis('off')
    ax.set_aspect('equal')

    return ax
Exemple #22
0
def truncated_hist(
    x   : pd.Series,
    val : Any = 'mode',
    ax  : plt.Axes = None
    ) -> plt.Axes:
    """
    Truncated histogram to visualize more values when one dominates

    Arguments:
        x :
            Data Series
        val :
            Value to truncate in histogram. 'mode' will truncate the data mode.
        ax :
            matplotlib Axes object to draw plot onto
    Returns:
        ax :
            Returns the Axes object with the plot drawn onto it
    """
    # Setup Axes
    if not ax:
        fig, ax = plt.subplots()
    ax.set_xlabel(x.name)
    ax.set_ylabel('Counts')

    if val is None:
        ax.hist(x, bins='auto')
        return

    if val == 'mode':
        val = x.mode().iloc[0]

    # Plot without selected value
    sel_vals = x[x != val]
    bin_vals, bin_edges, _ = ax.hist(sel_vals, bins='auto')
    ax_min, ax_max = ax.get_ylim()
    ax.set_ylim(ax_min, ax_max*1.1)

    # Expand x-axis to include removed value and then annotate with the value's count
    ax_min, ax_max = ax.get_xlim()
    if val < min(sel_vals): # Lower xmin
        buff = abs(ax_max - max(sel_vals))
        ax.set_xlim(val - buff, ax_max)
        horiztonal_alignment = 'left'
        arrow_relpos = (0, 0)
    elif val > max(sel_vals): # Increase xmax
        buff = abs(ax_min - min(sel_vals))
        ax.set_xlim(ax_min, val + buff)
        horiztonal_alignment = 'right'
        arrow_relpos = (1, 1)
    else: # No change to x range
        horiztonal_alignment = 'center'
        arrow_relpos = (0.5, 0.5)
    val_count = (x == val).sum()
    val_perc = val_count/len(x)
    ax.annotate(f'{val} (count={val_count}; {val_perc:.0%} of total)',
                xy=(val, 0), xytext=(val, max(bin_vals)*1.1), xycoords='data',
                ha=horiztonal_alignment,
                arrowprops=dict(arrowstyle = '<-', color = 'black', lw = '4', relpos=arrow_relpos)
               )

    return ax
Exemple #23
0
    def __plot_query_state(self,
                           fig: plt.Figure,
                           ax: plt.Axes,
                           u: FiniteMetricVertex,
                           v: FiniteMetricVertex,
                           w: FiniteMetricVertex,
                           i: int,
                           final: bool = False) -> None:
        """
        Plot a single frame/plot that depicts the current state of the ADO
        query. Clears and overwrites any previous contents of the plot.

        :param fig: A Matplotlib figure object representing the figure being
        modified/displayed.
        :param ax: A Matplotlib Axes object representing the subplot being
        modified/displayed.
        :param u: The current value of u in the ADO query algorithm.
        :param v: The current value of v in the ADO query algorithm.
        :param w: The current value of w in the ADO query algorithm.
        :param i: The iteration of the ADO query algorithm.
        :param final: Whether or not this is the final query state/final
        iteration of the algorithm.
        """
        ax.cla()

        # Plot all the points in the graph
        ax.scatter([v.i for v in self.vertices], [v.j for v in self.vertices],
                   s=4,
                   color="black",
                   marker=".",
                   label="Points")

        # Plot u, v, w with special symbols/colors
        ax.scatter([u.i], [u.j], s=12, color="red", marker="*", label="u")
        ax.annotate("u", (u.i, u.j), color="red")
        ax.scatter([v.i], [v.j], s=12, color="green", marker="*", label="v")
        ax.annotate("v", (v.i, v.j), color="green")
        ax.scatter([w.i], [w.j], s=5, color="orange", marker="p", label="w")
        ax.annotate("w", (w.i, w.j),
                    color="orange",
                    xytext=(-15, -10),
                    textcoords="offset pixels")

        # For the current u, mark and label its p_i(u)s
        p_i_u = [self.p[u][i] for i in range(self.k)]
        ax.scatter([v[0].i for v in p_i_u], [v[0].j for v in p_i_u],
                   s=4,
                   color="violet",
                   marker="o",
                   label="p_i(u)")
        for j in range(1, self.k):
            ax.annotate("p_{}(u)".format(j), (p_i_u[j][0].i, p_i_u[j][0].j),
                        xytext=(5, 5),
                        textcoords="offset pixels",
                        color="violet")

        # For the current v, highlight its batch B(v) in a different color
        B_v = [w for w in self.B[v]]
        ax.scatter([w.i for w in B_v], [w.j for w in B_v],
                   s=4,
                   color="lime",
                   marker="*",
                   label="B(v)")

        # Draw line from u to current w
        ax.add_line(Line2D([u.i, w.i], [u.j, w.j], color="pink"))

        # For the final plot, draw a line from w to current v as well
        if final:
            ax.add_line(Line2D([w.i, v.i], [w.j, v.j], color="palegreen"))

        title = "Iteration {} (final)".format(i) \
            if final else "Iteration {}".format(i)

        ax.set_title(title)
        ax.legend(bbox_to_anchor=(1.05, 1.0), loc='upper left')
        plt.tight_layout()
        fig.show()
Exemple #24
0
def annot_bars(
    ax: plt.Axes,
    dist: float = 0.15,
    color: str = "k",
    compact: bool = False,
    orient: str = "h",
    format_spec: str = "{x:.2f}",
    fontsize: int = 12,
    alpha: float = 0.5,
    drop_last: int = 0,
    **kwargs,
) -> None:
    """Annotate a bar graph with the bar values.

    Parameters
    ----------
    ax : Axes
        Axes object to annotate.
    dist : float, optional
        Distance from ends as fraction of max bar. Defaults to 0.15.
    color : str, optional
        Text color. Defaults to "k".
    compact : bool, optional
        Annotate inside the bars. Defaults to False.
    orient : str, optional
        Bar orientation. Defaults to "h".
    format_spec : str, optional
        Format string for annotations. Defaults to ".2f".
    fontsize : int, optional
        Font size. Defaults to 12.
    alpha : float, optional
        Opacity of text. Defaults to 0.5.
    drop_last : int, optional
        Number of bars to ignore on tail end. Defaults to 0.
    """
    if not compact:
        dist = -dist

    xb = np.array(ax.get_xbound()) * (1 + abs(2 * dist))
    ax.set_xbound(*xb)

    max_bar = np.abs([b.get_width() for b in ax.patches]).max()
    dist = dist * max_bar
    for bar in ax.patches[:-drop_last or len(ax.patches)]:
        if orient.lower() == "h":
            x = bar.get_width()
            x = x + dist if x < 0 else x - dist
            y = bar.get_y() + bar.get_height() / 2
        elif orient.lower() == "v":
            x = bar.get_x() + bar.get_width() / 2
            y = bar.get_height()
            y = y + dist if y < 0 else y - dist
        else:
            raise ValueError("`orient` must be 'h' or 'v'")

        text = format_spec.format(x=bar.get_width())
        ax.annotate(
            text,
            (x, y),
            ha="center",
            va="center",
            c=color,
            fontsize=fontsize,
            alpha=alpha,
            **kwargs,
        )
Exemple #25
0
def plot_timestamped(
        timestamps: List[str],
        values,
        plabels=None,
        ratio=None,
        timezone=None,
        mavgs=[(5, 'blue'), (14, 'green')],
        ytick_size=None,
        ylimits=None,
        figure:plt.Figure=None,
        axes:plt.Axes=None,
        noplot:bool=False,
        **rest
) -> plt.Figure:
    timestamps, values = lzip(*((t, v) for t, v in zip(timestamps, values) if v is not None))
    # TODO report of filtered values?

    tss: List[datetime] = lmap(parse_timestamp, timestamps)
    if timezone is not None:
        raise RuntimeError("TODO") # ???
    tz = pytz.utc
    tss = lmap(lambda d: d.astimezone(tz) if d.tzinfo is not None else d.replace(tzinfo=tz), tss)

    assert_increasing(tss)

    mavgsc = [(mavg(tss, values, timedelta(m)), c) for m, c in mavgs]

    fig = figure
    if fig is None:
        fig = plt.figure()

    if ratio is not None:
        fig.set_size_inches(ratio)

    if axes is None:
        axes = fig.add_subplot(1,1,1)

    if ytick_size is not None:
        major_loc = MultipleLocator(ytick_size)
        axes.yaxis.set_major_locator(major_loc)

    if ylimits is not None:
        axes.set_ylim(ylimits)

    def dflt(d, **kwargs):
        dd = {k: v for k, v in d.items()}
        for k, v in kwargs.items():
            if k not in dd:
                dd[k] = v
        return dd

    def patch(d, **kwargs):
        dd = {k: v for k, v in d.items()}
        dd.update(kwargs)
        return dd

    if not noplot:
        axes.plot(tss, values, **dflt(rest, color='red'))
    # TODO ??
    if plabels:
        for t, v, l in zip(tss, values, plabels):
            axes.annotate(l, xy=(t, v))
    for mv, c in mavgsc:
        axes.plot([m[0] for m in mv], [m[1] for m in mv], **patch(rest, color=c)) # TODO hmm
    return fig
def plot_triv(fig: plt.Figure,
              ax: plt.Axes,
              A: Triangle,
              B: Triangle,
              dist: bool = True):
    _A, _B = get_triv(A, B)
    _A.plot(fig, ax, "black")
    _B.plot(fig, ax, "blue")
    ax.scatter([0, 1], [0, 0], c="black")
    ax.annotate("$(0, 0)$", [-0.01, -0.02])
    ax.annotate("$(1, 0)$", [1, -0.02])

    _A.plot(fig, ax, color="black")
    _A_offsets = np.array([
        [0.07, 0.02],
        [-0.05, 0.03],
        [-0.02, -0.04],
    ])
    _B.plot(fig, ax, color="blue")
    _B_offsets = np.array([
        [0.05, 0.0125],
        [-0.03, 0.01],
        [-0.02, -0.04],
    ])
    ax.annotate("$t$", _A.points[2] + np.array([-0.03, -0.005]))
    ax.annotate("$t^\prime$", _B.points[2] + np.array([-0.025, 0.01]))
    for i in range(3):
        ax.annotate(f"$\\alpha_{i}$", _A.points[i] + _A_offsets[i])
        ax.annotate(f"$\\beta_{i}$",
                    _B.points[i] + _B_offsets[i],
                    color="blue")

    if dist:
        ax.add_collection(
            LineCollection([[_A.points[2], _B.points[2]]],
                           color="black",
                           linestyles="--"))
        ax.scatter(*_A.points[2], c="black")
        ax.scatter(*_B.points[2], c="black")
Exemple #27
0
def real_legend(axis: Axes = None,
                lines: Optional[List[Union[Line2D, List[Line2D]]]] = None,
                labels: Optional[List[str]] = None,
                text_positions: Optional[List[Optional[Tuple[float,
                                                             float]]]] = None,
                arrow_threshold: Optional[float] = None,
                textbox_margin: float = 1.0,
                resolution: int = DEFAULT_RESOLUTION,
                attraction: float = DEFAULT_ATTRACTION,
                repulsion: float = DEFAULT_REPULSION,
                sigma: float = DEFAULT_SIGMA,
                noise: float = DEFAULT_NOISE,
                noise_seed: int = DEFAULT_NOISE_SEED,
                debug: bool = False,
                **kwargs) -> List[Text]:
    """Applies the real legend to an axis object which removes the legend box and adds labels
    annotating important lines in the figure.

    It uses a method of greedy local optimization algorithms. In particular, we model the whole space as
    a square grid of pixels that correspond to placement potential. We black out all forbidden
    places (such as edges and other objects in the figure) and then define for each label a different
    optimization space. For each label, we model the target line to be "attractive" and all other objects
    to be "repulsive". Then we blur everything out to make the space smooth and simply pick the "best spot"
    given by the highest value of the placement potential.

    Parameters
    ----------
    axis: Axes, optional
        The `Axes` object to be targeted. If omitted, the current `Axes` object will be tarteted.

    lines: list of Line2D or list of list of Line2D, optional
        List of lines which we want to label. Each item is either a single `Line2D` instance or a list of `Line2D`
        instances. In the latter case, multiple lines are treated as one object and will be assigned one label.
        If omitted, all lines from the given `Axes` instance will be targeted.

    labels: list of str, optional
        List of labels to assign to the lines. If omitted, the `label` properties from the `Line2D`
        objects will be used.

    text_positions: list of pair of float, optional
        Gives the ability to force one or more labels to be placed in specific positions. Given as a list of
        two-element `float` tuples that represent coordinates in the coordinate system of the figure data.

    arrow_threshold: float, optional
        If specified, it is the minimum distance that a label object will have from the line in order for
        an arrow to be drawn from the label to the line. By default no arrows are shown. The distance is given
        in the scale of the figure data.

    textbox_margin: float
        Allows us to specify a margin around label objects to prevent them from colliding or being to close
        to other objects. The margin is given relatively to the current size of the label. For example,
        the margin `1.0` means there will be no extra margin around the text (default behavior).
        As another example, the margin 2.0 means that the total bounding box around the label text will be twice
        as large as the original text.

    resolution: int
        Controls the resolution of the label placement space given in pixels. A higher number will mean more
        precision in placement but also more time to compute the positions. Lower values will be faster
        but with more rigid placement.

    attraction: float
        Controls the relative strength of how much the target line will attract the label.
        Tweak this parameter to fine tune placement.

    repulsion: float
        Controls the relative streength of how much all other non-target lines will repel the label.
        Tweak this parameter to fine tune placement.

    sigma: float
        Controlls how much we will smooth out the label positioning horizon.
        Tweak this parameter to fine tune placement.

    noise: float
        Controls the noise power to inject into the label positioning process. More noise will increase the probability
        that the object will be placed further from some optimal position based on the model of this method.
        On the other hand it can allow placing nodes in more convenient places.

        Tweak this parameter to fine tune placement. Default is 0.0 but for best results use values
        between 0.0 and 0.5.

    noise_seed: int
        If noise is added to the placement process, this is the random seed. Change the seed to change the placement
        outcome.

    debug: bool
        If set to `True` then a debug figure will be shown with optimization heatmaps for every line object.
        Dark areas will show areas we were trying to avoid. Light areas will show areas where the label
        was likely to be placed. The red dot is the final placement.

    """

    # If no axis is specified, we simply take the current one being drawn.
    if axis is None:
        axis = plt.gca()

    # If no lines are specified as targets, we simply target all lines.
    if lines is None:
        lines = axis.lines

    # Make sure if labels and/or text positions are specified, that they are the same length as lines.
    if labels is not None:
        assert (len(labels) == len(lines))
    if text_positions is not None:
        assert (len(text_positions) == len(lines))

    num_lines = len(lines)
    xmin, xmax = axis.get_xlim()
    ymin, ymax = axis.get_ylim()

    # Draw text to get the bounding boxes. We place it in the center of the plot to avoid any impact on the viewport.
    labels_specified = True
    if labels is not None:
        labels_specified = False
        labels = []
    colors = []
    texts = []
    texts_bb = []
    xc, yc = xmin + (xmax - xmin) / 2, ymin + (ymax - ymin) / 2
    for l in range(num_lines):
        line = lines[l] if not isinstance(lines[l], list) else lines[l][0]

        if not labels_specified:
            labels.append(line.get_label())
        label = labels[l]

        color = line.get_color()
        colors.append(color)

        # The text position is either going to be in the center or, in some given position
        # if it is specified in the input arguments.
        x, y = xc, yc
        if text_positions is not None and text_positions[l] is not None:
            assert (isinstance(text_positions[l], tuple))
            assert (len(text_positions[l]) == 2)
            x, y = text_positions[l]

        text = axis.text(xc,
                         yc,
                         label,
                         color=color,
                         horizontalalignment='center',
                         verticalalignment='center')
        texts.append(text)

        text_bb = _get_text_bb(text, textbox_margin)
        texts_bb.append(
            text_bb.translated(-text_bb.width / 2 - text_bb.x0,
                               -text_bb.height / 2 - text_bb.y0))

    # Build the "points of presence" matrix with all that belong to certain lines.
    pop = np.zeros((num_lines, resolution, resolution), dtype=np.float)

    for l in range(num_lines):
        line = [lines[l]] if not isinstance(lines[l], tuple) else lines[l]

        for x_i, y_i in itertools.product(range(resolution),
                                          range(resolution)):
            x_f, y_f = (np.array([x_i, y_i]) / resolution) * (
                [xmax - xmin, ymax - ymin]) + [xmin, ymin]
            text_bb_xy = texts_bb[l].translated(x_f, y_f)

            if text_bb_xy.x0 < xmin or text_bb_xy.x1 > xmax or text_bb_xy.y0 < ymin or text_bb_xy.y1 > ymax:
                pop[l, x_i, y_i] = 1.0

            elif any(line_part.get_path().intersects_bbox(text_bb_xy,
                                                          filled=False)
                     for line_part in line):
                pop[l, x_i, y_i] = 1.0

            # If a text position is already specified, we will immediately add it to the pop.
            if text_positions is not None and text_positions[l] is not None:
                if texts_bb[l].overlaps(text_bb_xy):
                    pop[l, x_i, y_i] = 1.0

    if debug:
        debug_f, debug_ax = plt.subplots(nrows=1, ncols=num_lines)

    for l in range(num_lines):

        # If the position of this label has been provided in the input arguments, we can just skip it.
        if text_positions is not None and text_positions[l] is not None:
            continue

        # Find empty space, which is a nice place for labels.
        empty_space = 1.0 - (np.sum(pop, axis=0) > 0) * 1.0

        # blur the pop's
        pop_blurred = pop.copy()
        for ll in range(num_lines):
            pop_blurred[ll] = ndimage.gaussian_filter(pop[ll],
                                                      sigma=sigma *
                                                      resolution / 5)

        # Positive weights for current line, negative weight for others....
        w = -repulsion * np.ones(num_lines, dtype=np.float)
        w[l] = attraction

        # calculate a field
        p = empty_space + np.sum(w[:, np.newaxis, np.newaxis] * pop_blurred,
                                 axis=0)

        # Add noise to the field if specified.
        if noise > 0.0:
            np.random.seed(noise_seed)
            p += np.random.normal(0.0, noise, p.shape)

        pos = np.argmax(p)  # note, argmax flattens the array first
        best_x, best_y = (pos / resolution, pos % resolution)
        x = xmin + (xmax - xmin) * best_x / resolution
        y = ymin + (ymax - ymin) * best_y / resolution

        if debug:
            im1 = debug_ax[l].imshow(p.T,
                                     interpolation='nearest',
                                     origin="lower")
            debug_ax[l].set_title("Heatmap for: " + texts[l].get_text())
            debug_ax[l].plot(best_x, best_y, 'ro')
            divider = make_axes_locatable(debug_ax[l])
            cax = divider.append_axes('right', size='5%', pad=0.05)
            debug_f.colorbar(im1, cax=cax, orientation='vertical')

        texts[l].set_position((x, y))

        # Prevent collision by blocking out the bounding box of this text box.
        text_bb_new = _get_text_bb(texts[l], textbox_margin)
        x_i_min, y_i_min = tuple(
            ((text_bb_new.min - [xmin, ymin]) / ([xmax - xmin, ymax - ymin]) *
             resolution).astype(int))
        x_i_max, y_i_max = tuple(
            ((text_bb_new.max - [xmin, ymin]) / ([xmax - xmin, ymax - ymin]) *
             resolution).astype(int))

        # Augmend the barrier to prevent collision between labels.
        w_barrier = int(round((x_i_max - x_i_min) / 2))
        h_barrier = int(round((y_i_max - y_i_min) / 2))
        x_i_min = int(max(0, x_i_min - w_barrier))
        y_i_min = int(max(0, y_i_min - h_barrier))
        x_i_max = int(min(resolution - 1, x_i_max + w_barrier))
        y_i_max = int(min(resolution - 1, y_i_max + h_barrier))

        pop[l, x_i_min:x_i_max + 1, y_i_min:y_i_max + 1] = 1.0

    # If the arrow threshold has been specified, draw arrows where needed.
    if arrow_threshold is not None:
        for l in range(num_lines):

            # Get all points on the path (including some interpolated ones).
            line = [lines[l]] if not isinstance(lines[l], tuple) else lines[l]
            points = np.vstack(
                [l.get_path().interpolated(10).vertices for l in line])

            # Get the midpoint of the text box.
            text_c = np.array(
                _get_midpoint(_get_text_bb(texts[l], textbox_margin)))

            # Get all distances.
            distances = [np.linalg.norm(text_c - p) for p in points]
            d_min_idx = np.argmin(distances)

            # If the distance is larger than the threshold, draw the line.
            if distances[d_min_idx] > arrow_threshold:

                # Find first point that doesn't intersect with any other text box.
                d_sorted_idx = np.argsort(distances)
                xytext = texts[l].get_position()
                xy = points[d_min_idx, :]
                for idx in d_sorted_idx:
                    tmp_line = Path([xytext, points[idx, :]])
                    intersects_with_any_textbox = all(
                        not tmp_line.intersects_bbox(
                            _get_text_bb(texts[i], textbox_margin))
                        for i in range(len(texts)) if i != l)
                    if intersects_with_any_textbox:
                        xy = points[idx, :]
                        break

                # Draw the new text with the arrow.
                a = axis.annotate(labels[l],
                                  xy=xy,
                                  xytext=xytext,
                                  ha="center",
                                  va="center",
                                  color=colors[l],
                                  arrowprops=dict(arrowstyle="->",
                                                  color=colors[l]))

                # Hide original text.
                texts[l].set_visible(False)
                texts[l] = a

    # Remove the ugly legend.
    ugly_legend = axis.get_legend()
    if ugly_legend is not None:
        ugly_legend.remove()

    if debug:
        debug_f.show()

    # We return all the placed labels.
    return texts
def draw_detections(
    ax: plt.Axes,
    *,
    boxes: np.ndarray,
    labels: Optional[np.ndarray] = None,
    label_associations: Dict[int, str] = dict(label_lookup),
    box_color: str = "r",
    font_color: Optional[str] = None,
    box_line_width: int = 2,
    label_fontsize: int = 24,
):
    """
    Draws detection boxes/labels on an existing image.

    Parameters
    ----------
    ax : Axes
        The image axis object on which the detections will be drawn.

    boxes : ndarray, shape-(N, 4)
        The detection boxes, each box is formatted as (xlo, ylo, xhi, yhi)
        in pixel space.

    labels : Optional[ndarray], shape-(N,)
        The integer classification label associated with each box.

    label_associations : Dict[int, str]
        int -> label

    box_color : str, optional (default=red)

    font_color : Optional[str]
        If not specified, matches ``box_color``

    box_line_width : int, optional (default=2)

    label_fontsize : int, optional (default=24)
    """
    assert boxes.ndim == 2 and boxes.shape[1] == 4

    if labels is None:
        labels = [None] * len(boxes)
    else:
        # filter non-background for efficient slider
        not_background = labels.squeeze() != 0
        boxes = boxes[not_background]
        labels = labels[not_background]

    assert len(boxes) == len(labels)

    if font_color is None:
        font_color = box_color

    for class_pred, box_pred in zip(labels, boxes):
        if class_pred is not None and class_pred == 0:
            continue

        x1, y1, x2, y2 = box_pred
        ax.add_patch(
            Rectangle(
                (x1, y1),
                x2 - x1,
                y2 - y1,
                color=box_color,
                fill=None,
                lw=box_line_width,
            )
        )
        if class_pred is not None:
            label = label_associations[int(class_pred)]
            ax.annotate(label, (x1, y1), color=font_color, fontsize=label_fontsize)
Exemple #29
0
def annotate_train_run(ax: plt.Axes, tr: TrainRun):
    x = tr.first_run().departure_station()
    y = tr.first_run().departure_time
    text = y.strftime("%H:%M")
    ax.annotate(text, xy=(x, y))
Exemple #30
0
    def highlight_data_outside_domain(
        self,
        ax: Axes,
        x: unyt_array,
        y: unyt_array,
        color: str,
        x_lim: List,
        y_lim: List,
    ) -> None:
        """
        Add arrows to the plot for each data point residing outside the plot's domain.
        The arrows indicate where the missing points are. For a given missing data point
        with its Y(X) coordinate outside the Y(X)-axis range, the corresponding arrow
        will have the same X(Y) coordinate and point to the direction where the missing
        point is. If a data point happens to lie outside both the X-axis range and
        Y-axis range, then a diagonal arrow is drawn.

        Parameters
        ----------

        ax: Axes
            An object of axes where to draw the arrows

        x: unyt_array
            Horizontal axis data

        y: unyt_array
            Vertical axis data

        color: str
            Color of the arrows that this function will draw. The color should be the
            same as the color of the (missing) data points.

        x_lim: List
            A 2-length list containing the lower and upper limits of the X-axis range.

        y_lim: List
            A 2-length list containing the lower and upper limits of the Y-axis range.
        """

        # Additional check to ensure all provided data points are good
        if not isnan(x).any() and not isnan(y).any():

            # Arrow parameters
            arrow_length = 0.07
            distance_from_edge = 0.01
            arrow_style = "->"

            # Split data into three categories (along X axis)
            below_x_range = x < x_lim[0]
            above_x_range = x > x_lim[1]
            within_x_range = logical_and(x >= x_lim[0], x <= x_lim[1])

            # Split data into three categories (along Y axis)
            below_y_range = y < y_lim[0]
            above_y_range = y > y_lim[1]
            within_y_range = logical_and(y >= y_lim[0], y <= y_lim[1])

            # First, find all data points that are outside the Y-axis range and within
            # X-axis range
            below_y_within_x = logical_and(below_y_range, within_x_range)
            above_y_within_x = logical_and(above_y_range, within_x_range)

            # X coordinates of the data points whose Y coordinates are outside the
            # Y-axis range
            x_down_list = x[below_y_within_x]
            x_up_list = x[above_y_within_x]

            # Use figure's data coordinates along the X axis and relative coordinates
            # along the Y axis.
            tform_x = blended_transform_factory(ax.transData, ax.transAxes)

            # Draw arrows pointing downwards
            for x_down in x_down_list:
                # We are using 'ax.annotate' instead of 'ax.arrow' because we want the
                # arrow's head and tail to have the same size regardless of what the
                # axes aspect ratio is or whether the plot is in logarithmic or linear
                # scale.
                ax.annotate(
                    "",
                    xytext=(x_down, arrow_length + distance_from_edge),
                    textcoords=tform_x,
                    xy=(x_down, distance_from_edge),
                    xycoords=tform_x,
                    arrowprops=dict(color=color, arrowstyle=arrow_style),
                )

            # Draw arrows pointing upwards
            for x_up in x_up_list:
                ax.annotate(
                    "",
                    xytext=(x_up, 1.0 - arrow_length - distance_from_edge),
                    textcoords=tform_x,
                    xy=(x_up, 1.0 - distance_from_edge),
                    xycoords=tform_x,
                    arrowprops=dict(color=color, arrowstyle=arrow_style),
                )

            # Next, find all data points that are outside the X-axis range and
            # within Y-axis range
            below_x_within_y = logical_and(below_x_range, within_y_range)
            above_x_within_y = logical_and(above_x_range, within_y_range)

            # Y coordinates of the data points whose X coordinates are outside the
            # X-axis range
            y_left_list = y[below_x_within_y]
            y_right_list = y[above_x_within_y]

            # Use figure's data coordinates along the Y axis and relative coordinates
            # along the X axis.
            tform_y = blended_transform_factory(ax.transAxes, ax.transData)

            # Draw arrows pointing leftwards
            for y_left in y_left_list:
                ax.annotate(
                    "",
                    xytext=(arrow_length + distance_from_edge, y_left),
                    textcoords=tform_y,
                    xy=(distance_from_edge, y_left),
                    xycoords=tform_y,
                    arrowprops=dict(color=color, arrowstyle=arrow_style),
                )

            # Draw arrows pointing rightwards
            for y_right in y_right_list:
                ax.annotate(
                    "",
                    xytext=(1.0 - arrow_length - distance_from_edge, y_right),
                    textcoords=tform_y,
                    xy=(1.0 - distance_from_edge, y_right),
                    xycoords=tform_y,
                    arrowprops=dict(color=color, arrowstyle=arrow_style),
                )

            # Finally, handle the points that are both outside the X and Y axis range
            outside_plot = logical_and(
                logical_or(below_y_range, above_y_range),
                logical_or(below_x_range, above_x_range),
            )
            x_outside_list, y_outside_list = x[outside_plot], y[outside_plot]

            for x_outside, y_outside in zip(x_outside_list, y_outside_list):

                # Unlike vertical and horizontal arrows, diagonal arrows extend both
                # in X and Y directions. We account for it by dividing the length of
                # diagonal arrow along each dimension by \sqrt(2).
                arrow_proj_length = arrow_length / sqrt(2.0)

                # Find the correct position of the arrow on the plot
                if x_lim[0] > x_outside:
                    arrow_start_x = arrow_proj_length + distance_from_edge
                    arrow_end_x = distance_from_edge
                else:
                    arrow_start_x = 1.0 - arrow_proj_length - distance_from_edge
                    arrow_end_x = 1.0 - distance_from_edge

                if y_lim[0] > y_outside:
                    arrow_start_y = arrow_proj_length + distance_from_edge
                    arrow_end_y = distance_from_edge
                else:
                    arrow_start_y = 1.0 - arrow_proj_length - distance_from_edge
                    arrow_end_y = 1.0 - distance_from_edge

                # Use figure's relative coordinates along the X and Y axis.
                tform = blended_transform_factory(ax.transAxes, ax.transAxes)

                ax.annotate(
                    "",
                    xytext=(arrow_start_x, arrow_start_y),
                    textcoords=tform,
                    xy=(arrow_end_x, arrow_end_y),
                    xycoords=tform,
                    arrowprops=dict(color=color, arrowstyle=arrow_style),
                )

        return