def _draw_ticks(self, renderer, pixel_array, angle_array, offset): """ Draw the minor ticks. """ path_trans = self.get_transform() gc = renderer.new_gc() gc.set_foreground(self.get_color()) gc.set_alpha(self.get_alpha()) gc.set_linewidth(self.get_linewidth()) marker_scale = Affine2D().scale(offset, offset) marker_rotation = Affine2D() marker_transform = marker_scale + marker_rotation initial_angle = 180. if self.get_tick_out() else 0. for axis in self.get_visible_axes(): for loc, angle in zip(pixel_array[axis], angle_array[axis]): # Set the rotation for this tick marker_rotation.rotate_deg(initial_angle + angle) # Draw the markers locs = path_trans.transform_non_affine(np.array([loc, loc])) renderer.draw_markers(gc, self._tickvert_path, marker_transform, Path(locs), path_trans.get_affine()) # Reset the tick rotation before moving to the next tick marker_rotation.clear() gc.restore()
def get_gridline_path(world, pixel): """ Draw a grid line Parameters ---------- world : ndarray The longitude and latitude values along the curve, given as a (n,2) array. pixel : ndarray The pixel coordinates corresponding to ``lon_lat`` """ # Mask values with invalid pixel positions mask = np.isnan(pixel[:, 0]) | np.isnan(pixel[:, 1]) # We can now start to set up the codes for the Path. codes = np.zeros(world.shape[0], dtype=np.uint8) codes[:] = Path.LINETO codes[0] = Path.MOVETO codes[mask] = Path.MOVETO # Also need to move to point *after* a hidden value codes[1:][mask[:-1]] = Path.MOVETO # We now go through and search for discontinuities in the curve that would # be due to the curve going outside the field of view, invalid WCS values, # or due to discontinuities in the projection. # Create the path path = Path(pixel, codes=codes) return path
def path(self): self.update_spines() x, y = [], [] for axis in self: x.append(self[axis].data[:, 0]) y.append(self[axis].data[:, 1]) return Path(np.vstack([np.hstack(x), np.hstack(y)]).transpose())
def _update_patch_path(self): """Override path patch to include only the outer ellipse, not the major and minor axes in the middle.""" self.update_spines() vertices = self['c'].data if self._path is None: self._path = Path(vertices) else: self._path.vertices = vertices
def _update_patch_path(self): self.update_spines() x, y = [], [] for axis in self: x.append(self[axis].data[:, 0]) y.append(self[axis].data[:, 1]) vertices = np.vstack([np.hstack(x), np.hstack(y)]).transpose() if self._path is None: self._path = Path(vertices) else: self._path.vertices = vertices
def get_gridline_path(ax, transform, world): """ Draw a grid line Parameters ---------- ax : ~matplotlib.axes.Axes The axes in which to plot the grid transform : transformation class The transformation between the world and pixel coordinates world : `~numpy.ndarray` The world coordinates along the curve, given as a (n,2) array. """ # Get pixel limits # xlim = ax.get_xlim() # ylim = ax.get_ylim() # Transform line to pixel coordinates pixel = transform.transform(world) # Mask values with invalid pixel positions mask = np.isnan(pixel[:, 0]) | np.isnan(pixel[:, 1]) # Mask values outside the viewport # This has now been disabled because it assumes specifically rectangular # axes, and also doesn't work if the coordinate direction is flipped. # outside = ((pixel[:, 0] < xlim[0]) | (pixel[:, 0] > xlim[-1]) | # (pixel[:, 1] < ylim[0]) | (pixel[:, 1] > ylim[-1])) # mask[1:-1] = mask[1:-1] | (outside[2:] & outside[:-2]) # We can now start to set up the codes for the Path. codes = np.zeros(world.shape[0], dtype=np.uint8) codes[:] = Path.LINETO codes[0] = Path.MOVETO codes[mask] = Path.MOVETO # Also need to move to point *after* a hidden value codes[1:][mask[:-1]] = Path.MOVETO # We now go through and search for discontinuities in the curve that would # be due to the curve going outside the field of view, invalid WCS values, # or due to discontinuities in the projection. # Create the path path = Path(pixel, codes=codes) # And add to the axes return path
def _update_patch_path(self): self.update_spines() xmin, xmax = self.parent_axes.get_xlim() ymin, ymax = self.parent_axes.get_ylim() x = [xmin, xmax, xmax, xmin, xmin] y = [ymin, ymin, ymax, ymax, ymin] vertices = np.vstack([np.hstack(x), np.hstack(y)]).transpose() if self._path is None: self._path = Path(vertices) else: self._path.vertices = vertices
def get_lon_lat_path(lon_lat, pixel, lon_lat_check): """ Draw a curve, taking into account discontinuities. Parameters ---------- lon_lat : ndarray The longitude and latitude values along the curve, given as a (n,2) array. pixel : ndarray The pixel coordinates corresponding to ``lon_lat`` lon_lat_check : ndarray The world coordinates derived from converting from ``pixel``, which is used to ensure round-tripping. """ # In some spherical projections, some parts of the curve are 'behind' or # 'in front of' the plane of the image, so we find those by reversing the # transformation and finding points where the result is not consistent. sep = angular_separation(np.radians(lon_lat[:, 0]), np.radians(lon_lat[:, 1]), np.radians(lon_lat_check[:, 0]), np.radians(lon_lat_check[:, 1])) # Define the relevant scale size using the separation between the first two points scale_size = angular_separation(*np.radians(lon_lat[0, :]), *np.radians(lon_lat[1, :])) with np.errstate(invalid='ignore'): sep[sep > np.pi] -= 2. * np.pi mask = np.abs(sep > ROUND_TRIP_RTOL * scale_size) # Mask values with invalid pixel positions mask = mask | np.isnan(pixel[:, 0]) | np.isnan(pixel[:, 1]) # We can now start to set up the codes for the Path. codes = np.zeros(lon_lat.shape[0], dtype=np.uint8) codes[:] = Path.LINETO codes[0] = Path.MOVETO codes[mask] = Path.MOVETO # Also need to move to point *after* a hidden value codes[1:][mask[:-1]] = Path.MOVETO # We now go through and search for discontinuities in the curve that would # be due to the curve going outside the field of view, invalid WCS values, # or due to discontinuities in the projection. # We start off by pre-computing the step in pixel coordinates from one # point to the next. The idea is to look for large jumps that might indicate # discontinuities. step = np.sqrt((pixel[1:, 0] - pixel[:-1, 0])**2 + (pixel[1:, 1] - pixel[:-1, 1])**2) # We search for discontinuities by looking for places where the step # is larger by more than a given factor compared to the median # discontinuous = step > DISCONT_FACTOR * np.median(step) discontinuous = step[1:] > DISCONT_FACTOR * step[:-1] # Skip over discontinuities codes[2:][discontinuous] = Path.MOVETO # The above missed the first step, so check that too if step[0] > DISCONT_FACTOR * step[1]: codes[1] = Path.MOVETO # Create the path path = Path(pixel, codes=codes) return path
class Ticks(Line2D): """ Ticks are derived from Line2D, and note that ticks themselves are markers. Thus, you should use set_mec, set_mew, etc. To change the tick size (length), you need to use set_ticksize. To change the direction of the ticks (ticks are in opposite direction of ticklabels by default), use set_tick_out(False). Note that Matplotlib's defaults dictionary :data:`~matplotlib.rcParams` contains default settings (color, size, width) of the form `xtick.*` and `ytick.*`. In a WCS projection, there may not be a clear relationship between axes of the projection and 'x' or 'y' axes. For this reason, we read defaults from `xtick.*`. The following settings affect the default appearance of ticks: * `xtick.direction` * `xtick.major.size` * `xtick.major.width` * `xtick.color` """ def __init__(self, ticksize=None, tick_out=None, **kwargs): if ticksize is None: ticksize = rcParams['xtick.major.size'] self.set_ticksize(ticksize) self.set_tick_out(rcParams.get('xtick.direction', 'in') == 'out') self.clear() line2d_kwargs = { 'color': rcParams['xtick.color'], # For the linewidth we need to set a default since old versions of # matplotlib don't have this. 'linewidth': rcParams.get('xtick.major.width', 1) } line2d_kwargs.update(kwargs) Line2D.__init__(self, [0.], [0.], **line2d_kwargs) self.set_visible_axes('all') self._display_minor_ticks = False def display_minor_ticks(self, display_minor_ticks): self._display_minor_ticks = display_minor_ticks def get_display_minor_ticks(self): return self._display_minor_ticks def set_tick_out(self, tick_out): """ set True if tick need to be rotated by 180 degree. """ self._tick_out = tick_out def get_tick_out(self): """ Return True if the tick will be rotated by 180 degree. """ return self._tick_out def set_ticksize(self, ticksize): """ set length of the ticks in points. """ self._ticksize = ticksize def get_ticksize(self): """ Return length of the ticks in points. """ return self._ticksize def set_visible_axes(self, visible_axes): self._visible_axes = visible_axes def get_visible_axes(self): if self._visible_axes == 'all': return self.world.keys() else: return [x for x in self._visible_axes if x in self.world] def clear(self): self.world = {} self.pixel = {} self.angle = {} self.disp = {} self.minor_world = {} self.minor_pixel = {} self.minor_angle = {} self.minor_disp = {} def add(self, axis, world, pixel, angle, axis_displacement): if axis not in self.world: self.world[axis] = [world] self.pixel[axis] = [pixel] self.angle[axis] = [angle] self.disp[axis] = [axis_displacement] else: self.world[axis].append(world) self.pixel[axis].append(pixel) self.angle[axis].append(angle) self.disp[axis].append(axis_displacement) def get_minor_world(self): return self.minor_world def add_minor(self, minor_axis, minor_world, minor_pixel, minor_angle, minor_axis_displacement): if minor_axis not in self.minor_world: self.minor_world[minor_axis] = [minor_world] self.minor_pixel[minor_axis] = [minor_pixel] self.minor_angle[minor_axis] = [minor_angle] self.minor_disp[minor_axis] = [minor_axis_displacement] else: self.minor_world[minor_axis].append(minor_world) self.minor_pixel[minor_axis].append(minor_pixel) self.minor_angle[minor_axis].append(minor_angle) self.minor_disp[minor_axis].append(minor_axis_displacement) def __len__(self): return len(self.world) _tickvert_path = Path([[0., 0.], [1., 0.]]) def draw(self, renderer, ticks_locs): """ Draw the ticks. """ if not self.get_visible(): return offset = renderer.points_to_pixels(self.get_ticksize()) self._draw_ticks(renderer, self.pixel, self.angle, offset, ticks_locs) if self._display_minor_ticks: offset = offset * 0.5 # for minor ticksize self._draw_ticks(renderer, self.minor_pixel, self.minor_angle, offset, ticks_locs) def _draw_ticks(self, renderer, pixel_array, angle_array, offset, ticks_locs): """ Draw the minor ticks. """ path_trans = self.get_transform() gc = renderer.new_gc() gc.set_foreground(self.get_color()) gc.set_alpha(self.get_alpha()) gc.set_linewidth(self.get_linewidth()) marker_scale = Affine2D().scale(offset, offset) marker_rotation = Affine2D() marker_transform = marker_scale + marker_rotation initial_angle = 180. if self.get_tick_out() else 0. for axis in self.get_visible_axes(): if axis not in pixel_array: continue for loc, angle in zip(pixel_array[axis], angle_array[axis]): # Set the rotation for this tick marker_rotation.rotate_deg(initial_angle + angle) # Draw the markers locs = path_trans.transform_non_affine(np.array([loc, loc])) renderer.draw_markers(gc, self._tickvert_path, marker_transform, Path(locs), path_trans.get_affine()) # Reset the tick rotation before moving to the next tick marker_rotation.clear() ticks_locs[axis].append(locs) gc.restore()
def get_lon_lat_path(ax, transform, lon_lat): """ Draw a curve, taking into account discontinuities. Parameters ---------- ax : ~matplotlib.axes.Axes The axes in which to plot the grid transform : transformation class The transformation between the world and pixel coordinates lon_lat : `~numpy.ndarray` The longitude and latitude values along the curve, given as a (n,2) array. """ # Get pixel limits # xlim = ax.get_xlim() # ylim = ax.get_ylim() # Transform line to pixel coordinates pixel = transform.transform(lon_lat) # In some spherical projections, some parts of the curve are 'behind' or # 'in front of' the plane of the image, so we find those by reversing the # transformation and finding points where the result is not consistent. lon_lat_check = transform.inverted().transform(pixel) sep = angular_separation(np.radians(lon_lat[:, 0]), np.radians(lon_lat[:, 1]), np.radians(lon_lat_check[:, 0]), np.radians(lon_lat_check[:, 1])) sep[sep > np.pi] -= 2. * np.pi mask = np.abs(sep > ROUND_TRIP_TOL) # Mask values with invalid pixel positions mask = mask | np.isnan(pixel[:, 0]) | np.isnan(pixel[:, 1]) # Mask values outside the viewport # This has now been disabled because it assumes specifically rectangular # axes, and also doesn't work if the coordinate direction is flipped. # outside = ((pixel[:, 0] < xlim[0]) | (pixel[:, 0] > xlim[-1]) | # (pixel[:, 1] < ylim[0]) | (pixel[:, 1] > ylim[-1])) # mask[1:-1] = mask[1:-1] | (outside[2:] & outside[:-2]) # We can now start to set up the codes for the Path. codes = np.zeros(lon_lat.shape[0], dtype=np.uint8) codes[:] = Path.LINETO codes[0] = Path.MOVETO codes[mask] = Path.MOVETO # Also need to move to point *after* a hidden value codes[1:][mask[:-1]] = Path.MOVETO # We now go through and search for discontinuities in the curve that would # be due to the curve going outside the field of view, invalid WCS values, # or due to discontinuities in the projection. # We start off by pre-computing the step in pixel coordinates from one # point to the next. The idea is to look for large jumps that might indicate # discontinuities. step = np.sqrt((pixel[1:, 0] - pixel[:-1, 0])**2 + (pixel[1:, 1] - pixel[:-1, 1])**2) # We search for discontinuities by looking for places where the step # is larger by more than a given factor compared to the median # discontinuous = step > DISCONT_FACTOR * np.median(step) discontinuous = step[1:] > DISCONT_FACTOR * step[:-1] # Skip over discontinuities codes[2:][discontinuous] = Path.MOVETO # The above missed the first step, so check that too if step[0] > DISCONT_FACTOR * step[1]: codes[1] = Path.MOVETO # Create the path path = Path(pixel, codes=codes) # And add to the axes return path
def path(self): x, y = [], [] for axis in self: x.append(self[axis].pixel[:, 0]) y.append(self[axis].pixel[:, 1]) return Path(np.vstack([np.hstack(x), np.hstack(y)]).transpose())
class Ticks(Line2D): """ Ticks are derived from Line2D, and note that ticks themselves are markers. Thus, you should use set_mec, set_mew, etc. To change the tick size (length), you need to use set_ticksize. To change the direction of the ticks (ticks are in opposite direction of ticklabels by default), use set_tick_out(False). """ def __init__(self, ticksize=5., tick_out=False, **kwargs): self.set_ticksize(ticksize) self.set_tick_out(tick_out) self.clear() Line2D.__init__(self, [0.], [0.], **kwargs) self.set_color('black') self.set_visible_axes('all') def set_tick_out(self, tick_out): """ set True if tick need to be rotated by 180 degree. """ self._tick_out = tick_out def get_tick_out(self): """ Return True if the tick will be rotated by 180 degree. """ return self._tick_out def set_ticksize(self, ticksize): """ set length of the ticks in points. """ self._ticksize = ticksize def get_ticksize(self): """ Return length of the ticks in points. """ return self._ticksize def set_visible_axes(self, visible_axes): self._visible_axes = visible_axes def get_visible_axes(self): if self._visible_axes == 'all': return self.world.keys() else: return [x for x in self._visible_axes if x in self.world] def clear(self): self.world = {} self.pixel = {} self.angle = {} self.disp = {} def add(self, axis, world, pixel, angle, axis_displacement): if axis not in self.world: self.world[axis] = [world] self.pixel[axis] = [pixel] self.angle[axis] = [angle] self.disp[axis] = [axis_displacement] else: self.world[axis].append(world) self.pixel[axis].append(pixel) self.angle[axis].append(angle) self.disp[axis].append(axis_displacement) def __len__(self): return len(self.world) _tickvert_path = Path([[0., 0.], [1., 0.]]) def draw(self, renderer): """ Draw the ticks. """ if not self.get_visible(): return path_trans = self.get_transform() gc = renderer.new_gc() gc.set_foreground(self.get_color()) gc.set_alpha(self.get_alpha()) offset = renderer.points_to_pixels(self.get_ticksize()) marker_scale = Affine2D().scale(offset, offset) marker_rotation = Affine2D() marker_transform = marker_scale + marker_rotation initial_angle = 180. if self.get_tick_out() else 0. for axis in self.get_visible_axes(): for loc, angle in zip(self.pixel[axis], self.angle[axis]): # Set the rotation for this tick marker_rotation.rotate_deg(initial_angle + angle) # Draw the markers locs = path_trans.transform_non_affine(np.array([loc, loc])) renderer.draw_markers(gc, self._tickvert_path, marker_transform, Path(locs), path_trans.get_affine()) # Reset the tick rotation before moving to the next tick marker_rotation.clear() gc.restore()
def real_legend(axis: Axes = None, lines: Optional[List[Union[Line2D, List[Line2D]]]] = None, labels: Optional[List[str]] = None, text_positions: Optional[List[Optional[Tuple[float, float]]]] = None, arrow_threshold: Optional[float] = None, textbox_margin: float = 1.0, resolution: int = DEFAULT_RESOLUTION, attraction: float = DEFAULT_ATTRACTION, repulsion: float = DEFAULT_REPULSION, sigma: float = DEFAULT_SIGMA, noise: float = DEFAULT_NOISE, noise_seed: int = DEFAULT_NOISE_SEED, debug: bool = False, **kwargs) -> List[Text]: """Applies the real legend to an axis object which removes the legend box and adds labels annotating important lines in the figure. It uses a method of greedy local optimization algorithms. In particular, we model the whole space as a square grid of pixels that correspond to placement potential. We black out all forbidden places (such as edges and other objects in the figure) and then define for each label a different optimization space. For each label, we model the target line to be "attractive" and all other objects to be "repulsive". Then we blur everything out to make the space smooth and simply pick the "best spot" given by the highest value of the placement potential. Parameters ---------- axis: Axes, optional The `Axes` object to be targeted. If omitted, the current `Axes` object will be tarteted. lines: list of Line2D or list of list of Line2D, optional List of lines which we want to label. Each item is either a single `Line2D` instance or a list of `Line2D` instances. In the latter case, multiple lines are treated as one object and will be assigned one label. If omitted, all lines from the given `Axes` instance will be targeted. labels: list of str, optional List of labels to assign to the lines. If omitted, the `label` properties from the `Line2D` objects will be used. text_positions: list of pair of float, optional Gives the ability to force one or more labels to be placed in specific positions. Given as a list of two-element `float` tuples that represent coordinates in the coordinate system of the figure data. arrow_threshold: float, optional If specified, it is the minimum distance that a label object will have from the line in order for an arrow to be drawn from the label to the line. By default no arrows are shown. The distance is given in the scale of the figure data. textbox_margin: float Allows us to specify a margin around label objects to prevent them from colliding or being to close to other objects. The margin is given relatively to the current size of the label. For example, the margin `1.0` means there will be no extra margin around the text (default behavior). As another example, the margin 2.0 means that the total bounding box around the label text will be twice as large as the original text. resolution: int Controls the resolution of the label placement space given in pixels. A higher number will mean more precision in placement but also more time to compute the positions. Lower values will be faster but with more rigid placement. attraction: float Controls the relative strength of how much the target line will attract the label. Tweak this parameter to fine tune placement. repulsion: float Controls the relative streength of how much all other non-target lines will repel the label. Tweak this parameter to fine tune placement. sigma: float Controlls how much we will smooth out the label positioning horizon. Tweak this parameter to fine tune placement. noise: float Controls the noise power to inject into the label positioning process. More noise will increase the probability that the object will be placed further from some optimal position based on the model of this method. On the other hand it can allow placing nodes in more convenient places. Tweak this parameter to fine tune placement. Default is 0.0 but for best results use values between 0.0 and 0.5. noise_seed: int If noise is added to the placement process, this is the random seed. Change the seed to change the placement outcome. debug: bool If set to `True` then a debug figure will be shown with optimization heatmaps for every line object. Dark areas will show areas we were trying to avoid. Light areas will show areas where the label was likely to be placed. The red dot is the final placement. """ # If no axis is specified, we simply take the current one being drawn. if axis is None: axis = plt.gca() # If no lines are specified as targets, we simply target all lines. if lines is None: lines = axis.lines # Make sure if labels and/or text positions are specified, that they are the same length as lines. if labels is not None: assert (len(labels) == len(lines)) if text_positions is not None: assert (len(text_positions) == len(lines)) num_lines = len(lines) xmin, xmax = axis.get_xlim() ymin, ymax = axis.get_ylim() # Draw text to get the bounding boxes. We place it in the center of the plot to avoid any impact on the viewport. labels_specified = True if labels is not None: labels_specified = False labels = [] colors = [] texts = [] texts_bb = [] xc, yc = xmin + (xmax - xmin) / 2, ymin + (ymax - ymin) / 2 for l in range(num_lines): line = lines[l] if not isinstance(lines[l], list) else lines[l][0] if not labels_specified: labels.append(line.get_label()) label = labels[l] color = line.get_color() colors.append(color) # The text position is either going to be in the center or, in some given position # if it is specified in the input arguments. x, y = xc, yc if text_positions is not None and text_positions[l] is not None: assert (isinstance(text_positions[l], tuple)) assert (len(text_positions[l]) == 2) x, y = text_positions[l] text = axis.text(xc, yc, label, color=color, horizontalalignment='center', verticalalignment='center') texts.append(text) text_bb = _get_text_bb(text, textbox_margin) texts_bb.append( text_bb.translated(-text_bb.width / 2 - text_bb.x0, -text_bb.height / 2 - text_bb.y0)) # Build the "points of presence" matrix with all that belong to certain lines. pop = np.zeros((num_lines, resolution, resolution), dtype=np.float) for l in range(num_lines): line = [lines[l]] if not isinstance(lines[l], tuple) else lines[l] for x_i, y_i in itertools.product(range(resolution), range(resolution)): x_f, y_f = (np.array([x_i, y_i]) / resolution) * ( [xmax - xmin, ymax - ymin]) + [xmin, ymin] text_bb_xy = texts_bb[l].translated(x_f, y_f) if text_bb_xy.x0 < xmin or text_bb_xy.x1 > xmax or text_bb_xy.y0 < ymin or text_bb_xy.y1 > ymax: pop[l, x_i, y_i] = 1.0 elif any(line_part.get_path().intersects_bbox(text_bb_xy, filled=False) for line_part in line): pop[l, x_i, y_i] = 1.0 # If a text position is already specified, we will immediately add it to the pop. if text_positions is not None and text_positions[l] is not None: if texts_bb[l].overlaps(text_bb_xy): pop[l, x_i, y_i] = 1.0 if debug: debug_f, debug_ax = plt.subplots(nrows=1, ncols=num_lines) for l in range(num_lines): # If the position of this label has been provided in the input arguments, we can just skip it. if text_positions is not None and text_positions[l] is not None: continue # Find empty space, which is a nice place for labels. empty_space = 1.0 - (np.sum(pop, axis=0) > 0) * 1.0 # blur the pop's pop_blurred = pop.copy() for ll in range(num_lines): pop_blurred[ll] = ndimage.gaussian_filter(pop[ll], sigma=sigma * resolution / 5) # Positive weights for current line, negative weight for others.... w = -repulsion * np.ones(num_lines, dtype=np.float) w[l] = attraction # calculate a field p = empty_space + np.sum(w[:, np.newaxis, np.newaxis] * pop_blurred, axis=0) # Add noise to the field if specified. if noise > 0.0: np.random.seed(noise_seed) p += np.random.normal(0.0, noise, p.shape) pos = np.argmax(p) # note, argmax flattens the array first best_x, best_y = (pos / resolution, pos % resolution) x = xmin + (xmax - xmin) * best_x / resolution y = ymin + (ymax - ymin) * best_y / resolution if debug: im1 = debug_ax[l].imshow(p.T, interpolation='nearest', origin="lower") debug_ax[l].set_title("Heatmap for: " + texts[l].get_text()) debug_ax[l].plot(best_x, best_y, 'ro') divider = make_axes_locatable(debug_ax[l]) cax = divider.append_axes('right', size='5%', pad=0.05) debug_f.colorbar(im1, cax=cax, orientation='vertical') texts[l].set_position((x, y)) # Prevent collision by blocking out the bounding box of this text box. text_bb_new = _get_text_bb(texts[l], textbox_margin) x_i_min, y_i_min = tuple( ((text_bb_new.min - [xmin, ymin]) / ([xmax - xmin, ymax - ymin]) * resolution).astype(int)) x_i_max, y_i_max = tuple( ((text_bb_new.max - [xmin, ymin]) / ([xmax - xmin, ymax - ymin]) * resolution).astype(int)) # Augmend the barrier to prevent collision between labels. w_barrier = int(round((x_i_max - x_i_min) / 2)) h_barrier = int(round((y_i_max - y_i_min) / 2)) x_i_min = int(max(0, x_i_min - w_barrier)) y_i_min = int(max(0, y_i_min - h_barrier)) x_i_max = int(min(resolution - 1, x_i_max + w_barrier)) y_i_max = int(min(resolution - 1, y_i_max + h_barrier)) pop[l, x_i_min:x_i_max + 1, y_i_min:y_i_max + 1] = 1.0 # If the arrow threshold has been specified, draw arrows where needed. if arrow_threshold is not None: for l in range(num_lines): # Get all points on the path (including some interpolated ones). line = [lines[l]] if not isinstance(lines[l], tuple) else lines[l] points = np.vstack( [l.get_path().interpolated(10).vertices for l in line]) # Get the midpoint of the text box. text_c = np.array( _get_midpoint(_get_text_bb(texts[l], textbox_margin))) # Get all distances. distances = [np.linalg.norm(text_c - p) for p in points] d_min_idx = np.argmin(distances) # If the distance is larger than the threshold, draw the line. if distances[d_min_idx] > arrow_threshold: # Find first point that doesn't intersect with any other text box. d_sorted_idx = np.argsort(distances) xytext = texts[l].get_position() xy = points[d_min_idx, :] for idx in d_sorted_idx: tmp_line = Path([xytext, points[idx, :]]) intersects_with_any_textbox = all( not tmp_line.intersects_bbox( _get_text_bb(texts[i], textbox_margin)) for i in range(len(texts)) if i != l) if intersects_with_any_textbox: xy = points[idx, :] break # Draw the new text with the arrow. a = axis.annotate(labels[l], xy=xy, xytext=xytext, ha="center", va="center", color=colors[l], arrowprops=dict(arrowstyle="->", color=colors[l])) # Hide original text. texts[l].set_visible(False) texts[l] = a # Remove the ugly legend. ugly_legend = axis.get_legend() if ugly_legend is not None: ugly_legend.remove() if debug: debug_f.show() # We return all the placed labels. return texts