def draw(self, renderer): xs3d, ys3d, zs3d = self._verts3d xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, renderer.M) FancyArrowPatch.set_positions(self,(xs[0],ys[0]),(xs[1],ys[1])) dx = np.abs(xs[0] - xs[1]) dy = np.abs(ys[0] - ys[1]) dz = np.abs(zs[0] - zs[1]) thresh = 0.4 if dx < thresh and dy < thresh: factor = np.max([dx,dy])*10. if factor > 0: self.set_arrowstyle("->",head_length=factor, head_width=factor/2.) else: self.set_arrowstyle("-") else: self.set_arrowstyle(self.init_style) FancyArrowPatch.draw(self, renderer)
class AnnotationBbox(martist.Artist, _AnnotationBase): """ Annotation-like class, but with offsetbox instead of Text. """ zorder = 3 def __str__(self): return "AnnotationBbox(%g,%g)"%(self.xy[0],self.xy[1]) @docstring.dedent_interpd def __init__(self, offsetbox, xy, xybox=None, xycoords='data', boxcoords=None, frameon=True, pad=0.4, # BboxPatch annotation_clip=None, box_alignment=(0.5, 0.5), bboxprops=None, arrowprops=None, fontsize=None, **kwargs): """ *offsetbox* : OffsetBox instance *xycoords* : same as Annotation but can be a tuple of two strings which are interpreted as x and y coordinates. *boxcoords* : similar to textcoords as Annotation but can be a tuple of two strings which are interpreted as x and y coordinates. *box_alignment* : a tuple of two floats for a vertical and horizontal alignment of the offset box w.r.t. the *boxcoords*. The lower-left corner is (0.0) and upper-right corner is (1.1). other parameters are identical to that of Annotation. """ self.offsetbox = offsetbox self.arrowprops = arrowprops self.set_fontsize(fontsize) if arrowprops is not None: self._arrow_relpos = self.arrowprops.pop("relpos", (0.5, 0.5)) self.arrow_patch = FancyArrowPatch((0, 0), (1,1), **self.arrowprops) else: self._arrow_relpos = None self.arrow_patch = None _AnnotationBase.__init__(self, xy, xytext=xybox, xycoords=xycoords, textcoords=boxcoords, annotation_clip=annotation_clip) martist.Artist.__init__(self, **kwargs) #self._fw, self._fh = 0., 0. # for alignment self._box_alignment = box_alignment # frame self.patch = FancyBboxPatch( xy=(0.0, 0.0), width=1., height=1., facecolor='w', edgecolor='k', mutation_scale=self.prop.get_size_in_points(), snap=True ) self.patch.set_boxstyle("square",pad=pad) if bboxprops: self.patch.set(**bboxprops) self._drawFrame = frameon def contains(self,event): t,tinfo = self.offsetbox.contains(event) #if self.arrow_patch is not None: # a,ainfo=self.arrow_patch.contains(event) # t = t or a # self.arrow_patch is currently not checked as this can be a line - JJ return t,tinfo def get_children(self): children = [self.offsetbox, self.patch] if self.arrow_patch: children.append(self.arrow_patch) return children def set_figure(self, fig): if self.arrow_patch is not None: self.arrow_patch.set_figure(fig) self.offsetbox.set_figure(fig) martist.Artist.set_figure(self, fig) def set_fontsize(self, s=None): """ set fontsize in points """ if s is None: s = rcParams["legend.fontsize"] self.prop=FontProperties(size=s) def get_fontsize(self, s=None): """ return fontsize in points """ return self.prop.get_size_in_points() def update_positions(self, renderer): "Update the pixel positions of the annotated point and the text." xy_pixel = self._get_position_xy(renderer) self._update_position_xybox(renderer, xy_pixel) mutation_scale = renderer.points_to_pixels(self.get_fontsize()) self.patch.set_mutation_scale(mutation_scale) if self.arrow_patch: self.arrow_patch.set_mutation_scale(mutation_scale) def _update_position_xybox(self, renderer, xy_pixel): "Update the pixel positions of the annotation text and the arrow patch." x, y = self.xytext if isinstance(self.textcoords, tuple): xcoord, ycoord = self.textcoords x1, y1 = self._get_xy(renderer, x, y, xcoord) x2, y2 = self._get_xy(renderer, x, y, ycoord) ox0, oy0 = x1, y2 else: ox0, oy0 = self._get_xy(renderer, x, y, self.textcoords) w, h, xd, yd = self.offsetbox.get_extent(renderer) _fw, _fh = self._box_alignment self.offsetbox.set_offset((ox0-_fw*w+xd, oy0-_fh*h+yd)) # update patch position bbox = self.offsetbox.get_window_extent(renderer) #self.offsetbox.set_offset((ox0-_fw*w, oy0-_fh*h)) self.patch.set_bounds(bbox.x0, bbox.y0, bbox.width, bbox.height) x, y = xy_pixel ox1, oy1 = x, y if self.arrowprops: x0, y0 = x, y d = self.arrowprops.copy() # Use FancyArrowPatch if self.arrowprops has "arrowstyle" key. # adjust the starting point of the arrow relative to # the textbox. # TODO : Rotation needs to be accounted. relpos = self._arrow_relpos ox0 = bbox.x0 + bbox.width * relpos[0] oy0 = bbox.y0 + bbox.height * relpos[1] # The arrow will be drawn from (ox0, oy0) to (ox1, # oy1). It will be first clipped by patchA and patchB. # Then it will be shrinked by shirnkA and shrinkB # (in points). If patch A is not set, self.bbox_patch # is used. self.arrow_patch.set_positions((ox0, oy0), (ox1,oy1)) fs = self.prop.get_size_in_points() mutation_scale = d.pop("mutation_scale", fs) mutation_scale = renderer.points_to_pixels(mutation_scale) self.arrow_patch.set_mutation_scale(mutation_scale) patchA = d.pop("patchA", self.patch) self.arrow_patch.set_patchA(patchA) def draw(self, renderer): """ Draw the :class:`Annotation` object to the given *renderer*. """ if renderer is not None: self._renderer = renderer if not self.get_visible(): return xy_pixel = self._get_position_xy(renderer) if not self._check_xy(renderer, xy_pixel): return self.update_positions(renderer) if self.arrow_patch is not None: if self.arrow_patch.figure is None and self.figure is not None: self.arrow_patch.figure = self.figure self.arrow_patch.draw(renderer) if self._drawFrame: self.patch.draw(renderer) self.offsetbox.draw(renderer)
class AnchoredCompass(AnchoredOffsetbox): def __init__(self, ax, transSky2Pix, loc, arrow_fraction=0.15, txt1="E", txt2="N", delta_a1=0, delta_a2=0, pad=0.1, borderpad=0.5, prop=None, frameon=False, ): """ Draw an arrows pointing the directions of E & N arrow_fraction : length of the arrow as a fraction of axes size pad, borderpad in fraction of the legend font size (or prop) """ self._ax = ax self._transSky2Pix = transSky2Pix self._box = AuxTransformBox(ax.transData) self.delta_a1, self.delta_a2 = delta_a1, delta_a2 self.arrow_fraction = arrow_fraction kwargs = dict(mutation_scale=11, shrinkA=0, shrinkB=5) self.arrow1 = FancyArrowPatch(posA=(0, 0), posB=(1, 1), arrowstyle="->", arrow_transmuter=None, connectionstyle="arc3", connector=None, **kwargs) self.arrow2 = FancyArrowPatch(posA=(0, 0), posB=(1, 1), arrowstyle="->", arrow_transmuter=None, connectionstyle="arc3", connector=None, **kwargs) x1t, y1t, x2t, y2t = 1, 1, 1, 1 self.txt1 = Text(x1t, y1t, txt1, rotation=0, rotation_mode="anchor", va="center", ha="right") self.txt2 = Text(x2t, y2t, txt2, rotation=0, rotation_mode="anchor", va="bottom", ha="center") self._box.add_artist(self.arrow1) self._box.add_artist(self.arrow2) self._box.add_artist(self.txt1) self._box.add_artist(self.txt2) AnchoredOffsetbox.__init__(self, loc, pad=pad, borderpad=borderpad, child=self._box, prop=prop, frameon=frameon) def set_path_effects(self, path_effects): for a in [self.arrow1, self.arrow2, self.txt1, self.txt2]: a.set_path_effects(path_effects) def _update_arrow(self, renderer): ax = self._ax x0, y0 = ax.viewLim.x0, ax.viewLim.y0 a1, a2 = estimate_angle(self._transSky2Pix, x0, y0) a1, a2 = a1+self.delta_a1, a2+self.delta_a2 D = min(ax.viewLim.width, ax.viewLim.height) d = D * self.arrow_fraction x1, y1 = x0+d*np.cos(a1/180.*np.pi), y0+d*np.sin(a1/180.*np.pi) x2, y2 = x0+d*np.cos(a2/180.*np.pi), y0+d*np.sin(a2/180.*np.pi) self.arrow1.set_positions((x0, y0), (x1, y1)) self.arrow2.set_positions((x0, y0), (x2, y2)) d2 = d x1t, y1t = x0+d2*np.cos(a1/180.*np.pi), y0+d2*np.sin(a1/180.*np.pi) x2t, y2t = x0+d2*np.cos(a2/180.*np.pi), y0+d2*np.sin(a2/180.*np.pi) self.txt1.set_position((x1t, y1t)) self.txt1.set_rotation(a1-180) self.txt2.set_position((x2t, y2t)) self.txt2.set_rotation(a2-90) def draw(self, renderer): self._update_arrow(renderer) super(AnchoredCompass, self).draw(renderer)
class AnnotationBbox(martist.Artist, _AnnotationBase): """ Annotation-like class, but with offsetbox instead of Text. """ zorder = 3 def __str__(self): return "AnnotationBbox(%g,%g)" % (self.xy[0], self.xy[1]) @docstring.dedent_interpd def __init__( self, offsetbox, xy, xybox=None, xycoords='data', boxcoords=None, frameon=True, pad=0.4, # BboxPatch annotation_clip=None, box_alignment=(0.5, 0.5), bboxprops=None, arrowprops=None, fontsize=None, **kwargs): """ *offsetbox* : OffsetBox instance *xycoords* : same as Annotation but can be a tuple of two strings which are interpreted as x and y coordinates. *boxcoords* : similar to textcoords as Annotation but can be a tuple of two strings which are interpreted as x and y coordinates. *box_alignment* : a tuple of two floats for a vertical and horizontal alignment of the offset box w.r.t. the *boxcoords*. The lower-left corner is (0.0) and upper-right corner is (1.1). other parameters are identical to that of Annotation. """ self.offsetbox = offsetbox self.arrowprops = arrowprops self.set_fontsize(fontsize) if arrowprops is not None: self._arrow_relpos = self.arrowprops.pop("relpos", (0.5, 0.5)) self.arrow_patch = FancyArrowPatch((0, 0), (1, 1), **self.arrowprops) else: self._arrow_relpos = None self.arrow_patch = None _AnnotationBase.__init__(self, xy, xytext=xybox, xycoords=xycoords, textcoords=boxcoords, annotation_clip=annotation_clip) martist.Artist.__init__(self, **kwargs) #self._fw, self._fh = 0., 0. # for alignment self._box_alignment = box_alignment # frame self.patch = FancyBboxPatch( xy=(0.0, 0.0), width=1., height=1., facecolor='w', edgecolor='k', mutation_scale=self.prop.get_size_in_points(), snap=True) self.patch.set_boxstyle("square", pad=pad) if bboxprops: self.patch.set(**bboxprops) self._drawFrame = frameon def contains(self, event): t, tinfo = self.offsetbox.contains(event) #if self.arrow_patch is not None: # a,ainfo=self.arrow_patch.contains(event) # t = t or a # self.arrow_patch is currently not checked as this can be a line - JJ return t, tinfo def get_children(self): children = [self.offsetbox, self.patch] if self.arrow_patch: children.append(self.arrow_patch) return children def set_figure(self, fig): if self.arrow_patch is not None: self.arrow_patch.set_figure(fig) self.offsetbox.set_figure(fig) martist.Artist.set_figure(self, fig) def set_fontsize(self, s=None): """ set fontsize in points """ if s is None: s = rcParams["legend.fontsize"] self.prop = FontProperties(size=s) def get_fontsize(self, s=None): """ return fontsize in points """ return self.prop.get_size_in_points() def update_positions(self, renderer): "Update the pixel positions of the annotated point and the text." xy_pixel = self._get_position_xy(renderer) self._update_position_xybox(renderer, xy_pixel) mutation_scale = renderer.points_to_pixels(self.get_fontsize()) self.patch.set_mutation_scale(mutation_scale) if self.arrow_patch: self.arrow_patch.set_mutation_scale(mutation_scale) def _update_position_xybox(self, renderer, xy_pixel): "Update the pixel positions of the annotation text and the arrow patch." x, y = self.xytext if isinstance(self.textcoords, tuple): xcoord, ycoord = self.textcoords x1, y1 = self._get_xy(renderer, x, y, xcoord) x2, y2 = self._get_xy(renderer, x, y, ycoord) ox0, oy0 = x1, y2 else: ox0, oy0 = self._get_xy(renderer, x, y, self.textcoords) w, h, xd, yd = self.offsetbox.get_extent(renderer) _fw, _fh = self._box_alignment self.offsetbox.set_offset((ox0 - _fw * w + xd, oy0 - _fh * h + yd)) # update patch position bbox = self.offsetbox.get_window_extent(renderer) #self.offsetbox.set_offset((ox0-_fw*w, oy0-_fh*h)) self.patch.set_bounds(bbox.x0, bbox.y0, bbox.width, bbox.height) x, y = xy_pixel ox1, oy1 = x, y if self.arrowprops: x0, y0 = x, y d = self.arrowprops.copy() # Use FancyArrowPatch if self.arrowprops has "arrowstyle" key. # adjust the starting point of the arrow relative to # the textbox. # TODO : Rotation needs to be accounted. relpos = self._arrow_relpos ox0 = bbox.x0 + bbox.width * relpos[0] oy0 = bbox.y0 + bbox.height * relpos[1] # The arrow will be drawn from (ox0, oy0) to (ox1, # oy1). It will be first clipped by patchA and patchB. # Then it will be shrinked by shirnkA and shrinkB # (in points). If patch A is not set, self.bbox_patch # is used. self.arrow_patch.set_positions((ox0, oy0), (ox1, oy1)) fs = self.prop.get_size_in_points() mutation_scale = d.pop("mutation_scale", fs) mutation_scale = renderer.points_to_pixels(mutation_scale) self.arrow_patch.set_mutation_scale(mutation_scale) patchA = d.pop("patchA", self.patch) self.arrow_patch.set_patchA(patchA) def draw(self, renderer): """ Draw the :class:`Annotation` object to the given *renderer*. """ if renderer is not None: self._renderer = renderer if not self.get_visible(): return xy_pixel = self._get_position_xy(renderer) if not self._check_xy(renderer, xy_pixel): return self.update_positions(renderer) if self.arrow_patch is not None: if self.arrow_patch.figure is None and self.figure is not None: self.arrow_patch.figure = self.figure self.arrow_patch.draw(renderer) if self._drawFrame: self.patch.draw(renderer) self.offsetbox.draw(renderer)
class AnchoredCompass(AnchoredOffsetbox): """ AnchoredOffsetbox class to create a NE compass on the plot given a value for ori (degrees E of N of the yaxis) """ def __init__(self, ax, ori, loc=4, arrow_fraction=0.15, txt1="E", txt2="N", pad=0.3, borderpad=0.5, prop=None, frameon=False): self._ax = ax self.ori = ori self._box = AuxTransformBox(ax.transData) self.arrow_fraction = arrow_fraction path_effects = [PathEffects.withStroke(linewidth=3, foreground="w")] kwargs = dict(mutation_scale=14, shrinkA=0, shrinkB=7) self.arrow1 = FancyArrowPatch(posA=(0, 0), posB=(1, 1), arrowstyle="-|>", arrow_transmuter=None, connectionstyle="arc3", connector=None, color="k", path_effects=path_effects, **kwargs) self.arrow2 = FancyArrowPatch(posA=(0, 0), posB=(1, 1), arrowstyle="-|>", arrow_transmuter=None, connectionstyle="arc3", connector=None, color="k", path_effects=path_effects, **kwargs) self.txt1 = Text(1, 1, txt1, rotation=0, rotation_mode="anchor", path_effects=path_effects, va="center", ha="center") self.txt2 = Text(2, 2, txt2, rotation=0, rotation_mode="anchor", path_effects=path_effects, va="center", ha="center") self._box.add_artist(self.arrow1) self._box.add_artist(self.arrow2) self._box.add_artist(self.txt1) self._box.add_artist(self.txt2) AnchoredOffsetbox.__init__(self, loc, pad=pad, borderpad=borderpad, child=self._box, prop=prop, frameon=frameon) def _update_arrow(self, renderer): ax = self._ax x0, y0 = ax.viewLim.x0, ax.viewLim.y0 a1, a2 = 180 - self.ori, 180 - self.ori - 90 D = min(ax.viewLim.width, ax.viewLim.height) d = D * self.arrow_fraction x1, y1 = x0 + d * np.cos(a1 / 180. * np.pi), y0 + d * np.sin( a1 / 180. * np.pi) x2, y2 = x0 + d * np.cos(a2 / 180. * np.pi), y0 + d * np.sin( a2 / 180. * np.pi) self.arrow1.set_positions((x0, y0), (x1, y1)) self.arrow2.set_positions((x0, y0), (x2, y2)) d2 = d x1t, y1t = x0 + d2 * np.cos(a1 / 180. * np.pi), y0 + d2 * np.sin( a1 / 180. * np.pi) x2t, y2t = x0 + d2 * np.cos(a2 / 180. * np.pi), y0 + d2 * np.sin( a2 / 180. * np.pi) self.txt1.set_position((x1t, y1t)) self.txt1.set_rotation(0) # a1-180 self.txt2.set_position((x2t + 20, y2t)) self.txt2.set_rotation(0) # a2-90 def draw(self, renderer): self._update_arrow(renderer) super(AnchoredCompass, self).draw(renderer)
class AnchoredCompass(AnchoredOffsetbox): def __init__(self, ax, transSky2Pix, loc, arrow_fraction=0.15, txt1="E", txt2="N", delta_a1=0, delta_a2=0, pad=0.1, borderpad=0.5, prop=None, frameon=False, color=None): """ Draw an arrows pointing the directions of E & N arrow_fraction : length of the arrow as a fraction of axes size pad, borderpad in fraction of the legend font size (or prop) """ self._ax = ax self._transSky2Pix = transSky2Pix self._box = AuxTransformBox(ax.transData) self.delta_a1, self.delta_a2 = delta_a1, delta_a2 self.arrow_fraction = arrow_fraction kwargs = dict(mutation_scale=11, shrinkA=0, shrinkB=5) self.arrow1 = FancyArrowPatch(posA=(0, 0), posB=(1, 1), arrowstyle="->", arrow_transmuter=None, connectionstyle="arc3", connector=None, color=color, **kwargs) self.arrow2 = FancyArrowPatch(posA=(0, 0), posB=(1, 1), arrowstyle="->", arrow_transmuter=None, connectionstyle="arc3", connector=None, color=color, **kwargs) x1t, y1t, x2t, y2t = 1, 1, 1, 1 self.txt1 = Text(x1t, y1t, txt1, rotation=0, rotation_mode="anchor", va="center", ha="right", color=color, fontproperties=prop) self.txt2 = Text(x2t, y2t, txt2, rotation=0, rotation_mode="anchor", va="bottom", ha="center", color=color, fontproperties=prop) self._box.add_artist(self.arrow1) self._box.add_artist(self.arrow2) self._box.add_artist(self.txt1) self._box.add_artist(self.txt2) AnchoredOffsetbox.__init__(self, loc, pad=pad, borderpad=borderpad, child=self._box, prop=prop, frameon=frameon) def set_path_effects(self, path_effects): for a in [self.arrow1, self.arrow2, self.txt1, self.txt2]: a.set_path_effects(path_effects) def _update_arrow(self, renderer): ax = self._ax x0, y0 = ax.viewLim.x0, ax.viewLim.y0 a1, a2 = estimate_angle_trans(self._transSky2Pix, x0, y0) a1, a2 = a1 + self.delta_a1, a2 + self.delta_a2 D = min(ax.viewLim.width, ax.viewLim.height) d = D * self.arrow_fraction x1, y1 = x0 + d * np.cos(a1 / 180. * np.pi), y0 + d * np.sin( a1 / 180. * np.pi) x2, y2 = x0 + d * np.cos(a2 / 180. * np.pi), y0 + d * np.sin( a2 / 180. * np.pi) self.arrow1.set_positions((x0, y0), (x1, y1)) self.arrow2.set_positions((x0, y0), (x2, y2)) d2 = d x1t, y1t = x0 + d2 * np.cos(a1 / 180. * np.pi), y0 + d2 * np.sin( a1 / 180. * np.pi) x2t, y2t = x0 + d2 * np.cos(a2 / 180. * np.pi), y0 + d2 * np.sin( a2 / 180. * np.pi) self.txt1.set_position((x1t, y1t)) self.txt1.set_rotation(a1 - 180) self.txt2.set_position((x2t, y2t)) self.txt2.set_rotation(a2 - 90) def draw(self, renderer): self._update_arrow(renderer) super(AnchoredCompass, self).draw(renderer)
class CorrespondenceEditor(object): ax: plt.Axes canvas: plt.FigureCanvasBase def __init__(self, ax, on_commit=None, facecolor=None, edgecolor=None, pick_radius=15): self.arrowstyle = "-|>,head_width=2.5, head_length=5" self.facecolor = facecolor or (0, 0.5, 1.0, 1.0) self.edgecolor = edgecolor or (0, 0.5, 1.0, 1.0) self.pick_radius = pick_radius self.scroll_speed = -0.5 # Possible semantics: # An integer > 1 -- fit a polynomial of that order # A float in [0,1] -- Use that fraction of the number of matches as the order self.order = 0.5 self._warp = None self._inverse_warp = None self.on_finish = on_commit self.ax = ax self.canvas = ax.figure.canvas self.matches = [[], []] self._active_point = -1 self._active_arrow = -1 self._arrow_patches = [] self._arrow_tails = [] # Zero-length arrows are iunvisible, so I will put markers at the tails self._arrow_highlight = FancyArrowPatch((0, 0), (0, 0), visible=False, arrowstyle=self.arrowstyle, color='y', linewidth=3) self._arrow_highlight_head = Line2D([0], [0], visible=False, marker='o', markersize=5, markerfacecolor='y', markeredgecolor='y', pickradius=self.pick_radius) self._arrow_highlight_tail = Line2D([0], [0], visible=False, marker='o', markersize=5, markerfacecolor='y', markeredgecolor='y', pickradius=self.pick_radius) self._point_highlight = Line2D([0], [0], visible=False, marker='o', markersize=5, markerfacecolor='r', markeredgecolor='r') self._arrow_highlight = self.ax.add_patch(self._arrow_highlight) self._arrow_highlight_head = self.ax.add_artist(self._arrow_highlight_head) self._arrow_highlight_tail = self.ax.add_artist(self._arrow_highlight_tail) self._point_highlight = self.ax.add_artist(self._point_highlight) self.cids = [] self._history = [] self._future = [] self.save_state() # Initialuze history (to empty doc) self._event = None self._mouse_down = None # Set to the event that started a drag. None if not dragging. self._editing = 0 # Whether we are editing. Counts up and down to allow compound edits. self.key_handlers = { ' ': self.commit, 's': partial(self.set_active_point, 0), 't': partial(self.set_active_point, 1), 'n': partial(self.new_arrow, selected=True, drag_target=False), 'delete': self.delete_arrow, 'e': self.delete_arrow, 'left': partial(self.nudge, dx=-1, dy=0), 'right': partial(self.nudge, dx=1, dy=0), 'up': partial(self.nudge, dx=0, dy=-1), 'down': partial(self.nudge, dx=0, dy=1), 'ctrl+z': self.undo, 'ctrl+Z': self.redo, 'enter': self.commit, } # Create the arrows self.refresh_arrows() self.connect_events() def __len__(self): return len(self.matches[0]) def __getitem__(self, i): return self.matches[0][i], self.matches[1][i] def __iter__(self): for i in range(len(self)): yield self[i] def begin_editing(self): """Call before changing the data""" self._editing += 1 def finish_editing(self): """Call after you are done modifying the data. Saves to history after all edits are complete. """ assert self._editing > 0 self._editing -= 1 if self._editing == 0: self._warp = None # We changed the matched --> the warp is dirty self._inverse_warp = None # We changed the matched --> the warp is dirty self.save_state() self.refresh_highlights() self.canvas.draw_idle() def get_state(self): return self.matches def set_state(self, state): s, t = state self.matches[0][:] = s self.matches[1][:] = t self.refresh_arrows() def save_state(self): # Sometimes I push multiple times for some reason if self._history and self._history[-1] == self.get_state(): return self._history.append(deepcopy(self.get_state())) del self._future[:] def undo(self): if self._history: self._future.append(deepcopy(self.get_state())) self.set_state(self._history.pop()) def redo(self): if self._future: self.set_state(self._future.pop()) if self._history and self.get_state() != self._history[-1]: self._history.append(deepcopy(self.get_state())) def connect(self, event, handler): self.cids.append(self.canvas.mpl_connect(event, handler)) def connect_events(self): self.connect('motion_notify_event', self._on_motion) self.connect('button_press_event', self._on_button_press) self.connect('button_release_event', self._on_button_release) self.connect('key_press_event', self._on_key_press) # self.connect('key_release_event', self._on_key_release) self.connect('scroll_event', self._on_scroll) def disconnect_events(self): """Disconnect all event handlers""" for cid in self.cids: self.canvas.mpl_disconnect(cid) self.cids = [] def ignore(self, event): return event.inaxes != self.ax def zoom(self, amount=1, xy=None): amount = 2 ** (amount) xmin, xmax = self.ax.get_xlim() ymin, ymax = self.ax.get_ylim() if xy is None: x = (xmin + xmax) / 2. y = (ymin + ymax) / 2. else: x, y = xy xmin = x + (xmin - x) * amount ymin = y + (ymin - y) * amount xmax = x + (xmax - x) * amount ymax = y + (ymax - y) * amount self.ax.set_xbound(xmin, xmax) self.ax.set_ybound(ymin, ymax) def _on_scroll(self, event: MouseEvent): x, y = event.xdata, event.ydata self.zoom(event.step * self.scroll_speed, (x, y)) self.canvas.draw_idle() def _on_motion(self, event: MouseEvent): if self.ignore(event): return if self._mouse_down: self.set_point((event.xdata, event.ydata)) else: self.select_arrow(event) def _distance_to_line_segment(self, p0, p1, xy, epsilon=1e-4): p0 = np.array(p0) p1 = np.array(p1) xy = np.array(xy) v = p1-p0 v2 = v@v if v2 > epsilon: q = (v @ (xy-p0))/(v @ v) # Distance along the edge q = np.clip(q, 0, 1) # Make sure we are on the line segment else: q = 0 p = p0 + v*q # Closest point on the line segment return np.linalg.norm(xy-p) def select_arrow(self, event: MouseEvent): patch: FancyArrowPatch min_distance = self.pick_radius argmin_distance = -1 p = np.array((event.xdata, event.ydata)) for i, (s, t) in enumerate(self): d = self._distance_to_line_segment(s, t, p) if d < min_distance: min_distance = d argmin_distance = i self.set_active_arrow(argmin_distance) if 0 <= argmin_distance: s, t = self[argmin_distance] d0 = norm(p-s) d1 = norm(p-t) if d0 < d1: self.set_active_point(0) else: self.set_active_point(1) def _start_drag(self, event: matplotlib.backend_bases.MouseEvent): self._mouse_down = event self.begin_editing() self.set_point((event.xdata, event.ydata)) def _on_button_press(self, event: matplotlib.backend_bases.MouseEvent): if self.ignore(event): return self._event = event if self.has_active_arrow(): if event.button == 1: # There are a couple of ways we can miss a mouse up event... if not self._mouse_down: self._start_drag(event) else: if event.button == 1: self.new_arrow(s=(event.xdata, event.ydata), t=(event.xdata, event.ydata), selected=True, drag_target=True) def _on_button_release(self, event: MouseEvent): if self.ignore(event): return if self._mouse_down: self._mouse_down = None self.finish_editing() def _on_key_press(self, event: KeyEvent): if self.ignore(event): return self._event = event if event.key in self.key_handlers: self.key_handlers[event.key]() def refresh_arrows(self): # Remove the old arrows while self._arrow_patches: self._arrow_patches.pop().remove() # Dont forget I added a point marker to the back of every arrow while self._arrow_tails: self._arrow_tails.pop().remove() # Add the new ones for s, t in zip(*self.matches): self._make_arrow_patch(s, t) # Refresh the highlights self.refresh_highlights() # Schedule a redraw self.canvas.draw_idle() def _set_highlight_arrow(self, s=None, t=None, visible=True): if s is not None and t is not None: self._arrow_highlight.set_positions(s, t) self._arrow_highlight_tail.set_data(s) self._arrow_highlight_head.set_data(t) self._arrow_highlight.set_visible(visible) self._arrow_highlight_head.set_visible(visible) self._arrow_highlight_tail.set_visible(visible) def refresh_highlights(self): if 0 <= self._active_arrow < len(self): self._set_highlight_arrow(self.matches[0][self._active_arrow], self.matches[1][self._active_arrow], visible=True) else: self._set_highlight_arrow(visible=False) if 0 <= self._active_arrow < len(self) and 0 <= self._active_point < 2: self._point_highlight.set_data(*(self.matches[self._active_point][self._active_arrow])) self._point_highlight.set_visible(True) else: self._point_highlight.set_visible(False) self.canvas.draw_idle() def commit(self, _unused=None): if self.on_finish: self.on_finish(self) def get_active_arrow(self): return self._active_arrow def set_active_arrow(self, i): if i != self._active_arrow: self._active_arrow = i self.refresh_highlights() def has_active_arrow(self): return 0 <= self._active_arrow < len(self) def get_active_point(self): return self._active_point def has_active_point(self): return 0 <= self._active_point < 2 def set_active_point(self, i): if i != self._active_point: self._active_point = i self.refresh_highlights() def _make_arrow_patch(self, s, t): a = FancyArrowPatch(s, t, arrowstyle=self.arrowstyle, facecolor=self.facecolor, edgecolor=self.edgecolor) a = self.ax.add_patch(a) self._arrow_patches.append(a) tail = Line2D([s[0]], [s[1]], marker='o', markersize=2.5, markeredgecolor=self.edgecolor, markerfacecolor=self.facecolor, pickradius=self.pick_radius ) tail = self.ax.add_artist(tail) self._arrow_tails.append(tail) @mutator def new_arrow(self, s=None, t=None, selected=True, drag_target=False): # Default to the location of the mouse in most recent event if s is None: if t is None: s = self._event.xdata, self._event.ydata else: s = self.predict_source(t) if t is None: if s is None: t = self._event.xdata, self._event.ydata else: t = self.predict_target(s) # Add the new source and target self.matches[0].append(s) self.matches[1].append(t) # Add a patch for the new arrow self._make_arrow_patch(s, t) # Users expect the new arrow to be selected if selected: self.set_active_arrow(len(self) - 1) self.set_active_point(1) # Presumably we clicked on the tail and we will click on the head next if drag_target: self._start_drag(self._event) # noinspection PyUnusedLocal @mutator def set_point(self, xy, point=None, arrow=None): patch: FancyArrowPatch line: Line2D if arrow is None: arrow = self._active_arrow if point is None: point = self._active_point if not 0 <= arrow < len(self): return # No arrows to select yet self.matches[point][arrow] = xy # Update the plot elements for the arrow patch = self._arrow_patches[arrow] patch.set_positions(self.matches[0][arrow], self.matches[1][arrow]) # And also move the tail self._arrow_tails[arrow].set_data(self.matches[0][arrow]) # Update the highlights if we are moving the selected item if arrow == self._active_arrow: self.refresh_highlights() def ensure_selected_arrow(self): if len(self) == 0: self.new_arrow() if self._active_arrow < 0: self.set_active_arrow(len(self) - 1) def ensure_selected_point(self): self.ensure_selected_arrow() if self._active_point < 0: self.set_active_arrow(len(self) - 1) @mutator def nudge(self, dx=0, dy=0): self.ensure_selected_point() x, y = self.matches[self._active_point][self._active_arrow] self.set_point((x + dx, y + dy)) @mutator def delete_arrow(self, arrow=None): if arrow is None: self.ensure_selected_arrow() arrow = self._active_arrow del self.matches[0][arrow] del self.matches[1][arrow] # Remove the patch from the plot arrow_patch = self._arrow_patches.pop(arrow) arrow_patch.remove() tail = self._arrow_tails.pop(arrow) tail.remove() # Update the active arrow index (it might have shifted) # The behavior if we delete the active arrow should be that # the next arrow is selected. Otherwise the selected arrow # should be the same. if self._active_arrow > arrow: self._active_arrow -= 1 def get_sources(self): return self.matches[0] def get_targets(self): return self.matches[1] def get_transform_order(self): if self.order > 1: order = min(self.order, len(self)) else: order = round(self.order * len(self)) return order def get_warp(self, recompute=False): if recompute: self._warp = None if self._warp is None: self._warp = PolynomialTransform() self._warp.estimate(np.array(self.get_sources()), np.array(self.get_targets()), self.get_transform_order()) return self._warp def get_inverse_warp(self, recompute=False): if recompute: self._inverse_warp = None if self._inverse_warp is None: self._inverse_warp = PolynomialTransform() self._inverse_warp.estimate(self.get_targets(), self.get_sources(), self.get_transform_order()) return self._warp def predict_target(self, s): t = self.get_warp()(np.array([s]))[0] return tuple(t) def predict_source(self, t): s = self.get_inverse_warp()(np.array([t]))[0] return tuple(s)
class CartPole: def __init__(self): # Global time of the simulation self.time = 0.0 # CartPole is initialized with state, control input, target position all zero # This is however usually changed before running the simulation. Treat it just as placeholders. # Container for the augmented state (angle, position and their first and second derivatives)of the cart self.s = s0 # (s like state) # Variables for control input and target position. self.u = 0.0 # Physical force acting on the cart self.Q = 0.0 # Dimensionless motor power in the range [-1,1] from which force is calculated with Q2u() method self.target_position = 0.0 # Physical parameters of the CartPole self.p = p_globals # region Time scales for simulation step, controller update and saving data # See last paragraph of "Time scales" section for explanations # ∆t in number of steps (related to simulation time step) # is set while setting corresponding dt through @property self.dt_controller_number_of_steps = 0 self.dt_save_number_of_steps = 0 # Counts time steps from last controller update or saving # is set while setting corresponding dt through @property self.dt_controller_steps_counter = 0 self.dt_save_steps_counter = 0 # Helper variables to set timescales self._dt_simulation = None self._dt_controller = None self._dt_save = None self.dt_simulation = None # s, Update CartPole dynamical state every dt_simulation seconds self.dt_controller = None # s, Recalculate control input every dt_controller_default seconds self.dt_save = None # s, Save CartPole state every dt_save_default seconds # endregion # region Variables controlling operation of the program - can be modified directly from CartPole environment self.rounding_decimals = 5 # Sets number of digits after coma to save in experiment history for each feature self.save_data_in_cart = True # Decides whether to store whole data of the experiment in dict_history or not self.stop_at_90 = False # If true pole is blocked after reaching the horizontal position # endregion # region Variables controlling operation of the program - should not be modified directly self.save_flag = False # Signalizes that the current time step should be saved self.csv_filepath = None # Where to save the experiment history. self.controller = None # Placeholder for the currently used controller function self.controller_name = '' # Placeholder for the currently used controller name self.controller_idx = None # Placeholder for the currently used controller index self.controller_names = self.get_available_controller_names( ) # list of controllers available in controllers folder # endregion # region Variables for generating experiments with random target trace # Parameters for random trace generation # These need to be set, before CartPole can generate random trace and random experiment self.track_relative_complexity = None # randomly placed target points/s self.length_of_experiment = None # seconds, length of the random length trace self.interpolation_type = None # Sets how to interpolate between turning points of random trace # Possible choices: '0-derivative-smooth', 'linear', 'previous' # '0-derivative-smooth' # -> turning points are connected with smooth interpolation curve having derivative = 0 at each turning p. # 'linear' -> turning points are connected with line segments # 'previous' -> between two turning points the value of the preceding point is kept constant. # In this last setting endpoint if set has no visible effect # (may however appear in the last line of the recording - TODO: not checked) self.turning_points_period = None # How turning points should be distributed # Possible choices: 'regular', 'random' # Regular means that they are equidistant from each other # Random means we pick randomly points at time axis at which we place turning points # Where the target position of the random experiment starts and end: self.start_random_target_position_at = None self.end_random_target_position_at = None # Alternatively you can provide a list of target positions. # e.g. self.turning_points = [10.0, 0.0, 0.0] # If not None this variable has precedence - # track_relative_complexity, start/end_random_target_position_at_globals have no effect. self.turning_points = None self.random_track_f = None # Function interpolataing the random target position between turning points self.new_track_generated = False # Flag informing that a new target position track is generated self.t_max_pre = None # Placeholder for the end time of the generated random experiment self.number_of_timesteps_in_random_experiment = None self.use_pregenerated_target_position = False # Informs method performing experiment # not to take target position from environment # endregion and self.dict_history = {} # Dictionary holding the experiment history # region Variables initialization for drawing/animating a CartPole # DIMENSIONS OF THE DRAWING ONLY!!! # NOTHING TO DO WITH THE SIMULATION AND NOT INTENDED TO BE MANIPULATED BY USER !!! # Variable relevant for interactive use of slider self.slider_max = 0.0 self.slider_value = 0.0 # Parameters needed to display CartPole in GUI # They are assigned with values in self.init_elements() self.CartLength = None self.WheelRadius = None self.WheelToMiddle = None self.y_plane = None self.y_wheel = None self.MastHight = None # For drawing only. For calculation see L self.MastThickness = None self.HalfLength = None # Length of the track # Elements of the drawing self.Mast = None self.Chassis = None self.WheelLeft = None self.WheelRight = None # Arrow indicating acceleration (=motor power) self.Acceleration_Arrow = None self.y_acceleration_arrow = None self.scaling_dx_acceleration_arrow = None self.x_acceleration_arrow = None # Depending on mode, slider may be displayed either as bar or as an arrow self.Slider_Bar = None self.Slider_Arrow = None self.t2 = None # An abstract container for the transform rotating the mast self.init_graphical_elements( ) # Assign proper object to the above variables # endregion # region Initialize CartPole in manual-stabilization mode self.set_controller('manual-stabilization') # endregion # region 1. Methods related to dynamic evolution of CartPole system # This method changes the internal state of the CartPole # from a state at time t to a state at t+dt # We assume this function is called for the first time to calculate first time step # @profile(precision=4) def update_state(self): # Update the total time of the simulation self.time = self.time + self.dt_simulation # Update target position depending on the mode of operation if self.use_pregenerated_target_position: # If time exceeds the max time for which target position was defined if self.time >= self.t_max_pre: return self.target_position = self.random_track_f(self.time) self.slider_value = self.target_position # Assign target position to slider to display it else: if self.controller_name == 'manual-stabilization': self.target_position = 0.0 # In this case target position is not used. # This just fill the corresponding column in history with zeros else: self.target_position = self.slider_value # Get target position from slider # Calculate the next state self.cartpole_integration() # Snippet to stop pole at +/- 90 deg if enabled zero_DD = None if self.stop_at_90: if self.s.angle >= np.pi / 2: self.s.angle = np.pi / 2 self.s.angleD = 0.0 zero_DD = True # Make also second derivatives 0 after they are calculated elif self.s.angle <= -np.pi / 2: self.s.angle = -np.pi / 2 self.s.angleD = 0.0 zero_DD = True # Make also second derivatives 0 after they are calculated else: zero_DD = False # Wrap angle to +/-π self.s.angle = wrap_angle_rad(self.s.angle) # In case in the next step the wheel of the cart # went beyond the track # Bump elastically into an (invisible) boarder if (abs(self.s.position) + self.WheelToMiddle) > self.HalfLength: self.s.positionD = -self.s.positionD # Determine the dimensionless [-1,1] value of the motor power Q self.Update_Q() # Convert dimensionless motor power to a physical force acting on the Cart self.u = Q2u(self.Q, self.p) # Update second derivatives self.s.angleDD, self.s.positionDD = cartpole_ode( self.p, self.s, self.u) if zero_DD: self.s.angleDD = 0.0 # Calculate time steps from last saving # The counter should be initialized at max-1 to start with a control input update self.dt_save_steps_counter += 1 # If update time interval elapsed save current state and zero the counter if self.dt_save_steps_counter == self.dt_save_number_of_steps: # If user chose to save history of the simulation it is saved now # It is saved first internally to a dictionary in the Cart instance if self.save_data_in_cart: # Saving simulation data self.dict_history['time'].append(self.time) self.dict_history['s.position'].append(self.s.position) self.dict_history['s.positionD'].append(self.s.positionD) self.dict_history['s.positionDD'].append(self.s.positionDD) self.dict_history['s.angle'].append(self.s.angle) self.dict_history['s.angleD'].append(self.s.angleD) self.dict_history['s.angleDD'].append(self.s.angleDD) self.dict_history['u'].append(self.u) self.dict_history['Q'].append(self.Q) # The target_position is not always meaningful # If it is not meaningful all values in this column are set to 0 self.dict_history['target_position'].append( self.target_position) self.dict_history['s.angle.sin'].append(np.sin(self.s.angle)) self.dict_history['s.angle.cos'].append(np.cos(self.s.angle)) else: self.dict_history = { 'time': [self.time], 's.position': [self.s.position], 's.positionD': [self.s.positionD], 's.positionDD': [self.s.positionDD], 's.angle': [self.s.angle], 's.angleD': [self.s.angleD], 's.angleDD': [self.s.angleDD], 'u': [self.u], 'Q': [self.Q], 'target_position': [self.target_position], 's.angle.sin': [np.sin(self.s.angle)], 's.angle.cos': [np.cos(self.s.angle)] } self.save_flag = True self.dt_save_steps_counter = 0 # A method integrating the cartpole ode over time step dt # Currently we use a simple single step Euler stepping def cartpole_integration(self): """Simple single step integration of CartPole state by dt Takes state as SimpleNamespace, but returns as separate variables :param s: state of the CartPole (contains: s.position, s.positionD, s.angle and s.angleD) :param dt: time step by which the CartPole state should be integrated """ self.s.position = self.s.position + self.s.positionD * self.dt_simulation self.s.positionD = self.s.positionD + self.s.positionDD * self.dt_simulation self.s.angle = self.s.angle + self.s.angleD * self.dt_simulation self.s.angleD = self.s.angleD + self.s.angleDD * self.dt_simulation # Determine the dimensionless [-1,1] value of the motor power Q # The function loads an external controller from PATH_TO_CONTROLLERS # This function should be called for the first time to calculate 0th time step # Otherwise it goes out of sync with saving def Update_Q(self): # Calculate time steps from last update # The counter should be initialized at max-1 to start with a control input update self.dt_controller_steps_counter += 1 # If update time interval elapsed update control input and zero the counter if self.dt_controller_steps_counter == self.dt_controller_number_of_steps: if self.controller_name == 'manual-stabilization': # in this case slider corresponds already to the power of the motor self.Q = self.slider_value else: # in this case slider gives a target position, lqr regulator self.Q = self.controller.step(self.s, self.target_position, self.time) self.dt_controller_steps_counter = 0 # endregion # region 2. Methods related to experiment history as a whole: saving, loading, plotting, resetting # This method saves the dictionary keeping the history of simulation to a .csv file def save_history_csv(self, csv_name=None, mode='init', length_of_experiment='unknown'): if mode == 'init': # Make folder to save data (if not yet existing) try: os.makedirs(PATH_TO_EXPERIMENT_RECORDINGS[:-1]) except FileExistsError: pass # Set path where to save the data if csv_name is None or csv_name == '': self.csv_filepath = PATH_TO_EXPERIMENT_RECORDINGS + 'CP_' + self.controller_name + str( datetime.now().strftime('_%Y-%m-%d_%H-%M-%S')) + '.csv' else: self.csv_filepath = PATH_TO_EXPERIMENT_RECORDINGS + csv_name if csv_name[-4:] != '.csv': self.csv_filepath += '.csv' # If such file exists, append index to the end (do not overwrite) net_index = 1 logpath_new = self.csv_filepath while True: if os.path.isfile(logpath_new): logpath_new = self.csv_filepath[:-4] else: self.csv_filepath = logpath_new break logpath_new = logpath_new + '-' + str(net_index) + '.csv' net_index += 1 # Write the .csv file with open(self.csv_filepath, "a") as outfile: writer = csv.writer(outfile) writer.writerow([ '# ' + 'This is CartPole experiment from {} at time {}'.format( datetime.now().strftime('%d.%m.%Y'), datetime.now().strftime('%H:%M:%S')) ]) try: repo = Repo() git_revision = repo.head.object.hexsha except: git_revision = 'unknown' writer.writerow( ['# ' + 'Done with git-revision: {}'.format(git_revision)]) writer.writerow(['#']) writer.writerow([ '# Length of experiment: {} s'.format( str(length_of_experiment)) ]) writer.writerow(['#']) writer.writerow(['# Time intervals dt:']) writer.writerow( ['# Simulation: {} s'.format(str(self.dt_simulation))]) writer.writerow([ '# Controller update: {} s'.format(str(self.dt_controller)) ]) writer.writerow(['# Saving: {} s'.format(str(self.dt_save))]) writer.writerow(['#']) writer.writerow( ['# Controller: {}'.format(self.controller_name)]) writer.writerow(['#']) writer.writerow(['# Parameters:']) for k in self.p.__dict__: writer.writerow( ['# ' + k + ': ' + str(getattr(self.p, k))]) writer.writerow(['#']) writer.writerow(['# Data:']) writer.writerow(self.dict_history.keys()) elif mode == 'save online': # Save this dict with open(self.csv_filepath, "a") as outfile: writer = csv.writer(outfile) self.dict_history = { key: np.around(value, self.rounding_decimals) for key, value in self.dict_history.items() } writer.writerows(zip(*self.dict_history.values())) self.save_now = False elif mode == 'save offline': # Round data to a set precision with open(self.csv_filepath, "a") as outfile: writer = csv.writer(outfile) self.dict_history = { key: np.around(value, self.rounding_decimals) for key, value in self.dict_history.items() } writer.writerows(zip(*self.dict_history.values())) self.save_now = False # Another possibility to save data. # DF_history = pd.DataFrame.from_dict(self.dict_history).round(self.rounding_decimals) # DF_history.to_csv(self.csv_filepath, index=False, header=False, mode='a') # Mode (a)ppend # load csv file with experiment recording (e.g. for replay) def load_history_csv(self, csv_name=None): # Set path where to save the data if csv_name is None or csv_name == '': # get the latest file try: list_of_files = glob.glob(PATH_TO_EXPERIMENT_RECORDINGS + '/*.csv') file_path = max(list_of_files, key=os.path.getctime) except FileNotFoundError: print( 'Cannot load: No experiment recording found in data folder ' + './data/') return False else: filename = csv_name if csv_name[-4:] != '.csv': filename += '.csv' # check if file found in DATA_FOLDER_NAME or at local starting point if not os.path.isfile(filename): file_path = os.path.join(PATH_TO_EXPERIMENT_RECORDINGS, filename) if not os.path.isfile(file_path): print( 'Cannot load: There is no experiment recording file with name {} at local folder or in {}' .format(filename, PATH_TO_EXPERIMENT_RECORDINGS)) return False # Get race recording print('Loading file {}'.format(file_path)) try: data: pd.DataFrame = pd.read_csv( file_path, comment='#') # skip comment lines starting with # except Exception as e: print('Cannot load: Caught {} trying to read CSV file {}'.format( e, file_path)) return False return data # Method plotting the dynamic evolution over time of the CartPole # It should be called after an experiment and only if experiment data was saved def summary_plots(self): fontsize_labels = 14 fontsize_ticks = 12 fig, axs = plt.subplots( 4, 1, figsize=(16, 9), sharex=True) # share x axis so zoom zooms all plots # Plot angle error axs[0].set_ylabel("Angle (deg)", fontsize=fontsize_labels) axs[0].plot(np.array(self.dict_history['time']), np.array(self.dict_history['s.angle']) * 180.0 / np.pi, 'b', markersize=12, label='Ground Truth') axs[0].tick_params(axis='both', which='major', labelsize=fontsize_ticks) # Plot position axs[1].set_ylabel("position (m)", fontsize=fontsize_labels) axs[1].plot(self.dict_history['time'], self.dict_history['s.position'], 'g', markersize=12, label='Ground Truth') axs[1].tick_params(axis='both', which='major', labelsize=fontsize_ticks) # Plot motor input command axs[2].set_ylabel("motor (N)", fontsize=fontsize_labels) axs[2].plot(self.dict_history['time'], self.dict_history['u'], 'r', markersize=12, label='motor') axs[2].tick_params(axis='both', which='major', labelsize=fontsize_ticks) # Plot target position axs[3].set_ylabel("position target (m)", fontsize=fontsize_labels) axs[3].plot(self.dict_history['time'], self.dict_history['target_position'], 'k') axs[3].tick_params(axis='both', which='major', labelsize=fontsize_ticks) axs[3].set_xlabel('Time (s)', fontsize=fontsize_labels) fig.align_ylabels() plt.show() return fig, axs # endregion # region 3. Methods for generating random target position for generation of random experiment # Generates a random target position # in a form of a function interpolating between turning points def Generate_Random_Trace_Function(self): if (self.turning_points is None) or (self.turning_points == []): number_of_turning_points = int( np.floor(self.length_of_experiment * self.track_relative_complexity)) y = 2.0 * (np.random.random(number_of_turning_points) - 0.5) y = y * 0.5 * self.HalfLength if number_of_turning_points == 0: y = np.append(y, 0.0) y = np.append(y, 0.0) elif number_of_turning_points == 1: if self.start_random_target_position_at is not None: y[0] = self.start_random_target_position_at elif self.end_random_target_position_at is not None: y[0] = self.end_random_target_position_at else: pass y = np.append(y, y[0]) else: if self.start_random_target_position_at is not None: y[0] = self.start_random_target_position_at if self.end_random_target_position_at is not None: y[-1] = self.end_random_target_position_at else: number_of_turning_points = len(self.turning_points) if number_of_turning_points == 0: raise ValueError('You should not be here!') elif number_of_turning_points == 1: y = np.array([self.turning_points[0], self.turning_points[0]]) else: y = np.array(self.turning_points) number_of_timesteps = np.ceil(self.length_of_experiment / self.dt_simulation) self.t_max_pre = number_of_timesteps * self.dt_simulation random_samples = number_of_turning_points - 2 if number_of_turning_points - 2 >= 0 else 0 # t_init = linspace(0, self.t_max_pre, num=self.track_relative_complexity, endpoint=True) if self.turning_points_period == 'random': t_init = np.sort( np.random.uniform(self.dt_simulation, self.t_max_pre - self.dt_simulation, random_samples)) t_init = np.insert(t_init, 0, 0.0) t_init = np.append(t_init, self.t_max_pre) elif self.turning_points_period == 'regular': t_init = np.linspace(0, self.t_max_pre, num=random_samples + 2, endpoint=True) else: raise NotImplementedError( 'There is no mode corresponding to this value of turning_points_period variable' ) # Try algorithm setting derivative to 0 a each point if self.interpolation_type == '0-derivative-smooth': yder = [[y[i], 0] for i in range(len(y))] random_track_f = BPoly.from_derivatives(t_init, yder) elif self.interpolation_type == 'linear': random_track_f = interp1d(t_init, y, kind='linear') elif self.interpolation_type == 'previous': random_track_f = interp1d(t_init, y, kind='previous') else: raise ValueError('Unknown interpolation type.') # Truncate the target position to be not grater than 80% of track length def random_track_f_truncated(time): target_position = random_track_f(time) if target_position > 0.8 * self.HalfLength: target_position = 0.8 * self.HalfLength elif target_position < -0.8 * self.HalfLength: target_position = -0.8 * self.HalfLength return target_position self.random_track_f = random_track_f_truncated self.new_track_generated = True # Prepare CartPole Instance to perform an experiment with random target position trace def setup_cartpole_random_experiment( self, # Initial state s0=None, controller=None, dt_simulation=None, dt_controller=None, dt_save=None, # Settings related to random trace generation track_relative_complexity=None, length_of_experiment=None, interpolation_type=None, turning_points_period=None, start_random_target_position_at=None, end_random_target_position_at=None, turning_points=None): # Set time scales: self.dt_simulation = dt_simulation self.dt_controller = dt_controller self.dt_save = dt_save # Set CartPole in the right (automatic control) mode # You may want to provide it before this function not to reload it every time if controller is not None: self.set_controller(controller) # Set initial state self.s = s0 self.track_relative_complexity = track_relative_complexity self.length_of_experiment = length_of_experiment self.interpolation_type = interpolation_type self.turning_points_period = turning_points_period self.start_random_target_position_at = start_random_target_position_at self.end_random_target_position_at = end_random_target_position_at self.turning_points = turning_points self.Generate_Random_Trace_Function() self.use_pregenerated_target_position = 1 self.number_of_timesteps_in_random_experiment = int( np.ceil(self.length_of_experiment / self.dt_simulation)) # Target position at time 0 self.target_position = self.random_track_f(self.time) # Make already in the first timestep Q appropriate to the initial state, target position and controller self.Update_Q() self.set_cartpole_state_at_t0(reset_mode=2, s=self.s, Q=self.Q, target_position=self.target_position) # Runs a random experiment with parameters set with setup_cartpole_random_experiment # And saves the experiment recording to csv file # @profile(precision=4) def run_cartpole_random_experiment(self, csv=None, save_mode='offline'): """ This function runs a random CartPole experiment and returns the history of CartPole states, control inputs and desired cart position """ if save_mode == 'offline': self.save_data_in_cart = True elif save_mode == 'online': self.save_data_in_cart = False else: raise ValueError('Unknown save mode value') # Create csv file for savign self.save_history_csv(csv_name=csv, mode='init', length_of_experiment=self.length_of_experiment) # Save 0th timestep if save_mode == 'online': self.save_history_csv(csv_name=csv, mode='save online') # Run the CartPole experiment for number of time for _ in trange(self.number_of_timesteps_in_random_experiment): # Print an error message if it runs already to long (should stop before) if self.time > self.t_max_pre: raise Exception( 'ERROR: It seems the experiment is running too long...') self.update_state() # Additional option to stop the experiment if abs(self.s.position) > 45.0: break print('Cart went out of safety boundaries') # if abs(CartPoleInstance.s.angle) > 0.8*np.pi: # raise ValueError('Cart went unstable') if save_mode == 'online' and self.save_flag: self.save_history_csv(csv_name=csv, mode='save online') self.save_flag = False data = pd.DataFrame(self.dict_history) if save_mode == 'offline': self.save_history_csv(csv_name=csv, mode='save offline') self.summary_plots() # Set CartPole state - the only use is to make sure that experiment history is discared # Maybe you can delete this line self.set_cartpole_state_at_t0(reset_mode=0) return data # endregion # region 4. Methods "Get, set, reset" # Method returns the list of controllers available in the PATH_TO_CONTROLLERS folder def get_available_controller_names(self): """ Method returns the list of controllers available in the PATH_TO_CONTROLLERS folder """ controller_files = glob.glob(PATH_TO_CONTROLLERS + 'controller_' + '*.py') controller_names = ['manual-stabilization'] controller_names.extend( np.sort([ os.path.basename(item)[len('controller_'):-len('.py')].replace( '_', '-') for item in controller_files ])) return controller_names # Set the controller of CartPole def set_controller(self, controller_name=None, controller_idx=None): """ The method sets a new controller as the current controller of the CartPole instance. The controller may be indicated either by its name or by the index on the controller list (see get_available_controller_names method). """ # Check if the proper information was provided: either controller_name or controller_idx if (controller_name is None) and (controller_idx is None): raise ValueError( 'You have to specify either controller_name or controller_idx to set a new controller.' 'You have specified none of the two.') elif (controller_name is not None) and (controller_idx is not None): raise ValueError( 'You have to specify either controller_name or controller_idx to set a new controller.' 'You have specified both.') else: pass # If controller name provided get controller index and vice versa if (controller_name is not None): try: controller_idx = self.controller_names.index(controller_name) except ValueError: raise ValueError( '{} is not in list. \n In list are: {}'.format( controller_name, self.controller_names)) else: controller_name = self.controller_names[controller_idx] # save controller name and index to variables in the CartPole namespace self.controller_name = controller_name self.controller_idx = controller_idx # Load controller if self.controller_name == 'manual-stabilization': self.controller = None else: controller_full_name = 'controller_' + self.controller_name.replace( '-', '_') path_import = PATH_TO_CONTROLLERS[2:].replace('/', '.').replace( r'\\', '.') import_str = 'from ' + path_import + controller_full_name + ' import ' + controller_full_name exec(import_str) self.controller = eval(controller_full_name + '()') # Set the maximal allowed value of the slider - relevant only for GUI if self.controller_name == 'manual-stabilization': self.slider_max = 1.0 self.Slider_Arrow.set_positions((0, 0), (0, 0)) else: self.slider_max = self.p.TrackHalfLength self.Slider_Bar.set_width(0.0) # This method resets the internal state of the CartPole instance # The starting state (for t = 0) may be # all zeros (reset_mode = 0) # set in this function (reset_mode = 1) # provide by user (reset_mode = 1), by giving s, Q and target_position def set_cartpole_state_at_t0(self, reset_mode=1, s=None, Q=None, target_position=None): self.time = 0.0 if reset_mode == 0: # Don't change it self.s.position = self.s.positionD = self.s.positionDD = 0.0 self.s.angle = self.s.angleD = self.s.angleDD = 0.0 self.Q = self.u = 0.0 self.slider = self.target_position = 0.0 elif reset_mode == 1: # You may change this but be carefull with other user. Better use 3 # You can change here with which initial parameters you wish to start the simulation self.s.position = 0.0 self.s.positionD = 0.0 self.s.angle = (2.0 * np.random.normal() - 1.0) * np.pi / 180.0 # np.pi/2.0 # self.s.angleD = 0.0 # 1.0 self.target_position = self.slider_value self.Q = 0.0 self.u = Q2u(self.Q, self.p) self.s.angleDD, self.s.positionDD = cartpole_ode( self.p, self.s, self.u) elif reset_mode == 2: # Don't change it if (s is not None) and (Q is not None) and (target_position is not None): self.s = s self.Q = Q self.slider = self.target_position = target_position self.u = Q2u(self.Q, self.p) # Calculate CURRENT control input self.s.angleDD, self.s.positionDD = cartpole_ode( self.p, self.s, self.u) # Calculate CURRENT second derivatives else: raise ValueError( 's, Q or target position not provided for initial state') # Reset the dict keeping the experiment history and save the state for t = 0 self.dict_history = { 'time': [self.time], 's.position': [self.s.position], 's.positionD': [self.s.positionD], 's.positionDD': [self.s.positionDD], 's.angle': [self.s.angle], 's.angleD': [self.s.angleD], 's.angleDD': [self.s.angleDD], 'u': [self.u], 'Q': [self.Q], 'target_position': [self.target_position], 's.angle.sin': [np.sin(self.s.angle)], 's.angle.cos': [np.cos(self.s.angle)] } # region Get and set timescales # Makes sure that when dt is updated also related variables are updated @property def dt_simulation(self): return self._dt_simulation @dt_simulation.setter def dt_simulation(self, value): self._dt_simulation = value if self._dt_controller is not None: self.dt_controller_number_of_steps = np.rint( self._dt_controller / value).astype(np.int32) if self.dt_controller_number_of_steps == 0: self.dt_controller_number_of_steps = 1 # Initialize counter at max value to start with update self.dt_controller_steps_counter = self.dt_controller_number_of_steps - 1 if self._dt_save is not None: self.dt_save_number_of_steps = np.rint(self._dt_save / value).astype(np.int32) if self.dt_save_number_of_steps == 0: self.dt_save_number_of_steps = 1 self.dt_save_steps_counter = 0 @property def dt_controller(self): return self._dt_controller @dt_controller.setter def dt_controller(self, value): self._dt_controller = value if self._dt_simulation is not None: self.dt_controller_number_of_steps = np.rint( value / self._dt_simulation).astype(np.int32) if self.dt_controller_number_of_steps == 0: self.dt_controller_number_of_steps = 1 # Initialize counter at max value to start with update self.dt_controller_steps_counter = self.dt_controller_number_of_steps - 1 @property def dt_save(self): return self._dt_save @dt_save.setter def dt_save(self, value): self._dt_save = value if self._dt_simulation is not None: self.dt_save_number_of_steps = np.rint( value / self._dt_simulation).astype(np.int32) if self.dt_save_number_of_steps == 0: self.dt_save_number_of_steps = 1 # This counter is initialized at 0 - 0th step is saved manually self.dt_save_steps_counter = 0 # endregion # endregion # region 5. Methods needed to display CartPole in GUI """ This section contains methods related to displaying CartPole in GUI of the simulator. One could think of moving these function outside of CartPole class and connecting them rather more tightly with GUI of the simulator. We leave them however as a part of CartPole class as they rely on variables of the CartPole. """ # This method initializes CartPole elements to be plotted in CartPole GUI def init_graphical_elements(self): self.CartLength = 10.0 self.WheelRadius = 0.5 self.WheelToMiddle = 4.0 self.y_plane = 0.0 self.y_wheel = self.y_plane + self.WheelRadius self.MastHight = 10.0 # For drawing only. For calculation see L self.MastThickness = 0.05 self.HalfLength = 50.0 # Length of the track self.y_acceleration_arrow = 1.5 * self.WheelRadius self.scaling_dx_acceleration_arrow = 20.0 self.x_acceleration_arrow = ( self.s.position + # np.sign(self.Q) * (self.CartLength / 2.0) + self.scaling_dx_acceleration_arrow * self.Q) # Initialize elements of the drawing self.Mast = FancyBboxPatch( xy=(self.s.position - (self.MastThickness / 2.0), 1.25 * self.WheelRadius), width=self.MastThickness, height=self.MastHight, fc='g') self.Chassis = FancyBboxPatch( (self.s.position - (self.CartLength / 2.0), self.WheelRadius), self.CartLength, 1 * self.WheelRadius, fc='r') self.WheelLeft = Circle( (self.s.position - self.WheelToMiddle, self.y_wheel), radius=self.WheelRadius, fc='y', ec='k', lw=5) self.WheelRight = Circle( (self.s.position + self.WheelToMiddle, self.y_wheel), radius=self.WheelRadius, fc='y', ec='k', lw=5) self.Acceleration_Arrow = FancyArrowPatch( (self.s.position, self.y_acceleration_arrow), (self.x_acceleration_arrow, self.y_acceleration_arrow), arrowstyle='simple', mutation_scale=10, facecolor='gold', edgecolor='orange') self.Slider_Arrow = FancyArrowPatch((self.slider_value, 0), (self.slider_value, 0), arrowstyle='fancy', mutation_scale=50) self.Slider_Bar = Rectangle((0.0, 0.0), self.slider_value, 1.0) self.t2 = transforms.Affine2D().rotate( 0.0) # An abstract container for the transform rotating the mast # This method accepts the mouse position and updated the slider value accordingly # The mouse position has to be captured by a function not included in this class def update_slider(self, mouse_position): # The if statement formulates a saturation condition if mouse_position > self.slider_max: self.slider_value = self.slider_max elif mouse_position < -self.slider_max: self.slider_value = -self.slider_max else: self.slider_value = mouse_position # This method draws elements and set properties of the CartPole figure # which do not change at every frame of the animation def draw_constant_elements(self, fig, AxCart, AxSlider): # Delete all elements of the Figure AxCart.clear() AxSlider.clear() ## Upper chart with Cart Picture # Set x and y limits AxCart.set_xlim((-self.HalfLength * 1.1, self.HalfLength * 1.1)) AxCart.set_ylim((-1.0, 15.0)) # Remove ticks on the y-axes AxCart.yaxis.set_major_locator(plt.NullLocator( )) # NullLocator is used to disable ticks on the Figures # Draw track Floor = Rectangle((-self.HalfLength, -1.0), 2 * self.HalfLength, 1.0, fc='brown') AxCart.add_patch(Floor) # Draw an invisible point at constant position # Thanks to it the axes is drawn high enough for the mast InvisiblePointUp = Rectangle((0, self.MastHight + 2.0), self.MastThickness, 0.0001, fc='w', ec='w') AxCart.add_patch(InvisiblePointUp) # Apply scaling AxCart.axis('scaled') ## Lower Chart with Slider # Set y limits AxSlider.set(xlim=(-1.1 * self.slider_max, self.slider_max * 1.1)) # Remove ticks on the y-axes AxSlider.yaxis.set_major_locator(plt.NullLocator( )) # NullLocator is used to disable ticks on the Figures # Apply scaling AxSlider.set_aspect("auto") return fig, AxCart, AxSlider # This method updates the elements of the Cart Figure which change at every frame. # Not that these elements are not ploted directly by this method # but rather returned as objects which can be used by another function # e.g. animation function from matplotlib package def update_drawing(self): self.x_acceleration_arrow = ( self.s.position + # np.sign(self.Q) * (self.CartLength / 2.0) + self.scaling_dx_acceleration_arrow * self.Q) self.Acceleration_Arrow.set_positions( (self.s.position, self.y_acceleration_arrow), (self.x_acceleration_arrow, self.y_acceleration_arrow)) # Draw mast mast_position = (self.s.position - (self.MastThickness / 2.0)) self.Mast.set_x(mast_position) # Draw rotated mast t21 = transforms.Affine2D().translate(-mast_position, -1.25 * self.WheelRadius) if ANGLE_CONVENTION == 'CLOCK-NEG': t22 = transforms.Affine2D().rotate(self.s.angle) elif ANGLE_CONVENTION == 'CLOCK-POS': t22 = transforms.Affine2D().rotate(-self.s.angle) else: raise ValueError('Unknown angle convention') t23 = transforms.Affine2D().translate(mast_position, 1.25 * self.WheelRadius) self.t2 = t21 + t22 + t23 # Draw Chassis self.Chassis.set_x(self.s.position - (self.CartLength / 2.0)) # Draw Wheels self.WheelLeft.center = (self.s.position - self.WheelToMiddle, self.y_wheel) self.WheelRight.center = (self.s.position + self.WheelToMiddle, self.y_wheel) # Draw SLider if self.controller_name == 'manual-stabilization': self.Slider_Bar.set_width(self.slider_value) else: self.Slider_Arrow.set_positions((self.slider_value, 0), (self.slider_value, 1.0)) return self.Mast, self.t2, self.Chassis, self.WheelRight, self.WheelLeft,\ self.Slider_Bar, self.Slider_Arrow, self.Acceleration_Arrow # A function redrawing the changing elements of the Figure def run_animation(self, fig): def init(): # Adding variable elements to the Figure fig.AxCart.add_patch(self.Mast) fig.AxCart.add_patch(self.Chassis) fig.AxCart.add_patch(self.WheelLeft) fig.AxCart.add_patch(self.WheelRight) fig.AxCart.add_patch(self.Acceleration_Arrow) fig.AxSlider.add_patch(self.Slider_Bar) fig.AxSlider.add_patch(self.Slider_Arrow) return self.Mast, self.Chassis, self.WheelLeft, self.WheelRight,\ self.Slider_Bar, self.Slider_Arrow, self.Acceleration_Arrow def animationManage(i): # Updating variable elements self.update_drawing() # Special care has to be taken of the mast rotation self.t2 = self.t2 + fig.AxCart.transData self.Mast.set_transform(self.t2) return self.Mast, self.Chassis, self.WheelLeft, self.WheelRight,\ self.Slider_Bar, self.Slider_Arrow, self.Acceleration_Arrow # Initialize animation object anim = animation.FuncAnimation( fig, animationManage, init_func=init, frames=300, # fargs=(CartPoleInstance,), # It was used when this function was a part of GUI class. Now left as an example how to add arguments to FuncAnimation interval=10, blit=True, repeat=True) return anim
class EnvironmentReader(Reader): def __init__(self, vl, room_shape, start_iter): super(EnvironmentReader,self).__init__(win_size=[700,700], win_loc=pos, title='Environment') self.vl = vl self.cur_iter = start_iter self.cur_i = 0 self.max_iter = np.max(vl['Iter num']) self.maxx = room_shape[0][1] self.maxy = room_shape[1][1] self.cntr_x = np.mean(room_shape[0]) self.cntr_y = np.mean(room_shape[1]) self.x_hist = []; self.y_hist = [] self.canvas.mpl_connect('resize_event',lambda x: self.update_background()) self.pos, = self.ax.plot([],[],color='g',animated=True) self.vel = Arrow([0,0],[1,0],arrowstyle='-|>, head_length=3, head_width=3', animated=True, linewidth=4) self.ax.add_patch(self.vel) #self.radius, = self.ax.plot([],[],color='r',animated=True) self.clockcounter = self.ax.text(self.maxx,room_shape[1][0], '',ha='right',size='large', animated=True) self.iters = self.ax.text(self.maxx-1, room_shape[1][0]+3, '',ha='right',animated=True) self.target = Circle([0,0],0,animated=True,color='r') self.target.set_radius(11) self.ax.add_artist(self.target) self.ax.set_xlim(room_shape[0]) self.ax.set_ylim(room_shape[1]) self.draw_plc_flds() self.draw_HMM_sectors() self.predicted_counter = self.ax.text(self.maxx,room_shape[1][0]+10,'', ha='right',size='large',animated=True) self.prediction_hist = None #self.cur_sec = Rectangle([0,0],self.HMM.dx,self.HMM.dy,fill=True, # color='k',animated=True) #self.ax_env.add_patch(self.cur_sec) self.canvas.draw() def read(self, iteration=None): ''' Return the environment data corresponding to the next iteration. Note that 0 or multiple data points can be associated with a single iteration number. All of the environment data before self.cur_i has already been recorded.''' if iteration is not None: self.cur_iter = iteration try: cur_j = 1+self.cur_i+ np.nonzero(self.vl['Iter num'][self.cur_i:] == self.cur_iter)[0][-1] except: # There is no iteration with that value self.cur_iter += 1 return [np.NAN, np.NAN, np.NAN, np.NAN] cur_x = self.vl['xs'][self.cur_i:cur_j] cur_y = self.vl['ys'][self.cur_i:cur_j] cur_vx = self.vl['vxs'][self.cur_i:cur_j] cur_vy = self.vl['vys'][self.cur_i:cur_j] self.cur_iter += 1 self.cur_i = cur_j return (cur_x, cur_y, cur_vx, cur_vy) def draw_HMM_sectors(self): return for i in range(self.HMM.cll_p_sd): curx = self.HMM.xrange[0]+i*self.HMM.dx cury = self.HMM.yrange[0]+i*self.HMM.dy self.ax_env.plot([curx,curx],[0,self.maxy],'k') self.ax_env.plot([0,self.maxx],[cury,cury], 'k') def draw_plc_flds(self): return clrs = ['b','g'] legend_widgets = [] for i, plc_flds in zip(range(len(self.WR.contexts)),self.WR.contexts.values()): added = False for plc_fld in plc_flds: circ = Circle([plc_fld.x,plc_fld.y], plc_fld.r, color=clrs[i]) self.ax_env.add_patch(circ) if not added: legend_widgets.append(circ) added=True self.ax.legend(legend_widgets, ('counterclockwise','clockwise'),'lower right') def draw(self, xs,ys,vxs,vys): ''' Display the environment data to the window. The environment data are inputs. ''' if np.any(np.isnan(xs)): return # Display how far along in the animation we are self.iters.set_text('%i/%i'%(self.cur_iter,self.max_iter)) # Begin the drawing process self.canvas.restore_region(self.background) # Calculate and draw rat's physical position x = xs[-1]; y=ys[-1]; vx=vxs[-1]; vy=vys[-1] self.x_hist.extend(xs) self.y_hist.extend(ys) self.pos.set_data(self.x_hist,self.y_hist) # Adjust velocity vector if vx != 0 or vy != 0: self.vel.set_positions([x, y], [x+vx,y+vy]) # Adjust radius line #self.radius.set_data([self.cntr_x,x],[self.cntr_y,y]) # Calculate physical orientation, display it, and compare with # virmenLog's orientatino assessment cur_orientation = self.is_clockwise(x,y,vx,vy) if cur_orientation == 1: self.clockcounter.set_text('Clockwise') else: self.clockcounter.set_text('Counterclockwise') # Adjust the location of the rat's chased target target_x = self.vl['txs'][self.cur_i] target_y = self.vl['tys'][self.cur_i] self.target.center = [target_x,target_y] # Make a context prediction try: self._make_prediction(self.is_clockwise(x,y,vx,vy)) except: pass #logging.warning('Make prediction failed.') # Update the drawing window for itm in [self.pos, self.vel, #self.radius, self.clockcounter,self.iters, self.predicted_counter, self.target]: self.ax.draw_artist(itm) self.canvas.blit(self.ax.bbox) def is_clockwise(self,x,y,vx,vy): ''' Determines if motion is clockwise around the center of the room, which is [0, MAXX] x [0, MAXY] Output mapped to {-1,1} to conform with vl['Task'] labels.''' cross_prod = (x-self.cntr_x)*vy - (y-self.cntr_y)*vx clockwise = 2*(cross_prod>0)-1 return clockwise
class ThreeCompVisualisation: """ Basis to visualise power flow within the hydraulics model as an animation or simulation """ def __init__(self, agent: ThreeCompHydAgent, axis: plt.axis = None, animated: bool = False, detail_annotations: bool = False, basic_annotations: bool = True, black_and_white: bool = False): """ Whole visualisation setup using given agent's parameters :param agent: The agent to be visualised :param axis: If set, the visualisation will be drawn using the provided axis object. Useful for animations or to display multiple models in one plot :param animated: If true the visualisation is set up to deal with frame updates. See animation script for more details. :param detail_annotations: If true, tank distances and sizes are annotated as well. :param basic_annotations: If true, U, LF, and LS are visible :param black_and_white: If true, the visualisation is in black and white """ # matplotlib fontsize rcParams['font.size'] = 10 # plot if no axis was assigned if axis is None: fig = plt.figure(figsize=(8, 5)) self._ax1 = fig.add_subplot(1, 1, 1) else: fig = None self._ax1 = axis if black_and_white: self.__u_color = (0.7, 0.7, 0.7) self.__lf_color = (0.5, 0.5, 0.5) self.__ls_color = (0.3, 0.3, 0.3) self.__ann_color = (0, 0, 0) self.__p_color = (0.5, 0.5, 0.5) elif not black_and_white: self.__u_color = "tab:cyan" self.__lf_color = "tab:orange" self.__ls_color = "tab:red" self.__ann_color = "tab:blue" self.__p_color = "tab:green" # basic parameters for setup self._animated = animated self.__detail_annotations = detail_annotations self._agent = agent self.__offset = 0.2 # U tank with three stripes self.__width_u = 0.3 self._u = None self._u1 = None self._u2 = None self._r1 = None # line marking flow from U to LF self._ann_u = None # U annotation # LF tank self._lf = None self._h = None # fill state self._ann_lf = None # annotation # LS tank self._ls = None self._g = None self._ann_ls = None # annotation LS self._r2 = None # line marking flow from LS to LF # finish the basic layout self.__set_basic_layout() self.update_basic_layout(agent) # now the animation components if self._animated: # U flow self._arr_u_flow = None self._ann_u_flow = None # flow out of tap self._arr_power_flow = None self._ann_power_flow = None # LS flow (R2) self._arr_r2_l_pos = None self._arr_r2_flow = None self._ann_r2_flow = None # time information annotation self._ann_time = None self.__set_animation_layout() self._ax1.add_artist(self._ann_time) # basic annotations are U, LF, and LS if not basic_annotations: self.hide_basic_annotations() # add layout for detailed annotations # detail annotation add greek letters for distances and positions if self.__detail_annotations: self.__set_detailed_annotations_layout() self._ax1.set_xlim(0, 1.05) self._ax1.set_ylim(0, 1.2) else: self._ax1.set_xlim(0, 1.0) self._ax1.set_ylim(0, 1.2) if self.__detail_annotations and self._animated: raise UserWarning("Detailed annotations and animation cannot be combined") self._ax1.set_axis_off() # display plot if no axis object was assigned if fig is not None: plt.subplots_adjust(left=0.02, bottom=0.02, right=0.98, top=0.98) plt.show() plt.close(fig) def __set_detailed_annotations_layout(self): """ Adds components required for a detailed annotations view with denoted positions and distances """ u_width = self.__width_u ls_left = self._ls.get_x() ls_width = self._ls.get_width() lf_left = self._lf.get_x() lf_width = self._lf.get_width() ls_height = self._ls.get_height() # some offset to the bottom offset = self.__offset phi_o = self._agent.phi + offset gamma_o = self._agent.gamma + offset rcParams['text.usetex'] = True self._ann_u_flow = Text(text="$M_{U}$", ha='right', fontsize="xx-large", x=u_width + 0.09, y=phi_o - 0.08) ann_p_ae = Text(text="$p_{U}$", ha='right', fontsize="xx-large", x=u_width + 0.07, y=phi_o + 0.03) self._arr_u_flow = FancyArrowPatch((u_width, phi_o), (u_width + 0.1, phi_o), arrowstyle='-|>', mutation_scale=30, lw=2, color=self.__u_color) self._ax1.annotate('$\phi$', xy=(u_width / 2, phi_o), xytext=(u_width / 2, (phi_o - offset) / 2 + offset - 0.015), ha='center', fontsize="xx-large", arrowprops=dict(arrowstyle='-|>', ls='-', fc=self.__ann_color) ) self._ax1.annotate('$\phi$', xy=(u_width / 2, offset), xytext=(u_width / 2, (phi_o - offset) / 2 + offset - 0.015), ha='center', fontsize="xx-large", arrowprops=dict(arrowstyle='-|>', ls='-', fc=self.__ann_color) ) self._ann_power_flow = Text(text="$p$", ha='center', fontsize="xx-large", x=self._ann_lf.get_position()[0], y=offset - 0.06) self._arr_power_flow = FancyArrowPatch((self._ann_lf.get_position()[0], offset - 0.078), (self._ann_lf.get_position()[0], 0.0), arrowstyle='-|>', mutation_scale=30, lw=2, color=self.__p_color) self._ax1.annotate('$h$', xy=(self._ann_lf.get_position()[0] + 0.07, 1 + offset), xytext=(self._ann_lf.get_position()[0] + 0.07, 1 + offset - 0.30), ha='center', fontsize="xx-large", arrowprops=dict(arrowstyle='-|>', ls='-', fc=self.__ann_color) ) self._ax1.annotate('$h$', xy=(self._ann_lf.get_position()[0] + 0.07, 1 + offset - 0.55), xytext=(self._ann_lf.get_position()[0] + 0.07, 1 + offset - 0.30), ha='center', fontsize="xx-large", arrowprops=dict(arrowstyle='-|>', ls='-', fc=self.__ann_color) ) self._h.update(dict(xy=(lf_left, offset), width=lf_width, height=1 - 0.55, color=self.__lf_color)) self._ax1.annotate('$g$', xy=(ls_left + ls_width + 0.02, ls_height + gamma_o), xytext=(ls_left + ls_width + 0.02, ls_height * 0.61 + gamma_o), ha='center', fontsize="xx-large", arrowprops=dict(arrowstyle='-|>', ls='-', fc=self.__ann_color) ) self._ax1.annotate('$g$', xy=(ls_left + ls_width + 0.02, ls_height * 0.3 + gamma_o), xytext=(ls_left + ls_width + 0.02, ls_height * 0.61 + gamma_o), ha='center', fontsize="xx-large", arrowprops=dict(arrowstyle='-|>', ls='-', fc=self.__ann_color) ) self._g.update(dict(xy=(ls_left, gamma_o), width=ls_width, height=ls_height * 0.3, color=self.__ls_color)) ann_p_an = Text(text="$p_{L}$", ha='left', usetex=True, fontsize="xx-large", x=ls_left - 0.06, y=gamma_o + 0.11) ann_arr_flow = Text(text="$M_{LS}$", ha='left', usetex=True, fontsize="xx-large", x=ls_left - 0.09, y=gamma_o + 0.03) self._ann_r2_flow = Text(text="$M_{LF}$", ha='left', usetex=True, fontsize="xx-large", x=ls_left - 0.04, y=gamma_o - 0.07) self._arr_r2_flow = FancyArrowPatch((ls_left - 0.1, gamma_o), (ls_left, gamma_o), arrowstyle='<|-|>', mutation_scale=30, lw=2, color=self.__ls_color) self._ax1.annotate('$\\theta$', xy=(ls_left + ls_width / 2, 1 + offset), xytext=( ls_left + ls_width / 2, 1 - (1 - (ls_height + gamma_o - offset)) / 2 + offset - 0.015), ha='center', fontsize="xx-large", arrowprops=dict(arrowstyle='-|>', ls='-', fc=self.__ann_color) ) self._ax1.annotate('$\\theta$', xy=(ls_left + ls_width / 2, ls_height + gamma_o), xytext=( ls_left + ls_width / 2, 1 - (1 - (ls_height + gamma_o - offset)) / 2 + offset - 0.015), ha='center', fontsize="xx-large", arrowprops=dict(arrowstyle='-|>', ls='-', fc=self.__ann_color) ) self._ax1.annotate('$\\gamma$', xy=(ls_left + ls_width / 2, offset), xytext=(ls_left + ls_width / 2, (gamma_o - offset) / 2 + offset - 0.015), ha='center', fontsize="xx-large", arrowprops=dict(arrowstyle='-|>', ls='-', fc=self.__ann_color) ) self._ax1.annotate('$\\gamma$', xy=(ls_left + ls_width / 2, gamma_o), xytext=(ls_left + ls_width / 2, (gamma_o - offset) / 2 + offset - 0.015), ha='center', fontsize="xx-large", arrowprops=dict(arrowstyle='-|>', ls='-', fc=self.__ann_color) ) self._ax1.annotate('$1$', xy=(1.05, 0 + offset), xytext=(1.05, 0.5 + offset), ha='center', fontsize="xx-large", arrowprops=dict(arrowstyle='-|>', ls='-', fc=self.__ann_color) ) self._ax1.annotate('$1$', xy=(1.05, 1 + offset), xytext=(1.05, 0.5 + offset), ha='center', fontsize="xx-large", arrowprops=dict(arrowstyle='-|>', ls='-', fc=self.__ann_color) ) self._ax1.add_artist(ann_arr_flow) # self._ax1.add_artist(ann_p_an) # self._ax1.add_artist(ann_p_ae) self._ax1.axhline(offset, linestyle='--', color=self.__ann_color) self._ax1.axhline(1 + offset - 0.001, linestyle='--', color=self.__ann_color) self._ax1.add_artist(self._ann_power_flow) self._ax1.add_artist(self._arr_power_flow) self._ax1.add_artist(self._arr_u_flow) self._ax1.add_artist(self._ann_u_flow) self._ax1.add_artist(self._arr_r2_flow) self._ax1.add_artist(self._ann_r2_flow) def __set_animation_layout(self): """ Adds layout components that are required for an animation """ offset = self.__offset o_width = self.__width_u phi_o = self._agent.phi + offset gamma_o = self._agent.gamma + offset # U flow (R1) self._arr_u_flow = FancyArrowPatch((o_width, phi_o), (o_width + 0.1, phi_o), arrowstyle='simple', mutation_scale=0, ec='white', fc=self.__u_color) self._ann_u_flow = Text(text="flow: ", ha='right', fontsize="large", x=o_width, y=phi_o - 0.05) # Tap flow (Power) self._arr_power_flow = FancyArrowPatch((self._ann_lf.get_position()[0], offset - 0.05), (self._ann_lf.get_position()[0], 0.0), arrowstyle='simple', mutation_scale=0, ec='white', color=self.__p_color) self._ann_power_flow = Text(text="flow: ", ha='center', fontsize="large", x=self._ann_lf.get_position()[0], y=offset - 0.05) # LS flow (R2) self._arr_r2_l_pos = [(self._ls.get_x(), gamma_o), (self._ls.get_x() - 0.1, gamma_o)] self._arr_r2_flow = FancyArrowPatch(self._arr_r2_l_pos[0], self._arr_r2_l_pos[1], arrowstyle='simple', mutation_scale=0, ec='white', color=self.__ls_color) self._ann_r2_flow = Text(text="flow: ", ha='left', fontsize="large", x=self._ls.get_x(), y=gamma_o - 0.05) # information annotation self._ann_time = Text(x=1, y=0.9 + offset, ha="right") self._ax1.add_artist(self._ann_power_flow) self._ax1.add_artist(self._arr_power_flow) self._ax1.add_artist(self._arr_u_flow) self._ax1.add_artist(self._ann_u_flow) self._ax1.add_artist(self._arr_r2_flow) self._ax1.add_artist(self._ann_r2_flow) def __set_basic_layout(self): """ updates position estimations and layout """ # get sizes from agent lf = self._agent.lf ls = self._agent.ls ls_height = self._agent.height_ls # u_left is 0 u_width = self.__width_u # determine width with size ratio retained lf_left = u_width + 0.1 lf_width = ((lf * ls_height) * (1 - lf_left - 0.1)) / ls ls_left = lf_left + lf_width + 0.1 ls_width = 1 - ls_left # some offset to the bottom offset = self.__offset phi_o = self._agent.phi + offset gamma_o = self._agent.gamma + offset # S tank self._u = Rectangle((0.0, phi_o), 0.05, 1 - self._agent.phi, color=self.__u_color, alpha=0.3) self._u1 = Rectangle((0.05, phi_o), 0.05, 1 - self._agent.phi, color=self.__u_color, alpha=0.6) self._u2 = Rectangle((0.1, phi_o), u_width - 0.1, 1 - self._agent.phi, color=self.__u_color) self._r1 = Line2D([u_width, u_width + 0.1], [phi_o, phi_o], color=self.__u_color) self._ann_u = Text(text="$U$", ha='center', fontsize="xx-large", x=u_width / 2, y=((1 - self._agent.phi) / 2) + phi_o - 0.02) # LF vessel self._lf = Rectangle((lf_left, offset), lf_width, 1, fill=False, ec="black") self._h = Rectangle((lf_left, offset), lf_width, 1, color=self.__lf_color) self._ann_lf = Text(text="$LF$", ha='center', fontsize="xx-large", x=lf_left + (lf_width / 2), y=offset + 0.5 - 0.02) # LS vessel self._ls = Rectangle((ls_left, gamma_o), ls_width, ls_height, fill=False, ec="black") self._g = Rectangle((ls_left, gamma_o), ls_width, ls_height, color=self.__ls_color) self._r2 = Line2D([ls_left, ls_left - 0.1], [gamma_o, gamma_o], color=self.__ls_color) self._ann_ls = Text(text="$LS$", ha='center', fontsize="xx-large", x=ls_left + (ls_width / 2), y=gamma_o + (ls_height / 2) - 0.02) # the basic layout self._ax1.add_line(self._r1) self._ax1.add_line(self._r2) self._ax1.add_artist(self._u) self._ax1.add_artist(self._u1) self._ax1.add_artist(self._u2) self._ax1.add_artist(self._lf) self._ax1.add_artist(self._ls) self._ax1.add_artist(self._h) self._ax1.add_artist(self._g) self._ax1.add_artist(self._ann_u) self._ax1.add_artist(self._ann_lf) self._ax1.add_artist(self._ann_ls) def update_basic_layout(self, agent): """ updates tank positions and sizes according to new agent :param agent: agent to be visualised """ self._agent = agent # get sizes from agent lf = agent.lf ls = agent.ls ls_heigh = agent.height_ls # o_left is 0 u_width = self.__width_u # determine width with size ratio retained lf_left = u_width + 0.1 lf_width = ((lf * ls_heigh) * (1 - lf_left - 0.1)) / (ls + lf * ls_heigh) ls_left = lf_left + lf_width + 0.1 ls_width = 1 - ls_left # some offset to the bottom offset = self.__offset phi_o = agent.phi + offset gamma_o = agent.gamma + offset # S tank self._u.set_bounds(0.0, phi_o, 0.05, 1 - self._agent.phi) self._u1.set_bounds(0.05, phi_o, 0.05, 1 - self._agent.phi) self._u2.set_bounds(0.1, phi_o, u_width - 0.1, 1 - self._agent.phi) self._r1.set_xdata([u_width, u_width + 0.1]) self._r1.set_ydata([phi_o, phi_o]) self._ann_u.set_position(xy=(u_width / 2, ((1 - self._agent.phi) / 2) + phi_o - 0.02)) # LF vessel self._lf.set_bounds(lf_left, offset, lf_width, 1) self._h.set_bounds(lf_left, offset, lf_width, 1) self._ann_lf.set_position(xy=(lf_left + (lf_width / 2), offset + 0.5 - 0.02)) # LS vessel self._ls.set_bounds(ls_left, gamma_o, ls_width, ls_heigh) self._g.set_bounds(ls_left, gamma_o, ls_width, ls_heigh) self._r2.set_xdata([ls_left, ls_left - 0.1]) self._r2.set_ydata([gamma_o, gamma_o]) self._ann_ls.set_position(xy=(ls_left + (ls_width / 2), gamma_o + (ls_heigh / 2) - 0.02)) # update levels self._h.set_height(1 - self._agent.get_h()) self._g.set_height(self._agent.height_ls - self._agent.get_g()) def hide_basic_annotations(self): """ Simply hides the S, LF, and LS text """ self._ann_u.set_text("") self._ann_lf.set_text("") self._ann_ls.set_text("") def update_animation_data(self, frame_number): """ For animations and simulations. The function to call at each frame. :param frame_number: frame number has to be taken because of parent class method :return: an iterable of artists """ if not self._animated: raise UserWarning("Animation flag has to be enabled in order to use this function") # perform one step cur_time = self._agent.get_time() power = self._agent.perform_one_step() # draw some information self._ann_time.set_text("agent \n time: {}".format(int(cur_time))) # power arrow self._ann_power_flow.set_text("power: {}".format(round(power))) self._arr_power_flow.set_mutation_scale(math.log(power + 1) * 10) # oxygen arrow p_u = round(self._agent.get_p_u() * self._agent.hz, 1) max_str = "(MAX)" if p_u == self._agent.m_u else "" self._ann_u_flow.set_text("flow: {} {}".format(p_u, max_str)) self._arr_u_flow.set_mutation_scale(math.log(p_u + 1) * 10) # lactate arrow p_g = round(self._agent.get_p_l() * self._agent.hz, 1) if p_g < 0: max_str = "(MAX)" if p_g == self._agent.m_lf else "" self._arr_r2_flow.set_positions(self._arr_r2_l_pos[1], self._arr_r2_l_pos[0]) else: max_str = "(MAX)" if p_g == self._agent.m_ls else "" self._arr_r2_flow.set_positions(self._arr_r2_l_pos[0], self._arr_r2_l_pos[1]) self._ann_r2_flow.set_text("flow: {} {}".format(p_g, max_str)) self._arr_r2_flow.set_mutation_scale(math.log(abs(p_g) + 1) * 10) # update levels self._h.set_height(1 - self._agent.get_h()) self._g.set_height(self._agent.height_ls - self._agent.get_g()) # list of artists to be drawn return [self._ann_time, self._ann_power_flow, self._arr_power_flow, self._g, self._h]