예제 #1
0
    def draw(self, renderer):
        xs3d, ys3d, zs3d = self._verts3d
        xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, renderer.M)
        FancyArrowPatch.set_positions(self,(xs[0],ys[0]),(xs[1],ys[1]))
        dx = np.abs(xs[0] - xs[1])
        dy = np.abs(ys[0] - ys[1])
        dz = np.abs(zs[0] - zs[1])

        thresh = 0.4
        if dx < thresh and dy < thresh:
            factor = np.max([dx,dy])*10.
            if factor > 0:
                self.set_arrowstyle("->",head_length=factor, head_width=factor/2.)
            else:
                self.set_arrowstyle("-")
        else:
            self.set_arrowstyle(self.init_style)
        FancyArrowPatch.draw(self, renderer)
예제 #2
0
class AnnotationBbox(martist.Artist, _AnnotationBase):
    """
    Annotation-like class, but with offsetbox instead of Text.
    """

    zorder = 3

    def __str__(self):
        return "AnnotationBbox(%g,%g)"%(self.xy[0],self.xy[1])
    @docstring.dedent_interpd
    def __init__(self, offsetbox, xy,
                 xybox=None,
                 xycoords='data',
                 boxcoords=None,
                 frameon=True, pad=0.4, # BboxPatch
                 annotation_clip=None,
                 box_alignment=(0.5, 0.5),
                 bboxprops=None,
                 arrowprops=None,
                 fontsize=None,
                 **kwargs):
        """
        *offsetbox* : OffsetBox instance

        *xycoords* : same as Annotation but can be a tuple of two
           strings which are interpreted as x and y coordinates.

        *boxcoords* : similar to textcoords as Annotation but can be a
           tuple of two strings which are interpreted as x and y
           coordinates.

        *box_alignment* : a tuple of two floats for a vertical and
           horizontal alignment of the offset box w.r.t. the *boxcoords*.
           The lower-left corner is (0.0) and upper-right corner is (1.1).

        other parameters are identical to that of Annotation.
        """
        self.offsetbox = offsetbox

        self.arrowprops = arrowprops

        self.set_fontsize(fontsize)


        if arrowprops is not None:
            self._arrow_relpos = self.arrowprops.pop("relpos", (0.5, 0.5))
            self.arrow_patch = FancyArrowPatch((0, 0), (1,1),
                                               **self.arrowprops)
        else:
            self._arrow_relpos = None
            self.arrow_patch = None

        _AnnotationBase.__init__(self,
                                 xy, xytext=xybox,
                                 xycoords=xycoords, textcoords=boxcoords,
                                 annotation_clip=annotation_clip)

        martist.Artist.__init__(self, **kwargs)

        #self._fw, self._fh = 0., 0. # for alignment
        self._box_alignment = box_alignment

        # frame
        self.patch = FancyBboxPatch(
            xy=(0.0, 0.0), width=1., height=1.,
            facecolor='w', edgecolor='k',
            mutation_scale=self.prop.get_size_in_points(),
            snap=True
            )
        self.patch.set_boxstyle("square",pad=pad)
        if bboxprops:
            self.patch.set(**bboxprops)
        self._drawFrame =  frameon


    def contains(self,event):
        t,tinfo = self.offsetbox.contains(event)
        #if self.arrow_patch is not None:
        #    a,ainfo=self.arrow_patch.contains(event)
        #    t = t or a

        # self.arrow_patch is currently not checked as this can be a line - JJ

        return t,tinfo


    def get_children(self):
        children = [self.offsetbox, self.patch]
        if self.arrow_patch:
            children.append(self.arrow_patch)
        return children

    def set_figure(self, fig):

        if self.arrow_patch is not None:
            self.arrow_patch.set_figure(fig)
        self.offsetbox.set_figure(fig)
        martist.Artist.set_figure(self, fig)

    def set_fontsize(self, s=None):
        """
        set fontsize in points
        """
        if s is None:
            s = rcParams["legend.fontsize"]

        self.prop=FontProperties(size=s)

    def get_fontsize(self, s=None):
        """
        return fontsize in points
        """
        return self.prop.get_size_in_points()

    def update_positions(self, renderer):
        "Update the pixel positions of the annotated point and the text."
        xy_pixel = self._get_position_xy(renderer)
        self._update_position_xybox(renderer, xy_pixel)

        mutation_scale = renderer.points_to_pixels(self.get_fontsize())
        self.patch.set_mutation_scale(mutation_scale)

        if self.arrow_patch:
            self.arrow_patch.set_mutation_scale(mutation_scale)


    def _update_position_xybox(self, renderer, xy_pixel):
        "Update the pixel positions of the annotation text and the arrow patch."

        x, y = self.xytext
        if isinstance(self.textcoords, tuple):
            xcoord, ycoord = self.textcoords
            x1, y1 = self._get_xy(renderer, x, y, xcoord)
            x2, y2 = self._get_xy(renderer, x, y, ycoord)
            ox0, oy0 = x1, y2
        else:
            ox0, oy0 = self._get_xy(renderer, x, y, self.textcoords)

        w, h, xd, yd = self.offsetbox.get_extent(renderer)

        _fw, _fh = self._box_alignment
        self.offsetbox.set_offset((ox0-_fw*w+xd, oy0-_fh*h+yd))

        # update patch position
        bbox = self.offsetbox.get_window_extent(renderer)
        #self.offsetbox.set_offset((ox0-_fw*w, oy0-_fh*h))
        self.patch.set_bounds(bbox.x0, bbox.y0,
                              bbox.width, bbox.height)

        x, y = xy_pixel

        ox1, oy1 = x, y

        if self.arrowprops:
            x0, y0 = x, y

            d = self.arrowprops.copy()

            # Use FancyArrowPatch if self.arrowprops has "arrowstyle" key.

            # adjust the starting point of the arrow relative to
            # the textbox.
            # TODO : Rotation needs to be accounted.
            relpos = self._arrow_relpos

            ox0 = bbox.x0 + bbox.width * relpos[0]
            oy0 = bbox.y0 + bbox.height * relpos[1]

            # The arrow will be drawn from (ox0, oy0) to (ox1,
            # oy1). It will be first clipped by patchA and patchB.
            # Then it will be shrinked by shirnkA and shrinkB
            # (in points). If patch A is not set, self.bbox_patch
            # is used.

            self.arrow_patch.set_positions((ox0, oy0), (ox1,oy1))
            fs = self.prop.get_size_in_points()
            mutation_scale = d.pop("mutation_scale", fs)
            mutation_scale = renderer.points_to_pixels(mutation_scale)
            self.arrow_patch.set_mutation_scale(mutation_scale)

            patchA = d.pop("patchA", self.patch)
            self.arrow_patch.set_patchA(patchA)



    def draw(self, renderer):
        """
        Draw the :class:`Annotation` object to the given *renderer*.
        """

        if renderer is not None:
            self._renderer = renderer
        if not self.get_visible(): return

        xy_pixel = self._get_position_xy(renderer)

        if not self._check_xy(renderer, xy_pixel):
            return

        self.update_positions(renderer)

        if self.arrow_patch is not None:
            if self.arrow_patch.figure is None and self.figure is not None:
                self.arrow_patch.figure = self.figure
            self.arrow_patch.draw(renderer)

        if self._drawFrame:
            self.patch.draw(renderer)

        self.offsetbox.draw(renderer)
예제 #3
0
class AnchoredCompass(AnchoredOffsetbox):
    def __init__(self, ax, transSky2Pix, loc,
                 arrow_fraction=0.15,
                 txt1="E", txt2="N",
                 delta_a1=0, delta_a2=0,
                 pad=0.1, borderpad=0.5, prop=None, frameon=False,
                 ):
        """
        Draw an arrows pointing the directions of E & N

        arrow_fraction : length of the arrow as a fraction of axes size

        pad, borderpad in fraction of the legend font size (or prop)
        """

        self._ax = ax
        self._transSky2Pix = transSky2Pix
        self._box = AuxTransformBox(ax.transData)
        self.delta_a1, self.delta_a2 = delta_a1, delta_a2
        self.arrow_fraction = arrow_fraction

        kwargs = dict(mutation_scale=11,
                      shrinkA=0,
                      shrinkB=5)

        self.arrow1 = FancyArrowPatch(posA=(0, 0), posB=(1, 1),
                                      arrowstyle="->",
                                      arrow_transmuter=None,
                                      connectionstyle="arc3",
                                      connector=None,
                                      **kwargs)
        self.arrow2 = FancyArrowPatch(posA=(0, 0), posB=(1, 1),
                                      arrowstyle="->",
                                      arrow_transmuter=None,
                                      connectionstyle="arc3",
                                      connector=None,
                                      **kwargs)


        x1t, y1t, x2t, y2t = 1, 1, 1, 1
        self.txt1 = Text(x1t, y1t, txt1, rotation=0,
                         rotation_mode="anchor",
                         va="center", ha="right")
        self.txt2 = Text(x2t, y2t, txt2, rotation=0,
                         rotation_mode="anchor",
                         va="bottom", ha="center")


        self._box.add_artist(self.arrow1)
        self._box.add_artist(self.arrow2)

        self._box.add_artist(self.txt1)
        self._box.add_artist(self.txt2)

        AnchoredOffsetbox.__init__(self, loc, pad=pad, borderpad=borderpad,
                                   child=self._box,
                                   prop=prop,
                                   frameon=frameon)

    def set_path_effects(self, path_effects):
        for a in [self.arrow1, self.arrow2, self.txt1, self.txt2]:
            a.set_path_effects(path_effects)

    def _update_arrow(self, renderer):
        ax = self._ax

        x0, y0 = ax.viewLim.x0, ax.viewLim.y0
        a1, a2 = estimate_angle(self._transSky2Pix, x0, y0)
        a1, a2 = a1+self.delta_a1, a2+self.delta_a2

        D = min(ax.viewLim.width, ax.viewLim.height)
        d = D * self.arrow_fraction
        x1, y1 = x0+d*np.cos(a1/180.*np.pi), y0+d*np.sin(a1/180.*np.pi)
        x2, y2 = x0+d*np.cos(a2/180.*np.pi), y0+d*np.sin(a2/180.*np.pi)

        self.arrow1.set_positions((x0, y0), (x1, y1))
        self.arrow2.set_positions((x0, y0), (x2, y2))

        d2 = d
        x1t, y1t = x0+d2*np.cos(a1/180.*np.pi), y0+d2*np.sin(a1/180.*np.pi)
        x2t, y2t = x0+d2*np.cos(a2/180.*np.pi), y0+d2*np.sin(a2/180.*np.pi)

        self.txt1.set_position((x1t, y1t))
        self.txt1.set_rotation(a1-180)
        self.txt2.set_position((x2t, y2t))
        self.txt2.set_rotation(a2-90)


    def draw(self, renderer):
        self._update_arrow(renderer)
        super(AnchoredCompass, self).draw(renderer)
예제 #4
0
class AnnotationBbox(martist.Artist, _AnnotationBase):
    """
    Annotation-like class, but with offsetbox instead of Text.
    """

    zorder = 3

    def __str__(self):
        return "AnnotationBbox(%g,%g)" % (self.xy[0], self.xy[1])

    @docstring.dedent_interpd
    def __init__(
            self,
            offsetbox,
            xy,
            xybox=None,
            xycoords='data',
            boxcoords=None,
            frameon=True,
            pad=0.4,  # BboxPatch
            annotation_clip=None,
            box_alignment=(0.5, 0.5),
            bboxprops=None,
            arrowprops=None,
            fontsize=None,
            **kwargs):
        """
        *offsetbox* : OffsetBox instance

        *xycoords* : same as Annotation but can be a tuple of two
           strings which are interpreted as x and y coordinates.

        *boxcoords* : similar to textcoords as Annotation but can be a
           tuple of two strings which are interpreted as x and y
           coordinates.

        *box_alignment* : a tuple of two floats for a vertical and
           horizontal alignment of the offset box w.r.t. the *boxcoords*.
           The lower-left corner is (0.0) and upper-right corner is (1.1).

        other parameters are identical to that of Annotation.
        """
        self.offsetbox = offsetbox

        self.arrowprops = arrowprops

        self.set_fontsize(fontsize)

        if arrowprops is not None:
            self._arrow_relpos = self.arrowprops.pop("relpos", (0.5, 0.5))
            self.arrow_patch = FancyArrowPatch((0, 0), (1, 1),
                                               **self.arrowprops)
        else:
            self._arrow_relpos = None
            self.arrow_patch = None

        _AnnotationBase.__init__(self,
                                 xy,
                                 xytext=xybox,
                                 xycoords=xycoords,
                                 textcoords=boxcoords,
                                 annotation_clip=annotation_clip)

        martist.Artist.__init__(self, **kwargs)

        #self._fw, self._fh = 0., 0. # for alignment
        self._box_alignment = box_alignment

        # frame
        self.patch = FancyBboxPatch(
            xy=(0.0, 0.0),
            width=1.,
            height=1.,
            facecolor='w',
            edgecolor='k',
            mutation_scale=self.prop.get_size_in_points(),
            snap=True)
        self.patch.set_boxstyle("square", pad=pad)
        if bboxprops:
            self.patch.set(**bboxprops)
        self._drawFrame = frameon

    def contains(self, event):
        t, tinfo = self.offsetbox.contains(event)
        #if self.arrow_patch is not None:
        #    a,ainfo=self.arrow_patch.contains(event)
        #    t = t or a

        # self.arrow_patch is currently not checked as this can be a line - JJ

        return t, tinfo

    def get_children(self):
        children = [self.offsetbox, self.patch]
        if self.arrow_patch:
            children.append(self.arrow_patch)
        return children

    def set_figure(self, fig):

        if self.arrow_patch is not None:
            self.arrow_patch.set_figure(fig)
        self.offsetbox.set_figure(fig)
        martist.Artist.set_figure(self, fig)

    def set_fontsize(self, s=None):
        """
        set fontsize in points
        """
        if s is None:
            s = rcParams["legend.fontsize"]

        self.prop = FontProperties(size=s)

    def get_fontsize(self, s=None):
        """
        return fontsize in points
        """
        return self.prop.get_size_in_points()

    def update_positions(self, renderer):
        "Update the pixel positions of the annotated point and the text."
        xy_pixel = self._get_position_xy(renderer)
        self._update_position_xybox(renderer, xy_pixel)

        mutation_scale = renderer.points_to_pixels(self.get_fontsize())
        self.patch.set_mutation_scale(mutation_scale)

        if self.arrow_patch:
            self.arrow_patch.set_mutation_scale(mutation_scale)

    def _update_position_xybox(self, renderer, xy_pixel):
        "Update the pixel positions of the annotation text and the arrow patch."

        x, y = self.xytext
        if isinstance(self.textcoords, tuple):
            xcoord, ycoord = self.textcoords
            x1, y1 = self._get_xy(renderer, x, y, xcoord)
            x2, y2 = self._get_xy(renderer, x, y, ycoord)
            ox0, oy0 = x1, y2
        else:
            ox0, oy0 = self._get_xy(renderer, x, y, self.textcoords)

        w, h, xd, yd = self.offsetbox.get_extent(renderer)

        _fw, _fh = self._box_alignment
        self.offsetbox.set_offset((ox0 - _fw * w + xd, oy0 - _fh * h + yd))

        # update patch position
        bbox = self.offsetbox.get_window_extent(renderer)
        #self.offsetbox.set_offset((ox0-_fw*w, oy0-_fh*h))
        self.patch.set_bounds(bbox.x0, bbox.y0, bbox.width, bbox.height)

        x, y = xy_pixel

        ox1, oy1 = x, y

        if self.arrowprops:
            x0, y0 = x, y

            d = self.arrowprops.copy()

            # Use FancyArrowPatch if self.arrowprops has "arrowstyle" key.

            # adjust the starting point of the arrow relative to
            # the textbox.
            # TODO : Rotation needs to be accounted.
            relpos = self._arrow_relpos

            ox0 = bbox.x0 + bbox.width * relpos[0]
            oy0 = bbox.y0 + bbox.height * relpos[1]

            # The arrow will be drawn from (ox0, oy0) to (ox1,
            # oy1). It will be first clipped by patchA and patchB.
            # Then it will be shrinked by shirnkA and shrinkB
            # (in points). If patch A is not set, self.bbox_patch
            # is used.

            self.arrow_patch.set_positions((ox0, oy0), (ox1, oy1))
            fs = self.prop.get_size_in_points()
            mutation_scale = d.pop("mutation_scale", fs)
            mutation_scale = renderer.points_to_pixels(mutation_scale)
            self.arrow_patch.set_mutation_scale(mutation_scale)

            patchA = d.pop("patchA", self.patch)
            self.arrow_patch.set_patchA(patchA)

    def draw(self, renderer):
        """
        Draw the :class:`Annotation` object to the given *renderer*.
        """

        if renderer is not None:
            self._renderer = renderer
        if not self.get_visible(): return

        xy_pixel = self._get_position_xy(renderer)

        if not self._check_xy(renderer, xy_pixel):
            return

        self.update_positions(renderer)

        if self.arrow_patch is not None:
            if self.arrow_patch.figure is None and self.figure is not None:
                self.arrow_patch.figure = self.figure
            self.arrow_patch.draw(renderer)

        if self._drawFrame:
            self.patch.draw(renderer)

        self.offsetbox.draw(renderer)
예제 #5
0
class AnchoredCompass(AnchoredOffsetbox):
    """
    AnchoredOffsetbox class to create a NE compass on the plot given a value
    for ori (degrees E of N of the yaxis)
    """
    def __init__(self,
                 ax,
                 ori,
                 loc=4,
                 arrow_fraction=0.15,
                 txt1="E",
                 txt2="N",
                 pad=0.3,
                 borderpad=0.5,
                 prop=None,
                 frameon=False):
        self._ax = ax
        self.ori = ori
        self._box = AuxTransformBox(ax.transData)
        self.arrow_fraction = arrow_fraction
        path_effects = [PathEffects.withStroke(linewidth=3, foreground="w")]
        kwargs = dict(mutation_scale=14, shrinkA=0, shrinkB=7)
        self.arrow1 = FancyArrowPatch(posA=(0, 0),
                                      posB=(1, 1),
                                      arrowstyle="-|>",
                                      arrow_transmuter=None,
                                      connectionstyle="arc3",
                                      connector=None,
                                      color="k",
                                      path_effects=path_effects,
                                      **kwargs)
        self.arrow2 = FancyArrowPatch(posA=(0, 0),
                                      posB=(1, 1),
                                      arrowstyle="-|>",
                                      arrow_transmuter=None,
                                      connectionstyle="arc3",
                                      connector=None,
                                      color="k",
                                      path_effects=path_effects,
                                      **kwargs)
        self.txt1 = Text(1,
                         1,
                         txt1,
                         rotation=0,
                         rotation_mode="anchor",
                         path_effects=path_effects,
                         va="center",
                         ha="center")
        self.txt2 = Text(2,
                         2,
                         txt2,
                         rotation=0,
                         rotation_mode="anchor",
                         path_effects=path_effects,
                         va="center",
                         ha="center")
        self._box.add_artist(self.arrow1)
        self._box.add_artist(self.arrow2)
        self._box.add_artist(self.txt1)
        self._box.add_artist(self.txt2)

        AnchoredOffsetbox.__init__(self,
                                   loc,
                                   pad=pad,
                                   borderpad=borderpad,
                                   child=self._box,
                                   prop=prop,
                                   frameon=frameon)

    def _update_arrow(self, renderer):
        ax = self._ax

        x0, y0 = ax.viewLim.x0, ax.viewLim.y0
        a1, a2 = 180 - self.ori, 180 - self.ori - 90
        D = min(ax.viewLim.width, ax.viewLim.height)
        d = D * self.arrow_fraction
        x1, y1 = x0 + d * np.cos(a1 / 180. * np.pi), y0 + d * np.sin(
            a1 / 180. * np.pi)
        x2, y2 = x0 + d * np.cos(a2 / 180. * np.pi), y0 + d * np.sin(
            a2 / 180. * np.pi)
        self.arrow1.set_positions((x0, y0), (x1, y1))
        self.arrow2.set_positions((x0, y0), (x2, y2))
        d2 = d
        x1t, y1t = x0 + d2 * np.cos(a1 / 180. * np.pi), y0 + d2 * np.sin(
            a1 / 180. * np.pi)
        x2t, y2t = x0 + d2 * np.cos(a2 / 180. * np.pi), y0 + d2 * np.sin(
            a2 / 180. * np.pi)
        self.txt1.set_position((x1t, y1t))
        self.txt1.set_rotation(0)  # a1-180
        self.txt2.set_position((x2t + 20, y2t))
        self.txt2.set_rotation(0)  # a2-90

    def draw(self, renderer):
        self._update_arrow(renderer)
        super(AnchoredCompass, self).draw(renderer)
예제 #6
0
class AnchoredCompass(AnchoredOffsetbox):
    def __init__(self,
                 ax,
                 transSky2Pix,
                 loc,
                 arrow_fraction=0.15,
                 txt1="E",
                 txt2="N",
                 delta_a1=0,
                 delta_a2=0,
                 pad=0.1,
                 borderpad=0.5,
                 prop=None,
                 frameon=False,
                 color=None):
        """
        Draw an arrows pointing the directions of E & N

        arrow_fraction : length of the arrow as a fraction of axes size

        pad, borderpad in fraction of the legend font size (or prop)
        """

        self._ax = ax
        self._transSky2Pix = transSky2Pix
        self._box = AuxTransformBox(ax.transData)
        self.delta_a1, self.delta_a2 = delta_a1, delta_a2
        self.arrow_fraction = arrow_fraction

        kwargs = dict(mutation_scale=11, shrinkA=0, shrinkB=5)

        self.arrow1 = FancyArrowPatch(posA=(0, 0),
                                      posB=(1, 1),
                                      arrowstyle="->",
                                      arrow_transmuter=None,
                                      connectionstyle="arc3",
                                      connector=None,
                                      color=color,
                                      **kwargs)
        self.arrow2 = FancyArrowPatch(posA=(0, 0),
                                      posB=(1, 1),
                                      arrowstyle="->",
                                      arrow_transmuter=None,
                                      connectionstyle="arc3",
                                      connector=None,
                                      color=color,
                                      **kwargs)

        x1t, y1t, x2t, y2t = 1, 1, 1, 1
        self.txt1 = Text(x1t,
                         y1t,
                         txt1,
                         rotation=0,
                         rotation_mode="anchor",
                         va="center",
                         ha="right",
                         color=color,
                         fontproperties=prop)
        self.txt2 = Text(x2t,
                         y2t,
                         txt2,
                         rotation=0,
                         rotation_mode="anchor",
                         va="bottom",
                         ha="center",
                         color=color,
                         fontproperties=prop)

        self._box.add_artist(self.arrow1)
        self._box.add_artist(self.arrow2)

        self._box.add_artist(self.txt1)
        self._box.add_artist(self.txt2)

        AnchoredOffsetbox.__init__(self,
                                   loc,
                                   pad=pad,
                                   borderpad=borderpad,
                                   child=self._box,
                                   prop=prop,
                                   frameon=frameon)

    def set_path_effects(self, path_effects):
        for a in [self.arrow1, self.arrow2, self.txt1, self.txt2]:
            a.set_path_effects(path_effects)

    def _update_arrow(self, renderer):
        ax = self._ax

        x0, y0 = ax.viewLim.x0, ax.viewLim.y0
        a1, a2 = estimate_angle_trans(self._transSky2Pix, x0, y0)
        a1, a2 = a1 + self.delta_a1, a2 + self.delta_a2

        D = min(ax.viewLim.width, ax.viewLim.height)
        d = D * self.arrow_fraction
        x1, y1 = x0 + d * np.cos(a1 / 180. * np.pi), y0 + d * np.sin(
            a1 / 180. * np.pi)
        x2, y2 = x0 + d * np.cos(a2 / 180. * np.pi), y0 + d * np.sin(
            a2 / 180. * np.pi)

        self.arrow1.set_positions((x0, y0), (x1, y1))
        self.arrow2.set_positions((x0, y0), (x2, y2))

        d2 = d
        x1t, y1t = x0 + d2 * np.cos(a1 / 180. * np.pi), y0 + d2 * np.sin(
            a1 / 180. * np.pi)
        x2t, y2t = x0 + d2 * np.cos(a2 / 180. * np.pi), y0 + d2 * np.sin(
            a2 / 180. * np.pi)

        self.txt1.set_position((x1t, y1t))
        self.txt1.set_rotation(a1 - 180)
        self.txt2.set_position((x2t, y2t))
        self.txt2.set_rotation(a2 - 90)

    def draw(self, renderer):
        self._update_arrow(renderer)
        super(AnchoredCompass, self).draw(renderer)
예제 #7
0
class CorrespondenceEditor(object):
    ax: plt.Axes
    canvas: plt.FigureCanvasBase

    def __init__(self, ax, on_commit=None,
                 facecolor=None, edgecolor=None,
                 pick_radius=15):

        self.arrowstyle = "-|>,head_width=2.5, head_length=5"
        self.facecolor = facecolor or (0, 0.5, 1.0, 1.0)
        self.edgecolor = edgecolor or (0, 0.5, 1.0, 1.0)
        self.pick_radius = pick_radius
        self.scroll_speed = -0.5

        # Possible semantics:
        #  An integer > 1 -- fit a polynomial of that order
        #  A float in [0,1] -- Use that fraction of the number of matches as the order
        self.order = 0.5
        self._warp = None
        self._inverse_warp = None

        self.on_finish = on_commit
        self.ax = ax
        self.canvas = ax.figure.canvas

        self.matches = [[], []]

        self._active_point = -1
        self._active_arrow = -1

        self._arrow_patches = []
        self._arrow_tails = [] # Zero-length arrows are iunvisible, so I will put markers at the tails

        self._arrow_highlight = FancyArrowPatch((0, 0), (0, 0), visible=False,
                                                arrowstyle=self.arrowstyle, color='y', linewidth=3)

        self._arrow_highlight_head = Line2D([0], [0], visible=False,
                                            marker='o', markersize=5, markerfacecolor='y', markeredgecolor='y',
                                            pickradius=self.pick_radius)

        self._arrow_highlight_tail = Line2D([0], [0], visible=False,
                                            marker='o', markersize=5, markerfacecolor='y', markeredgecolor='y',
                                            pickradius=self.pick_radius)

        self._point_highlight = Line2D([0], [0], visible=False,
                                       marker='o', markersize=5, markerfacecolor='r', markeredgecolor='r')

        self._arrow_highlight = self.ax.add_patch(self._arrow_highlight)
        self._arrow_highlight_head = self.ax.add_artist(self._arrow_highlight_head)
        self._arrow_highlight_tail = self.ax.add_artist(self._arrow_highlight_tail)

        self._point_highlight = self.ax.add_artist(self._point_highlight)

        self.cids = []

        self._history = []
        self._future = []
        self.save_state()  # Initialuze history (to empty doc)

        self._event = None
        self._mouse_down = None  # Set to the event that started a drag. None if not dragging.
        self._editing = 0  # Whether we are editing. Counts up and down to allow compound edits.

        self.key_handlers = {
            ' ': self.commit,
            's': partial(self.set_active_point, 0),
            't': partial(self.set_active_point, 1),
            'n': partial(self.new_arrow, selected=True, drag_target=False),
            'delete': self.delete_arrow,
            'e': self.delete_arrow,
            'left': partial(self.nudge, dx=-1, dy=0),
            'right': partial(self.nudge, dx=1, dy=0),
            'up': partial(self.nudge, dx=0, dy=-1),
            'down': partial(self.nudge, dx=0, dy=1),
            'ctrl+z': self.undo,
            'ctrl+Z': self.redo,
            'enter': self.commit,
        }

        # Create the arrows
        self.refresh_arrows()

        self.connect_events()

    def __len__(self):
        return len(self.matches[0])

    def __getitem__(self, i):
        return self.matches[0][i], self.matches[1][i]

    def __iter__(self):
        for i in range(len(self)):
            yield self[i]

    def begin_editing(self):
        """Call before changing the data"""
        self._editing += 1

    def finish_editing(self):
        """Call after you are done modifying the data.
        Saves to history after all edits are complete. """
        assert self._editing > 0
        self._editing -= 1
        if self._editing == 0:
            self._warp = None # We changed the matched --> the warp is dirty
            self._inverse_warp = None # We changed the matched --> the warp is dirty
            self.save_state()
            self.refresh_highlights()
            self.canvas.draw_idle()

    def get_state(self):
        return self.matches

    def set_state(self, state):
        s, t = state
        self.matches[0][:] = s
        self.matches[1][:] = t
        self.refresh_arrows()

    def save_state(self):
        # Sometimes I push multiple times for some reason
        if self._history and self._history[-1] == self.get_state():
            return
        self._history.append(deepcopy(self.get_state()))
        del self._future[:]

    def undo(self):
        if self._history:
            self._future.append(deepcopy(self.get_state()))
            self.set_state(self._history.pop())

    def redo(self):
        if self._future:
            self.set_state(self._future.pop())
            if self._history and self.get_state() != self._history[-1]:
                self._history.append(deepcopy(self.get_state()))

    def connect(self, event, handler):
        self.cids.append(self.canvas.mpl_connect(event, handler))

    def connect_events(self):
        self.connect('motion_notify_event', self._on_motion)
        self.connect('button_press_event', self._on_button_press)
        self.connect('button_release_event', self._on_button_release)
        self.connect('key_press_event', self._on_key_press)
        # self.connect('key_release_event', self._on_key_release)
        self.connect('scroll_event', self._on_scroll)

    def disconnect_events(self):
        """Disconnect all event handlers"""
        for cid in self.cids:
            self.canvas.mpl_disconnect(cid)
        self.cids = []

    def ignore(self, event):
        return event.inaxes != self.ax

    def zoom(self, amount=1, xy=None):
        amount = 2 ** (amount)
        xmin, xmax = self.ax.get_xlim()
        ymin, ymax = self.ax.get_ylim()
        if xy is None:
            x = (xmin + xmax) / 2.
            y = (ymin + ymax) / 2.
        else:
            x, y = xy
        xmin = x + (xmin - x) * amount
        ymin = y + (ymin - y) * amount
        xmax = x + (xmax - x) * amount
        ymax = y + (ymax - y) * amount
        self.ax.set_xbound(xmin, xmax)
        self.ax.set_ybound(ymin, ymax)

    def _on_scroll(self, event: MouseEvent):
        x, y = event.xdata, event.ydata
        self.zoom(event.step * self.scroll_speed, (x, y))
        self.canvas.draw_idle()

    def _on_motion(self, event: MouseEvent):
        if self.ignore(event):
            return

        if self._mouse_down:
            self.set_point((event.xdata, event.ydata))
        else:
            self.select_arrow(event)

    def _distance_to_line_segment(self, p0, p1, xy, epsilon=1e-4):
        p0 = np.array(p0)
        p1 = np.array(p1)
        xy = np.array(xy)
        v = p1-p0
        v2 = v@v
        if v2 > epsilon:
            q = (v @ (xy-p0))/(v @ v)  # Distance along the edge
            q = np.clip(q, 0, 1) # Make sure we are on the line segment
        else:
            q = 0
        p = p0 + v*q  # Closest point on the line segment
        return np.linalg.norm(xy-p)

    def select_arrow(self, event: MouseEvent):
        patch: FancyArrowPatch
        min_distance = self.pick_radius
        argmin_distance = -1

        p = np.array((event.xdata, event.ydata))
        for i, (s, t) in enumerate(self):
            d = self._distance_to_line_segment(s, t, p)
            if d < min_distance:
                min_distance = d
                argmin_distance = i

        self.set_active_arrow(argmin_distance)

        if 0 <= argmin_distance:
            s, t = self[argmin_distance]
            d0 = norm(p-s)
            d1 = norm(p-t)
            if d0 < d1:
                self.set_active_point(0)
            else:
                self.set_active_point(1)

    def _start_drag(self, event: matplotlib.backend_bases.MouseEvent):
        self._mouse_down = event
        self.begin_editing()
        self.set_point((event.xdata, event.ydata))

    def _on_button_press(self, event: matplotlib.backend_bases.MouseEvent):
        if self.ignore(event):
            return

        self._event = event

        if self.has_active_arrow():
            if event.button == 1:
                # There are a couple of ways we can miss a mouse up event...
                if not self._mouse_down:
                    self._start_drag(event)
        else:
            if event.button == 1:
                self.new_arrow(s=(event.xdata, event.ydata),
                               t=(event.xdata, event.ydata),
                               selected=True,
                               drag_target=True)

    def _on_button_release(self, event: MouseEvent):
        if self.ignore(event):
            return

        if self._mouse_down:
            self._mouse_down = None
            self.finish_editing()

    def _on_key_press(self, event: KeyEvent):
        if self.ignore(event):
            return

        self._event = event
        if event.key in self.key_handlers:
            self.key_handlers[event.key]()

    def refresh_arrows(self):
        # Remove the old arrows
        while self._arrow_patches:
            self._arrow_patches.pop().remove()

        # Dont forget I added a point marker to the back of every arrow
        while self._arrow_tails:
            self._arrow_tails.pop().remove()

        # Add the new ones
        for s, t in zip(*self.matches):
            self._make_arrow_patch(s, t)

        # Refresh the highlights
        self.refresh_highlights()

        # Schedule a redraw
        self.canvas.draw_idle()

    def _set_highlight_arrow(self, s=None, t=None, visible=True):
        if s is not None and t is not None:
            self._arrow_highlight.set_positions(s, t)
            self._arrow_highlight_tail.set_data(s)
            self._arrow_highlight_head.set_data(t)
        self._arrow_highlight.set_visible(visible)
        self._arrow_highlight_head.set_visible(visible)
        self._arrow_highlight_tail.set_visible(visible)

    def refresh_highlights(self):
        if 0 <= self._active_arrow < len(self):
            self._set_highlight_arrow(self.matches[0][self._active_arrow],
                                      self.matches[1][self._active_arrow],
                                      visible=True)
        else:
            self._set_highlight_arrow(visible=False)

        if 0 <= self._active_arrow < len(self) and 0 <= self._active_point < 2:
            self._point_highlight.set_data(*(self.matches[self._active_point][self._active_arrow]))
            self._point_highlight.set_visible(True)
        else:
            self._point_highlight.set_visible(False)

        self.canvas.draw_idle()

    def commit(self, _unused=None):
        if self.on_finish:
            self.on_finish(self)

    def get_active_arrow(self):
        return self._active_arrow

    def set_active_arrow(self, i):
        if i != self._active_arrow:
            self._active_arrow = i
            self.refresh_highlights()

    def has_active_arrow(self):
        return 0 <= self._active_arrow < len(self)

    def get_active_point(self):
        return self._active_point

    def has_active_point(self):
        return 0 <= self._active_point < 2

    def set_active_point(self, i):
        if i != self._active_point:
            self._active_point = i
            self.refresh_highlights()

    def _make_arrow_patch(self, s, t):
        a = FancyArrowPatch(s, t,
                            arrowstyle=self.arrowstyle,
                            facecolor=self.facecolor,
                            edgecolor=self.edgecolor)
        a = self.ax.add_patch(a)
        self._arrow_patches.append(a)

        tail = Line2D([s[0]], [s[1]],
                      marker='o',
                      markersize=2.5,
                      markeredgecolor=self.edgecolor,
                      markerfacecolor=self.facecolor,
                      pickradius=self.pick_radius
                      )
        tail = self.ax.add_artist(tail)
        self._arrow_tails.append(tail)


    @mutator
    def new_arrow(self, s=None, t=None, selected=True, drag_target=False):
        # Default to the location of the mouse in most recent event
        if s is None:
            if t is None:
                s = self._event.xdata, self._event.ydata
            else:
                s = self.predict_source(t)
        if t is None:
            if s is None:
                t = self._event.xdata, self._event.ydata
            else:
                t = self.predict_target(s)

        # Add the new source and target
        self.matches[0].append(s)
        self.matches[1].append(t)

        # Add a patch for the new arrow
        self._make_arrow_patch(s, t)

        # Users expect the new arrow to be selected
        if selected:
            self.set_active_arrow(len(self) - 1)
            self.set_active_point(1)  # Presumably we clicked on the tail and we will click on the head next

        if drag_target:
            self._start_drag(self._event)

    # noinspection PyUnusedLocal
    @mutator
    def set_point(self, xy, point=None, arrow=None):
        patch: FancyArrowPatch
        line: Line2D

        if arrow is None:
            arrow = self._active_arrow

        if point is None:
            point = self._active_point

        if not 0 <= arrow < len(self):
            return  # No arrows to select yet

        self.matches[point][arrow] = xy

        # Update the plot elements for the arrow
        patch = self._arrow_patches[arrow]
        patch.set_positions(self.matches[0][arrow], self.matches[1][arrow])

        # And also move the tail
        self._arrow_tails[arrow].set_data(self.matches[0][arrow])

        # Update the highlights if we are moving the selected item
        if arrow == self._active_arrow:
            self.refresh_highlights()

    def ensure_selected_arrow(self):
        if len(self) == 0:
            self.new_arrow()

        if self._active_arrow < 0:
            self.set_active_arrow(len(self) - 1)

    def ensure_selected_point(self):
        self.ensure_selected_arrow()
        if self._active_point < 0:
            self.set_active_arrow(len(self) - 1)

    @mutator
    def nudge(self, dx=0, dy=0):
        self.ensure_selected_point()
        x, y = self.matches[self._active_point][self._active_arrow]
        self.set_point((x + dx, y + dy))

    @mutator
    def delete_arrow(self, arrow=None):
        if arrow is None:
            self.ensure_selected_arrow()
            arrow = self._active_arrow

        del self.matches[0][arrow]
        del self.matches[1][arrow]

        # Remove the patch from the plot
        arrow_patch = self._arrow_patches.pop(arrow)
        arrow_patch.remove()

        tail = self._arrow_tails.pop(arrow)
        tail.remove()

        # Update the active arrow index (it might have shifted)
        # The behavior if we delete the active arrow should be that
        # the next arrow is selected. Otherwise the selected arrow
        # should be the same.
        if self._active_arrow > arrow:
            self._active_arrow -= 1

    def get_sources(self):
        return self.matches[0]

    def get_targets(self):
        return self.matches[1]

    def get_transform_order(self):
        if self.order > 1:
            order = min(self.order, len(self))
        else:
            order = round(self.order * len(self))
        return order

    def get_warp(self, recompute=False):
        if recompute:
            self._warp = None

        if self._warp is None:
            self._warp = PolynomialTransform()
            self._warp.estimate(np.array(self.get_sources()),
                                np.array(self.get_targets()),
                                self.get_transform_order())

        return self._warp

    def get_inverse_warp(self, recompute=False):
        if recompute:
            self._inverse_warp = None

        if self._inverse_warp is None:
            self._inverse_warp = PolynomialTransform()
            self._inverse_warp.estimate(self.get_targets(),
                                        self.get_sources(),
                                        self.get_transform_order())

        return self._warp

    def predict_target(self, s):
        t = self.get_warp()(np.array([s]))[0]
        return tuple(t)

    def predict_source(self, t):
        s =  self.get_inverse_warp()(np.array([t]))[0]
        return tuple(s)
예제 #8
0
class CartPole:
    def __init__(self):

        # Global time of the simulation
        self.time = 0.0

        # CartPole is initialized with state, control input, target position all zero
        # This is however usually changed before running the simulation. Treat it just as placeholders.
        # Container for the augmented state (angle, position and their first and second derivatives)of the cart
        self.s = s0  # (s like state)
        # Variables for control input and target position.
        self.u = 0.0  # Physical force acting on the cart
        self.Q = 0.0  # Dimensionless motor power in the range [-1,1] from which force is calculated with Q2u() method
        self.target_position = 0.0

        # Physical parameters of the CartPole
        self.p = p_globals

        # region Time scales for simulation step, controller update and saving data
        # See last paragraph of "Time scales" section for explanations
        # ∆t in number of steps (related to simulation time step)
        # is set while setting corresponding dt through @property
        self.dt_controller_number_of_steps = 0
        self.dt_save_number_of_steps = 0

        # Counts time steps from last controller update or saving
        # is set while setting corresponding dt through @property
        self.dt_controller_steps_counter = 0
        self.dt_save_steps_counter = 0

        # Helper variables to set timescales
        self._dt_simulation = None
        self._dt_controller = None
        self._dt_save = None

        self.dt_simulation = None  # s, Update CartPole dynamical state every dt_simulation seconds
        self.dt_controller = None  # s, Recalculate control input every dt_controller_default seconds
        self.dt_save = None  # s,  Save CartPole state every dt_save_default seconds
        # endregion

        # region Variables controlling operation of the program - can be modified directly from CartPole environment
        self.rounding_decimals = 5  # Sets number of digits after coma to save in experiment history for each feature
        self.save_data_in_cart = True  # Decides whether to store whole data of the experiment in dict_history or not
        self.stop_at_90 = False  # If true pole is blocked after reaching the horizontal position
        # endregion

        # region Variables controlling operation of the program - should not be modified directly
        self.save_flag = False  # Signalizes that the current time step should be saved
        self.csv_filepath = None  # Where to save the experiment history.
        self.controller = None  # Placeholder for the currently used controller function
        self.controller_name = ''  # Placeholder for the currently used controller name
        self.controller_idx = None  # Placeholder for the currently used controller index
        self.controller_names = self.get_available_controller_names(
        )  # list of controllers available in controllers folder
        # endregion

        # region Variables for generating experiments with random target trace
        # Parameters for random trace generation
        # These need to be set, before CartPole can generate random trace and random experiment
        self.track_relative_complexity = None  # randomly placed target points/s
        self.length_of_experiment = None  # seconds, length of the random length trace
        self.interpolation_type = None  # Sets how to interpolate between turning points of random trace
        # Possible choices: '0-derivative-smooth', 'linear', 'previous'
        # '0-derivative-smooth'
        #       -> turning points are connected with smooth interpolation curve having derivative = 0 at each turning p.
        # 'linear' -> turning points are connected with line segments
        # 'previous' -> between two turning points the value of the preceding point is kept constant.
        #                   In this last setting endpoint if set has no visible effect
        #                       (may however appear in the last line of the recording - TODO: not checked)
        self.turning_points_period = None  # How turning points should be distributed
        # Possible choices: 'regular', 'random'
        # Regular means that they are equidistant from each other
        # Random means we pick randomly points at time axis at which we place turning points
        # Where the target position of the random experiment starts and end:
        self.start_random_target_position_at = None
        self.end_random_target_position_at = None
        # Alternatively you can provide a list of target positions.
        # e.g. self.turning_points = [10.0, 0.0, 0.0]
        # If not None this variable has precedence -
        # track_relative_complexity, start/end_random_target_position_at_globals have no effect.
        self.turning_points = None

        self.random_track_f = None  # Function interpolataing the random target position between turning points
        self.new_track_generated = False  # Flag informing that a new target position track is generated
        self.t_max_pre = None  # Placeholder for the end time of the generated random experiment
        self.number_of_timesteps_in_random_experiment = None
        self.use_pregenerated_target_position = False  # Informs method performing experiment
        #                                                    not to take target position from environment
        # endregion and

        self.dict_history = {}  # Dictionary holding the experiment history

        # region Variables initialization for drawing/animating a CartPole
        # DIMENSIONS OF THE DRAWING ONLY!!!
        # NOTHING TO DO WITH THE SIMULATION AND NOT INTENDED TO BE MANIPULATED BY USER !!!

        # Variable relevant for interactive use of slider
        self.slider_max = 0.0
        self.slider_value = 0.0

        # Parameters needed to display CartPole in GUI
        # They are assigned with values in self.init_elements()
        self.CartLength = None
        self.WheelRadius = None
        self.WheelToMiddle = None
        self.y_plane = None
        self.y_wheel = None
        self.MastHight = None  # For drawing only. For calculation see L
        self.MastThickness = None
        self.HalfLength = None  # Length of the track

        # Elements of the drawing
        self.Mast = None
        self.Chassis = None
        self.WheelLeft = None
        self.WheelRight = None

        # Arrow indicating acceleration (=motor power)
        self.Acceleration_Arrow = None

        self.y_acceleration_arrow = None
        self.scaling_dx_acceleration_arrow = None
        self.x_acceleration_arrow = None

        # Depending on mode, slider may be displayed either as bar or as an arrow
        self.Slider_Bar = None
        self.Slider_Arrow = None
        self.t2 = None  # An abstract container for the transform rotating the mast

        self.init_graphical_elements(
        )  # Assign proper object to the above variables
        # endregion

        # region Initialize CartPole in manual-stabilization mode
        self.set_controller('manual-stabilization')
        # endregion

    # region 1. Methods related to dynamic evolution of CartPole system

    # This method changes the internal state of the CartPole
    # from a state at time t to a state at t+dt
    # We assume this function is called for the first time to calculate first time step
    # @profile(precision=4)
    def update_state(self):
        # Update the total time of the simulation
        self.time = self.time + self.dt_simulation

        # Update target position depending on the mode of operation
        if self.use_pregenerated_target_position:

            # If time exceeds the max time for which target position was defined
            if self.time >= self.t_max_pre:
                return

            self.target_position = self.random_track_f(self.time)
            self.slider_value = self.target_position  # Assign target position to slider to display it
        else:
            if self.controller_name == 'manual-stabilization':
                self.target_position = 0.0  # In this case target position is not used.
                # This just fill the corresponding column in history with zeros
            else:
                self.target_position = self.slider_value  # Get target position from slider

        # Calculate the next state
        self.cartpole_integration()

        # Snippet to stop pole at +/- 90 deg if enabled
        zero_DD = None
        if self.stop_at_90:
            if self.s.angle >= np.pi / 2:
                self.s.angle = np.pi / 2
                self.s.angleD = 0.0
                zero_DD = True  # Make also second derivatives 0 after they are calculated
            elif self.s.angle <= -np.pi / 2:
                self.s.angle = -np.pi / 2
                self.s.angleD = 0.0
                zero_DD = True  # Make also second derivatives 0 after they are calculated
            else:
                zero_DD = False

        # Wrap angle to +/-π
        self.s.angle = wrap_angle_rad(self.s.angle)

        # In case in the next step the wheel of the cart
        # went beyond the track
        # Bump elastically into an (invisible) boarder
        if (abs(self.s.position) + self.WheelToMiddle) > self.HalfLength:
            self.s.positionD = -self.s.positionD

        # Determine the dimensionless [-1,1] value of the motor power Q
        self.Update_Q()

        # Convert dimensionless motor power to a physical force acting on the Cart
        self.u = Q2u(self.Q, self.p)

        # Update second derivatives
        self.s.angleDD, self.s.positionDD = cartpole_ode(
            self.p, self.s, self.u)

        if zero_DD:
            self.s.angleDD = 0.0

        # Calculate time steps from last saving
        # The counter should be initialized at max-1 to start with a control input update
        self.dt_save_steps_counter += 1

        # If update time interval elapsed save current state and zero the counter
        if self.dt_save_steps_counter == self.dt_save_number_of_steps:
            # If user chose to save history of the simulation it is saved now
            # It is saved first internally to a dictionary in the Cart instance
            if self.save_data_in_cart:

                # Saving simulation data
                self.dict_history['time'].append(self.time)
                self.dict_history['s.position'].append(self.s.position)
                self.dict_history['s.positionD'].append(self.s.positionD)
                self.dict_history['s.positionDD'].append(self.s.positionDD)
                self.dict_history['s.angle'].append(self.s.angle)
                self.dict_history['s.angleD'].append(self.s.angleD)
                self.dict_history['s.angleDD'].append(self.s.angleDD)
                self.dict_history['u'].append(self.u)
                self.dict_history['Q'].append(self.Q)
                # The target_position is not always meaningful
                # If it is not meaningful all values in this column are set to 0
                self.dict_history['target_position'].append(
                    self.target_position)

                self.dict_history['s.angle.sin'].append(np.sin(self.s.angle))
                self.dict_history['s.angle.cos'].append(np.cos(self.s.angle))

            else:

                self.dict_history = {
                    'time': [self.time],
                    's.position': [self.s.position],
                    's.positionD': [self.s.positionD],
                    's.positionDD': [self.s.positionDD],
                    's.angle': [self.s.angle],
                    's.angleD': [self.s.angleD],
                    's.angleDD': [self.s.angleDD],
                    'u': [self.u],
                    'Q': [self.Q],
                    'target_position': [self.target_position],
                    's.angle.sin': [np.sin(self.s.angle)],
                    's.angle.cos': [np.cos(self.s.angle)]
                }

                self.save_flag = True

            self.dt_save_steps_counter = 0

    # A method integrating the cartpole ode over time step dt
    # Currently we use a simple single step Euler stepping
    def cartpole_integration(self):
        """Simple single step integration of CartPole state by dt

        Takes state as SimpleNamespace, but returns as separate variables

        :param s: state of the CartPole (contains: s.position, s.positionD, s.angle and s.angleD)
        :param dt: time step by which the CartPole state should be integrated
        """

        self.s.position = self.s.position + self.s.positionD * self.dt_simulation
        self.s.positionD = self.s.positionD + self.s.positionDD * self.dt_simulation

        self.s.angle = self.s.angle + self.s.angleD * self.dt_simulation
        self.s.angleD = self.s.angleD + self.s.angleDD * self.dt_simulation

    # Determine the dimensionless [-1,1] value of the motor power Q
    # The function loads an external controller from PATH_TO_CONTROLLERS
    # This function should be called for the first time to calculate 0th time step
    # Otherwise it goes out of sync with saving
    def Update_Q(self):
        # Calculate time steps from last update
        # The counter should be initialized at max-1 to start with a control input update
        self.dt_controller_steps_counter += 1

        # If update time interval elapsed update control input and zero the counter
        if self.dt_controller_steps_counter == self.dt_controller_number_of_steps:

            if self.controller_name == 'manual-stabilization':
                # in this case slider corresponds already to the power of the motor
                self.Q = self.slider_value
            else:  # in this case slider gives a target position, lqr regulator
                self.Q = self.controller.step(self.s, self.target_position,
                                              self.time)

            self.dt_controller_steps_counter = 0

    # endregion

    # region 2. Methods related to experiment history as a whole: saving, loading, plotting, resetting

    # This method saves the dictionary keeping the history of simulation to a .csv file
    def save_history_csv(self,
                         csv_name=None,
                         mode='init',
                         length_of_experiment='unknown'):

        if mode == 'init':

            # Make folder to save data (if not yet existing)
            try:
                os.makedirs(PATH_TO_EXPERIMENT_RECORDINGS[:-1])
            except FileExistsError:
                pass

            # Set path where to save the data
            if csv_name is None or csv_name == '':
                self.csv_filepath = PATH_TO_EXPERIMENT_RECORDINGS + 'CP_' + self.controller_name + str(
                    datetime.now().strftime('_%Y-%m-%d_%H-%M-%S')) + '.csv'
            else:
                self.csv_filepath = PATH_TO_EXPERIMENT_RECORDINGS + csv_name
                if csv_name[-4:] != '.csv':
                    self.csv_filepath += '.csv'

                # If such file exists, append index to the end (do not overwrite)
                net_index = 1
                logpath_new = self.csv_filepath
                while True:
                    if os.path.isfile(logpath_new):
                        logpath_new = self.csv_filepath[:-4]
                    else:
                        self.csv_filepath = logpath_new
                        break
                    logpath_new = logpath_new + '-' + str(net_index) + '.csv'
                    net_index += 1

            # Write the .csv file
            with open(self.csv_filepath, "a") as outfile:
                writer = csv.writer(outfile)

                writer.writerow([
                    '# ' +
                    'This is CartPole experiment from {} at time {}'.format(
                        datetime.now().strftime('%d.%m.%Y'),
                        datetime.now().strftime('%H:%M:%S'))
                ])
                try:
                    repo = Repo()
                    git_revision = repo.head.object.hexsha
                except:
                    git_revision = 'unknown'
                writer.writerow(
                    ['# ' + 'Done with git-revision: {}'.format(git_revision)])

                writer.writerow(['#'])
                writer.writerow([
                    '# Length of experiment: {} s'.format(
                        str(length_of_experiment))
                ])

                writer.writerow(['#'])
                writer.writerow(['# Time intervals dt:'])
                writer.writerow(
                    ['# Simulation: {} s'.format(str(self.dt_simulation))])
                writer.writerow([
                    '# Controller update: {} s'.format(str(self.dt_controller))
                ])
                writer.writerow(['# Saving: {} s'.format(str(self.dt_save))])

                writer.writerow(['#'])

                writer.writerow(
                    ['# Controller: {}'.format(self.controller_name)])

                writer.writerow(['#'])
                writer.writerow(['# Parameters:'])
                for k in self.p.__dict__:
                    writer.writerow(
                        ['# ' + k + ': ' + str(getattr(self.p, k))])
                writer.writerow(['#'])

                writer.writerow(['# Data:'])
                writer.writerow(self.dict_history.keys())

        elif mode == 'save online':

            # Save this dict
            with open(self.csv_filepath, "a") as outfile:
                writer = csv.writer(outfile)
                self.dict_history = {
                    key: np.around(value, self.rounding_decimals)
                    for key, value in self.dict_history.items()
                }
                writer.writerows(zip(*self.dict_history.values()))
            self.save_now = False

        elif mode == 'save offline':
            # Round data to a set precision
            with open(self.csv_filepath, "a") as outfile:
                writer = csv.writer(outfile)
                self.dict_history = {
                    key: np.around(value, self.rounding_decimals)
                    for key, value in self.dict_history.items()
                }
                writer.writerows(zip(*self.dict_history.values()))
            self.save_now = False
            # Another possibility to save data.
            # DF_history = pd.DataFrame.from_dict(self.dict_history).round(self.rounding_decimals)
            # DF_history.to_csv(self.csv_filepath, index=False, header=False, mode='a') # Mode (a)ppend

    # load csv file with experiment recording (e.g. for replay)
    def load_history_csv(self, csv_name=None):
        # Set path where to save the data
        if csv_name is None or csv_name == '':
            # get the latest file
            try:
                list_of_files = glob.glob(PATH_TO_EXPERIMENT_RECORDINGS +
                                          '/*.csv')
                file_path = max(list_of_files, key=os.path.getctime)
            except FileNotFoundError:
                print(
                    'Cannot load: No experiment recording found in data folder '
                    + './data/')
                return False
        else:
            filename = csv_name
            if csv_name[-4:] != '.csv':
                filename += '.csv'

            # check if file found in DATA_FOLDER_NAME or at local starting point
            if not os.path.isfile(filename):
                file_path = os.path.join(PATH_TO_EXPERIMENT_RECORDINGS,
                                         filename)
                if not os.path.isfile(file_path):
                    print(
                        'Cannot load: There is no experiment recording file with name {} at local folder or in {}'
                        .format(filename, PATH_TO_EXPERIMENT_RECORDINGS))
                    return False

        # Get race recording
        print('Loading file {}'.format(file_path))
        try:
            data: pd.DataFrame = pd.read_csv(
                file_path, comment='#')  # skip comment lines starting with #
        except Exception as e:
            print('Cannot load: Caught {} trying to read CSV file {}'.format(
                e, file_path))
            return False

        return data

    # Method plotting the dynamic evolution over time of the CartPole
    # It should be called after an experiment and only if experiment data was saved
    def summary_plots(self):
        fontsize_labels = 14
        fontsize_ticks = 12
        fig, axs = plt.subplots(
            4, 1, figsize=(16, 9),
            sharex=True)  # share x axis so zoom zooms all plots

        # Plot angle error
        axs[0].set_ylabel("Angle (deg)", fontsize=fontsize_labels)
        axs[0].plot(np.array(self.dict_history['time']),
                    np.array(self.dict_history['s.angle']) * 180.0 / np.pi,
                    'b',
                    markersize=12,
                    label='Ground Truth')
        axs[0].tick_params(axis='both',
                           which='major',
                           labelsize=fontsize_ticks)

        # Plot position
        axs[1].set_ylabel("position (m)", fontsize=fontsize_labels)
        axs[1].plot(self.dict_history['time'],
                    self.dict_history['s.position'],
                    'g',
                    markersize=12,
                    label='Ground Truth')
        axs[1].tick_params(axis='both',
                           which='major',
                           labelsize=fontsize_ticks)

        # Plot motor input command
        axs[2].set_ylabel("motor (N)", fontsize=fontsize_labels)
        axs[2].plot(self.dict_history['time'],
                    self.dict_history['u'],
                    'r',
                    markersize=12,
                    label='motor')
        axs[2].tick_params(axis='both',
                           which='major',
                           labelsize=fontsize_ticks)

        # Plot target position
        axs[3].set_ylabel("position target (m)", fontsize=fontsize_labels)
        axs[3].plot(self.dict_history['time'],
                    self.dict_history['target_position'], 'k')
        axs[3].tick_params(axis='both',
                           which='major',
                           labelsize=fontsize_ticks)

        axs[3].set_xlabel('Time (s)', fontsize=fontsize_labels)

        fig.align_ylabels()

        plt.show()

        return fig, axs

    # endregion

    # region 3. Methods for generating random target position for generation of random experiment

    # Generates a random target position
    # in a form of a function interpolating between turning points
    def Generate_Random_Trace_Function(self):
        if (self.turning_points is None) or (self.turning_points == []):

            number_of_turning_points = int(
                np.floor(self.length_of_experiment *
                         self.track_relative_complexity))

            y = 2.0 * (np.random.random(number_of_turning_points) - 0.5)
            y = y * 0.5 * self.HalfLength

            if number_of_turning_points == 0:
                y = np.append(y, 0.0)
                y = np.append(y, 0.0)
            elif number_of_turning_points == 1:
                if self.start_random_target_position_at is not None:
                    y[0] = self.start_random_target_position_at
                elif self.end_random_target_position_at is not None:
                    y[0] = self.end_random_target_position_at
                else:
                    pass
                y = np.append(y, y[0])
            else:
                if self.start_random_target_position_at is not None:
                    y[0] = self.start_random_target_position_at
                if self.end_random_target_position_at is not None:
                    y[-1] = self.end_random_target_position_at

        else:
            number_of_turning_points = len(self.turning_points)
            if number_of_turning_points == 0:
                raise ValueError('You should not be here!')
            elif number_of_turning_points == 1:
                y = np.array([self.turning_points[0], self.turning_points[0]])
            else:
                y = np.array(self.turning_points)

        number_of_timesteps = np.ceil(self.length_of_experiment /
                                      self.dt_simulation)
        self.t_max_pre = number_of_timesteps * self.dt_simulation

        random_samples = number_of_turning_points - 2 if number_of_turning_points - 2 >= 0 else 0

        # t_init = linspace(0, self.t_max_pre, num=self.track_relative_complexity, endpoint=True)
        if self.turning_points_period == 'random':
            t_init = np.sort(
                np.random.uniform(self.dt_simulation,
                                  self.t_max_pre - self.dt_simulation,
                                  random_samples))
            t_init = np.insert(t_init, 0, 0.0)
            t_init = np.append(t_init, self.t_max_pre)
        elif self.turning_points_period == 'regular':
            t_init = np.linspace(0,
                                 self.t_max_pre,
                                 num=random_samples + 2,
                                 endpoint=True)
        else:
            raise NotImplementedError(
                'There is no mode corresponding to this value of turning_points_period variable'
            )

        # Try algorithm setting derivative to 0 a each point
        if self.interpolation_type == '0-derivative-smooth':
            yder = [[y[i], 0] for i in range(len(y))]
            random_track_f = BPoly.from_derivatives(t_init, yder)
        elif self.interpolation_type == 'linear':
            random_track_f = interp1d(t_init, y, kind='linear')
        elif self.interpolation_type == 'previous':
            random_track_f = interp1d(t_init, y, kind='previous')
        else:
            raise ValueError('Unknown interpolation type.')

        # Truncate the target position to be not grater than 80% of track length
        def random_track_f_truncated(time):

            target_position = random_track_f(time)
            if target_position > 0.8 * self.HalfLength:
                target_position = 0.8 * self.HalfLength
            elif target_position < -0.8 * self.HalfLength:
                target_position = -0.8 * self.HalfLength

            return target_position

        self.random_track_f = random_track_f_truncated

        self.new_track_generated = True

    # Prepare CartPole Instance to perform an experiment with random target position trace
    def setup_cartpole_random_experiment(
            self,

            # Initial state
            s0=None,
            controller=None,
            dt_simulation=None,
            dt_controller=None,
            dt_save=None,

            # Settings related to random trace generation
            track_relative_complexity=None,
            length_of_experiment=None,
            interpolation_type=None,
            turning_points_period=None,
            start_random_target_position_at=None,
            end_random_target_position_at=None,
            turning_points=None):

        # Set time scales:
        self.dt_simulation = dt_simulation
        self.dt_controller = dt_controller
        self.dt_save = dt_save

        # Set CartPole in the right (automatic control) mode
        # You may want to provide it before this function not to reload it every time
        if controller is not None:
            self.set_controller(controller)

        # Set initial state
        self.s = s0

        self.track_relative_complexity = track_relative_complexity
        self.length_of_experiment = length_of_experiment
        self.interpolation_type = interpolation_type
        self.turning_points_period = turning_points_period
        self.start_random_target_position_at = start_random_target_position_at
        self.end_random_target_position_at = end_random_target_position_at
        self.turning_points = turning_points

        self.Generate_Random_Trace_Function()
        self.use_pregenerated_target_position = 1

        self.number_of_timesteps_in_random_experiment = int(
            np.ceil(self.length_of_experiment / self.dt_simulation))

        # Target position at time 0
        self.target_position = self.random_track_f(self.time)

        # Make already in the first timestep Q appropriate to the initial state, target position and controller

        self.Update_Q()

        self.set_cartpole_state_at_t0(reset_mode=2,
                                      s=self.s,
                                      Q=self.Q,
                                      target_position=self.target_position)

    # Runs a random experiment with parameters set with setup_cartpole_random_experiment
    # And saves the experiment recording to csv file
    # @profile(precision=4)
    def run_cartpole_random_experiment(self, csv=None, save_mode='offline'):
        """
        This function runs a random CartPole experiment
        and returns the history of CartPole states, control inputs and desired cart position
        """

        if save_mode == 'offline':
            self.save_data_in_cart = True
        elif save_mode == 'online':
            self.save_data_in_cart = False
        else:
            raise ValueError('Unknown save mode value')

        # Create csv file for savign
        self.save_history_csv(csv_name=csv,
                              mode='init',
                              length_of_experiment=self.length_of_experiment)

        # Save 0th timestep
        if save_mode == 'online':
            self.save_history_csv(csv_name=csv, mode='save online')

        # Run the CartPole experiment for number of time
        for _ in trange(self.number_of_timesteps_in_random_experiment):

            # Print an error message if it runs already to long (should stop before)
            if self.time > self.t_max_pre:
                raise Exception(
                    'ERROR: It seems the experiment is running too long...')

            self.update_state()

            # Additional option to stop the experiment
            if abs(self.s.position) > 45.0:
                break
                print('Cart went out of safety boundaries')

            # if abs(CartPoleInstance.s.angle) > 0.8*np.pi:
            #     raise ValueError('Cart went unstable')

            if save_mode == 'online' and self.save_flag:
                self.save_history_csv(csv_name=csv, mode='save online')
                self.save_flag = False

        data = pd.DataFrame(self.dict_history)

        if save_mode == 'offline':
            self.save_history_csv(csv_name=csv, mode='save offline')
            self.summary_plots()

        # Set CartPole state - the only use is to make sure that experiment history is discared
        # Maybe you can delete this line
        self.set_cartpole_state_at_t0(reset_mode=0)

        return data

    # endregion

    # region 4. Methods "Get, set, reset"

    # Method returns the list of controllers available in the PATH_TO_CONTROLLERS folder
    def get_available_controller_names(self):
        """
        Method returns the list of controllers available in the PATH_TO_CONTROLLERS folder
        """
        controller_files = glob.glob(PATH_TO_CONTROLLERS + 'controller_' +
                                     '*.py')
        controller_names = ['manual-stabilization']
        controller_names.extend(
            np.sort([
                os.path.basename(item)[len('controller_'):-len('.py')].replace(
                    '_', '-') for item in controller_files
            ]))

        return controller_names

    # Set the controller of CartPole
    def set_controller(self, controller_name=None, controller_idx=None):
        """
        The method sets a new controller as the current controller of the CartPole instance.
        The controller may be indicated either by its name
        or by the index on the controller list (see get_available_controller_names method).
        """

        # Check if the proper information was provided: either controller_name or controller_idx
        if (controller_name is None) and (controller_idx is None):
            raise ValueError(
                'You have to specify either controller_name or controller_idx to set a new controller.'
                'You have specified none of the two.')
        elif (controller_name is not None) and (controller_idx is not None):
            raise ValueError(
                'You have to specify either controller_name or controller_idx to set a new controller.'
                'You have specified both.')
        else:
            pass

        # If controller name provided get controller index and vice versa
        if (controller_name is not None):
            try:
                controller_idx = self.controller_names.index(controller_name)
            except ValueError:
                raise ValueError(
                    '{} is not in list. \n In list are: {}'.format(
                        controller_name, self.controller_names))
        else:
            controller_name = self.controller_names[controller_idx]

        # save controller name and index to variables in the CartPole namespace
        self.controller_name = controller_name
        self.controller_idx = controller_idx

        # Load controller
        if self.controller_name == 'manual-stabilization':
            self.controller = None
        else:
            controller_full_name = 'controller_' + self.controller_name.replace(
                '-', '_')
            path_import = PATH_TO_CONTROLLERS[2:].replace('/', '.').replace(
                r'\\', '.')
            import_str = 'from ' + path_import + controller_full_name + ' import ' + controller_full_name
            exec(import_str)
            self.controller = eval(controller_full_name + '()')

        # Set the maximal allowed value of the slider - relevant only for GUI
        if self.controller_name == 'manual-stabilization':
            self.slider_max = 1.0
            self.Slider_Arrow.set_positions((0, 0), (0, 0))
        else:
            self.slider_max = self.p.TrackHalfLength
            self.Slider_Bar.set_width(0.0)

    # This method resets the internal state of the CartPole instance
    # The starting state (for t = 0) may be
    # all zeros (reset_mode = 0)
    # set in this function (reset_mode = 1)
    # provide by user (reset_mode = 1), by giving s, Q and target_position
    def set_cartpole_state_at_t0(self,
                                 reset_mode=1,
                                 s=None,
                                 Q=None,
                                 target_position=None):
        self.time = 0.0
        if reset_mode == 0:  # Don't change it
            self.s.position = self.s.positionD = self.s.positionDD = 0.0
            self.s.angle = self.s.angleD = self.s.angleDD = 0.0
            self.Q = self.u = 0.0
            self.slider = self.target_position = 0.0

        elif reset_mode == 1:  # You may change this but be carefull with other user. Better use 3
            # You can change here with which initial parameters you wish to start the simulation
            self.s.position = 0.0
            self.s.positionD = 0.0
            self.s.angle = (2.0 * np.random.normal() -
                            1.0) * np.pi / 180.0  # np.pi/2.0 #
            self.s.angleD = 0.0  # 1.0
            self.target_position = self.slider_value

            self.Q = 0.0

            self.u = Q2u(self.Q, self.p)
            self.s.angleDD, self.s.positionDD = cartpole_ode(
                self.p, self.s, self.u)

        elif reset_mode == 2:  # Don't change it
            if (s is not None) and (Q is not None) and (target_position
                                                        is not None):
                self.s = s
                self.Q = Q
                self.slider = self.target_position = target_position

                self.u = Q2u(self.Q, self.p)  # Calculate CURRENT control input
                self.s.angleDD, self.s.positionDD = cartpole_ode(
                    self.p, self.s,
                    self.u)  # Calculate CURRENT second derivatives
            else:
                raise ValueError(
                    's, Q or target position not provided for initial state')

        # Reset the dict keeping the experiment history and save the state for t = 0
        self.dict_history = {
            'time': [self.time],
            's.position': [self.s.position],
            's.positionD': [self.s.positionD],
            's.positionDD': [self.s.positionDD],
            's.angle': [self.s.angle],
            's.angleD': [self.s.angleD],
            's.angleDD': [self.s.angleDD],
            'u': [self.u],
            'Q': [self.Q],
            'target_position': [self.target_position],
            's.angle.sin': [np.sin(self.s.angle)],
            's.angle.cos': [np.cos(self.s.angle)]
        }

    # region Get and set timescales

    # Makes sure that when dt is updated also related variables are updated

    @property
    def dt_simulation(self):
        return self._dt_simulation

    @dt_simulation.setter
    def dt_simulation(self, value):
        self._dt_simulation = value
        if self._dt_controller is not None:
            self.dt_controller_number_of_steps = np.rint(
                self._dt_controller / value).astype(np.int32)
            if self.dt_controller_number_of_steps == 0:
                self.dt_controller_number_of_steps = 1
            # Initialize counter at max value to start with update
            self.dt_controller_steps_counter = self.dt_controller_number_of_steps - 1
        if self._dt_save is not None:
            self.dt_save_number_of_steps = np.rint(self._dt_save /
                                                   value).astype(np.int32)
            if self.dt_save_number_of_steps == 0:
                self.dt_save_number_of_steps = 1
            self.dt_save_steps_counter = 0

    @property
    def dt_controller(self):
        return self._dt_controller

    @dt_controller.setter
    def dt_controller(self, value):
        self._dt_controller = value
        if self._dt_simulation is not None:
            self.dt_controller_number_of_steps = np.rint(
                value / self._dt_simulation).astype(np.int32)
            if self.dt_controller_number_of_steps == 0:
                self.dt_controller_number_of_steps = 1
            # Initialize counter at max value to start with update
            self.dt_controller_steps_counter = self.dt_controller_number_of_steps - 1

    @property
    def dt_save(self):
        return self._dt_save

    @dt_save.setter
    def dt_save(self, value):
        self._dt_save = value
        if self._dt_simulation is not None:
            self.dt_save_number_of_steps = np.rint(
                value / self._dt_simulation).astype(np.int32)
            if self.dt_save_number_of_steps == 0:
                self.dt_save_number_of_steps = 1
            # This counter is initialized at 0 - 0th step is saved manually
            self.dt_save_steps_counter = 0

    # endregion

    # endregion

    # region 5. Methods needed to display CartPole in GUI
    """
    This section contains methods related to displaying CartPole in GUI of the simulator.
    One could think of moving these function outside of CartPole class and connecting them rather more tightly
    with GUI of the simulator.
    We leave them however as a part of CartPole class as they rely on variables of the CartPole.
    """

    # This method initializes CartPole elements to be plotted in CartPole GUI
    def init_graphical_elements(self):

        self.CartLength = 10.0
        self.WheelRadius = 0.5
        self.WheelToMiddle = 4.0
        self.y_plane = 0.0
        self.y_wheel = self.y_plane + self.WheelRadius
        self.MastHight = 10.0  # For drawing only. For calculation see L
        self.MastThickness = 0.05
        self.HalfLength = 50.0  # Length of the track

        self.y_acceleration_arrow = 1.5 * self.WheelRadius
        self.scaling_dx_acceleration_arrow = 20.0
        self.x_acceleration_arrow = (
            self.s.position +
            # np.sign(self.Q) * (self.CartLength / 2.0) +
            self.scaling_dx_acceleration_arrow * self.Q)

        # Initialize elements of the drawing
        self.Mast = FancyBboxPatch(
            xy=(self.s.position - (self.MastThickness / 2.0),
                1.25 * self.WheelRadius),
            width=self.MastThickness,
            height=self.MastHight,
            fc='g')

        self.Chassis = FancyBboxPatch(
            (self.s.position - (self.CartLength / 2.0), self.WheelRadius),
            self.CartLength,
            1 * self.WheelRadius,
            fc='r')

        self.WheelLeft = Circle(
            (self.s.position - self.WheelToMiddle, self.y_wheel),
            radius=self.WheelRadius,
            fc='y',
            ec='k',
            lw=5)

        self.WheelRight = Circle(
            (self.s.position + self.WheelToMiddle, self.y_wheel),
            radius=self.WheelRadius,
            fc='y',
            ec='k',
            lw=5)

        self.Acceleration_Arrow = FancyArrowPatch(
            (self.s.position, self.y_acceleration_arrow),
            (self.x_acceleration_arrow, self.y_acceleration_arrow),
            arrowstyle='simple',
            mutation_scale=10,
            facecolor='gold',
            edgecolor='orange')

        self.Slider_Arrow = FancyArrowPatch((self.slider_value, 0),
                                            (self.slider_value, 0),
                                            arrowstyle='fancy',
                                            mutation_scale=50)
        self.Slider_Bar = Rectangle((0.0, 0.0), self.slider_value, 1.0)
        self.t2 = transforms.Affine2D().rotate(
            0.0)  # An abstract container for the transform rotating the mast

    # This method accepts the mouse position and updated the slider value accordingly
    # The mouse position has to be captured by a function not included in this class
    def update_slider(self, mouse_position):
        # The if statement formulates a saturation condition
        if mouse_position > self.slider_max:
            self.slider_value = self.slider_max
        elif mouse_position < -self.slider_max:
            self.slider_value = -self.slider_max
        else:
            self.slider_value = mouse_position

    # This method draws elements and set properties of the CartPole figure
    # which do not change at every frame of the animation
    def draw_constant_elements(self, fig, AxCart, AxSlider):
        # Delete all elements of the Figure
        AxCart.clear()
        AxSlider.clear()

        ## Upper chart with Cart Picture
        # Set x and y limits
        AxCart.set_xlim((-self.HalfLength * 1.1, self.HalfLength * 1.1))
        AxCart.set_ylim((-1.0, 15.0))
        # Remove ticks on the y-axes
        AxCart.yaxis.set_major_locator(plt.NullLocator(
        ))  # NullLocator is used to disable ticks on the Figures

        # Draw track
        Floor = Rectangle((-self.HalfLength, -1.0),
                          2 * self.HalfLength,
                          1.0,
                          fc='brown')
        AxCart.add_patch(Floor)

        # Draw an invisible point at constant position
        # Thanks to it the axes is drawn high enough for the mast
        InvisiblePointUp = Rectangle((0, self.MastHight + 2.0),
                                     self.MastThickness,
                                     0.0001,
                                     fc='w',
                                     ec='w')

        AxCart.add_patch(InvisiblePointUp)
        # Apply scaling
        AxCart.axis('scaled')

        ## Lower Chart with Slider
        # Set y limits
        AxSlider.set(xlim=(-1.1 * self.slider_max, self.slider_max * 1.1))
        # Remove ticks on the y-axes
        AxSlider.yaxis.set_major_locator(plt.NullLocator(
        ))  # NullLocator is used to disable ticks on the Figures
        # Apply scaling
        AxSlider.set_aspect("auto")

        return fig, AxCart, AxSlider

    # This method updates the elements of the Cart Figure which change at every frame.
    # Not that these elements are not ploted directly by this method
    # but rather returned as objects which can be used by another function
    # e.g. animation function from matplotlib package
    def update_drawing(self):

        self.x_acceleration_arrow = (
            self.s.position +
            # np.sign(self.Q) * (self.CartLength / 2.0) +
            self.scaling_dx_acceleration_arrow * self.Q)

        self.Acceleration_Arrow.set_positions(
            (self.s.position, self.y_acceleration_arrow),
            (self.x_acceleration_arrow, self.y_acceleration_arrow))

        # Draw mast
        mast_position = (self.s.position - (self.MastThickness / 2.0))
        self.Mast.set_x(mast_position)
        # Draw rotated mast
        t21 = transforms.Affine2D().translate(-mast_position,
                                              -1.25 * self.WheelRadius)
        if ANGLE_CONVENTION == 'CLOCK-NEG':
            t22 = transforms.Affine2D().rotate(self.s.angle)
        elif ANGLE_CONVENTION == 'CLOCK-POS':
            t22 = transforms.Affine2D().rotate(-self.s.angle)
        else:
            raise ValueError('Unknown angle convention')
        t23 = transforms.Affine2D().translate(mast_position,
                                              1.25 * self.WheelRadius)
        self.t2 = t21 + t22 + t23
        # Draw Chassis
        self.Chassis.set_x(self.s.position - (self.CartLength / 2.0))
        # Draw Wheels
        self.WheelLeft.center = (self.s.position - self.WheelToMiddle,
                                 self.y_wheel)
        self.WheelRight.center = (self.s.position + self.WheelToMiddle,
                                  self.y_wheel)
        # Draw SLider
        if self.controller_name == 'manual-stabilization':
            self.Slider_Bar.set_width(self.slider_value)
        else:
            self.Slider_Arrow.set_positions((self.slider_value, 0),
                                            (self.slider_value, 1.0))

        return self.Mast, self.t2, self.Chassis, self.WheelRight, self.WheelLeft,\
               self.Slider_Bar, self.Slider_Arrow, self.Acceleration_Arrow

    # A function redrawing the changing elements of the Figure
    def run_animation(self, fig):
        def init():
            # Adding variable elements to the Figure
            fig.AxCart.add_patch(self.Mast)
            fig.AxCart.add_patch(self.Chassis)
            fig.AxCart.add_patch(self.WheelLeft)
            fig.AxCart.add_patch(self.WheelRight)
            fig.AxCart.add_patch(self.Acceleration_Arrow)
            fig.AxSlider.add_patch(self.Slider_Bar)
            fig.AxSlider.add_patch(self.Slider_Arrow)
            return self.Mast, self.Chassis, self.WheelLeft, self.WheelRight,\
                   self.Slider_Bar, self.Slider_Arrow, self.Acceleration_Arrow

        def animationManage(i):
            # Updating variable elements
            self.update_drawing()
            # Special care has to be taken of the mast rotation
            self.t2 = self.t2 + fig.AxCart.transData
            self.Mast.set_transform(self.t2)
            return self.Mast, self.Chassis, self.WheelLeft, self.WheelRight,\
                   self.Slider_Bar, self.Slider_Arrow, self.Acceleration_Arrow

        # Initialize animation object
        anim = animation.FuncAnimation(
            fig,
            animationManage,
            init_func=init,
            frames=300,
            # fargs=(CartPoleInstance,), # It was used when this function was a part of GUI class. Now left as an example how to add arguments to FuncAnimation
            interval=10,
            blit=True,
            repeat=True)
        return anim
class EnvironmentReader(Reader):
    def __init__(self, vl, room_shape, start_iter):
        super(EnvironmentReader,self).__init__(win_size=[700,700],
                                               win_loc=pos,
                                               title='Environment')
        self.vl = vl
        self.cur_iter = start_iter
        
        self.cur_i = 0
        self.max_iter = np.max(vl['Iter num'])
        self.maxx = room_shape[0][1]
        self.maxy = room_shape[1][1]
        self.cntr_x = np.mean(room_shape[0])
        self.cntr_y = np.mean(room_shape[1])
        
        self.x_hist = []; self.y_hist = []
        
        self.canvas.mpl_connect('resize_event',lambda x: self.update_background())
        
        self.pos, = self.ax.plot([],[],color='g',animated=True)
        self.vel = Arrow([0,0],[1,0],arrowstyle='-|>, head_length=3, head_width=3',
                         animated=True, linewidth=4)
        self.ax.add_patch(self.vel)
        #self.radius, = self.ax.plot([],[],color='r',animated=True)
        self.clockcounter = self.ax.text(self.maxx,room_shape[1][0],
                                         '',ha='right',size='large',
                                         animated=True)
        self.iters = self.ax.text(self.maxx-1, room_shape[1][0]+3,
                                  '',ha='right',animated=True)
        self.target = Circle([0,0],0,animated=True,color='r')
        self.target.set_radius(11)
        self.ax.add_artist(self.target)
        
        self.ax.set_xlim(room_shape[0])
        self.ax.set_ylim(room_shape[1])
        
        self.draw_plc_flds()
        
        self.draw_HMM_sectors()
        self.predicted_counter = self.ax.text(self.maxx,room_shape[1][0]+10,'',
                                              ha='right',size='large',animated=True)
        self.prediction_hist = None
        #self.cur_sec = Rectangle([0,0],self.HMM.dx,self.HMM.dy,fill=True,
        #                         color='k',animated=True)
        #self.ax_env.add_patch(self.cur_sec)
        
        self.canvas.draw()
        
    def read(self, iteration=None):
        ''' Return the environment data corresponding to
            the next iteration.
            
            Note that 0 or multiple data points can be
            associated with a single iteration number.
            
            All of the environment data before self.cur_i
            has already been recorded.'''
        if iteration is not None: self.cur_iter = iteration
        try:
            cur_j = 1+self.cur_i+ np.nonzero(self.vl['Iter num'][self.cur_i:] == self.cur_iter)[0][-1]
        except:
            # There is no iteration with that value
            self.cur_iter += 1
            return [np.NAN, np.NAN, np.NAN, np.NAN]
        cur_x = self.vl['xs'][self.cur_i:cur_j]
        cur_y = self.vl['ys'][self.cur_i:cur_j]
        cur_vx = self.vl['vxs'][self.cur_i:cur_j]
        cur_vy = self.vl['vys'][self.cur_i:cur_j]
        self.cur_iter += 1
        self.cur_i = cur_j

        return (cur_x, cur_y, cur_vx, cur_vy)
    
    def draw_HMM_sectors(self):
        return
        for i in range(self.HMM.cll_p_sd):
            curx = self.HMM.xrange[0]+i*self.HMM.dx
            cury = self.HMM.yrange[0]+i*self.HMM.dy
            self.ax_env.plot([curx,curx],[0,self.maxy],'k')
            self.ax_env.plot([0,self.maxx],[cury,cury], 'k')

        
    def draw_plc_flds(self):
        return
        clrs = ['b','g']
        legend_widgets = []
        for i, plc_flds in zip(range(len(self.WR.contexts)),self.WR.contexts.values()):
            added = False
            for plc_fld in plc_flds:
                circ = Circle([plc_fld.x,plc_fld.y], plc_fld.r, color=clrs[i])
                self.ax_env.add_patch(circ)
                
                if not added:
                    legend_widgets.append(circ)
                    added=True
        self.ax.legend(legend_widgets, ('counterclockwise','clockwise'),'lower right')
    
    def draw(self, xs,ys,vxs,vys):
        ''' Display the environment data to the window.
        
            The environment data are inputs. '''
        
        if np.any(np.isnan(xs)):
            return
        
        # Display how far along in the animation we are
        self.iters.set_text('%i/%i'%(self.cur_iter,self.max_iter))
        
        # Begin the drawing process
        self.canvas.restore_region(self.background)
        
        # Calculate and draw rat's physical position
        x = xs[-1]; y=ys[-1]; vx=vxs[-1]; vy=vys[-1]
        self.x_hist.extend(xs)
        self.y_hist.extend(ys)
        self.pos.set_data(self.x_hist,self.y_hist)
        
        # Adjust velocity vector
        if vx != 0 or vy != 0: 
            self.vel.set_positions([x, y], [x+vx,y+vy])

        # Adjust radius line
        #self.radius.set_data([self.cntr_x,x],[self.cntr_y,y])

        # Calculate physical orientation, display it, and compare with
        #  virmenLog's orientatino assessment
        cur_orientation = self.is_clockwise(x,y,vx,vy)
        if  cur_orientation == 1:
            self.clockcounter.set_text('Clockwise')
        else:
            self.clockcounter.set_text('Counterclockwise')

        # Adjust the location  of the rat's chased target
        target_x = self.vl['txs'][self.cur_i]
        target_y = self.vl['tys'][self.cur_i]
        self.target.center = [target_x,target_y]

        # Make a context prediction
        try:
            self._make_prediction(self.is_clockwise(x,y,vx,vy))
        except:
            pass
            #logging.warning('Make prediction failed.')

        # Update the drawing window
        for itm in [self.pos, self.vel, #self.radius,
                    self.clockcounter,self.iters, 
                    self.predicted_counter, self.target]:
            self.ax.draw_artist(itm)
        self.canvas.blit(self.ax.bbox)

    def is_clockwise(self,x,y,vx,vy):
        ''' Determines if motion is clockwise around the center
            of the room, which is [0, MAXX] x [0, MAXY] 
            
            Output mapped to {-1,1} to conform with vl['Task'] labels.'''
        cross_prod = (x-self.cntr_x)*vy - (y-self.cntr_y)*vx
        clockwise = 2*(cross_prod>0)-1
        return clockwise
class ThreeCompVisualisation:
    """
    Basis to visualise power flow within the hydraulics model as an animation or simulation
    """

    def __init__(self, agent: ThreeCompHydAgent,
                 axis: plt.axis = None,
                 animated: bool = False,
                 detail_annotations: bool = False,
                 basic_annotations: bool = True,
                 black_and_white: bool = False):
        """
        Whole visualisation setup using given agent's parameters
        :param agent: The agent to be visualised
        :param axis: If set, the visualisation will be drawn using the provided axis object.
        Useful for animations or to display multiple models in one plot
        :param animated: If true the visualisation is set up to deal with frame updates.
        See animation script for more details.
        :param detail_annotations: If true, tank distances and sizes are annotated as well.
        :param basic_annotations: If true, U, LF, and LS are visible
        :param black_and_white: If true, the visualisation is in black and white
        """
        # matplotlib fontsize
        rcParams['font.size'] = 10

        # plot if no axis was assigned
        if axis is None:
            fig = plt.figure(figsize=(8, 5))
            self._ax1 = fig.add_subplot(1, 1, 1)
        else:
            fig = None
            self._ax1 = axis

        if black_and_white:
            self.__u_color = (0.7, 0.7, 0.7)
            self.__lf_color = (0.5, 0.5, 0.5)
            self.__ls_color = (0.3, 0.3, 0.3)
            self.__ann_color = (0, 0, 0)
            self.__p_color = (0.5, 0.5, 0.5)

        elif not black_and_white:
            self.__u_color = "tab:cyan"
            self.__lf_color = "tab:orange"
            self.__ls_color = "tab:red"
            self.__ann_color = "tab:blue"
            self.__p_color = "tab:green"

        # basic parameters for setup
        self._animated = animated
        self.__detail_annotations = detail_annotations
        self._agent = agent
        self.__offset = 0.2

        # U tank with three stripes
        self.__width_u = 0.3
        self._u = None
        self._u1 = None
        self._u2 = None
        self._r1 = None  # line marking flow from U to LF
        self._ann_u = None  # U annotation

        # LF tank
        self._lf = None
        self._h = None  # fill state
        self._ann_lf = None  # annotation

        # LS tank
        self._ls = None
        self._g = None
        self._ann_ls = None  # annotation LS
        self._r2 = None  # line marking flow from LS to LF

        # finish the basic layout
        self.__set_basic_layout()
        self.update_basic_layout(agent)

        # now the animation components
        if self._animated:
            # U flow
            self._arr_u_flow = None
            self._ann_u_flow = None

            # flow out of tap
            self._arr_power_flow = None
            self._ann_power_flow = None

            # LS flow (R2)
            self._arr_r2_l_pos = None
            self._arr_r2_flow = None
            self._ann_r2_flow = None

            # time information annotation
            self._ann_time = None

            self.__set_animation_layout()
            self._ax1.add_artist(self._ann_time)

        # basic annotations are U, LF, and LS
        if not basic_annotations:
            self.hide_basic_annotations()

        # add layout for detailed annotations
        # detail annotation add greek letters for distances and positions
        if self.__detail_annotations:
            self.__set_detailed_annotations_layout()
            self._ax1.set_xlim(0, 1.05)
            self._ax1.set_ylim(0, 1.2)
        else:
            self._ax1.set_xlim(0, 1.0)
            self._ax1.set_ylim(0, 1.2)

        if self.__detail_annotations and self._animated:
            raise UserWarning("Detailed annotations and animation cannot be combined")

        self._ax1.set_axis_off()

        # display plot if no axis object was assigned
        if fig is not None:
            plt.subplots_adjust(left=0.02, bottom=0.02, right=0.98, top=0.98)
            plt.show()
            plt.close(fig)

    def __set_detailed_annotations_layout(self):
        """
        Adds components required for a detailed
        annotations view with denoted positions and distances
        """

        u_width = self.__width_u
        ls_left = self._ls.get_x()
        ls_width = self._ls.get_width()
        lf_left = self._lf.get_x()
        lf_width = self._lf.get_width()
        ls_height = self._ls.get_height()

        # some offset to the bottom
        offset = self.__offset
        phi_o = self._agent.phi + offset
        gamma_o = self._agent.gamma + offset

        rcParams['text.usetex'] = True

        self._ann_u_flow = Text(text="$M_{U}$", ha='right', fontsize="xx-large", x=u_width + 0.09,
                                y=phi_o - 0.08)
        ann_p_ae = Text(text="$p_{U}$", ha='right', fontsize="xx-large", x=u_width + 0.07,
                        y=phi_o + 0.03)
        self._arr_u_flow = FancyArrowPatch((u_width, phi_o),
                                           (u_width + 0.1, phi_o),
                                           arrowstyle='-|>',
                                           mutation_scale=30,
                                           lw=2,
                                           color=self.__u_color)
        self._ax1.annotate('$\phi$',
                           xy=(u_width / 2, phi_o),
                           xytext=(u_width / 2, (phi_o - offset) / 2 + offset - 0.015),
                           ha='center',
                           fontsize="xx-large",
                           arrowprops=dict(arrowstyle='-|>',
                                           ls='-',
                                           fc=self.__ann_color)
                           )
        self._ax1.annotate('$\phi$',
                           xy=(u_width / 2, offset),
                           xytext=(u_width / 2, (phi_o - offset) / 2 + offset - 0.015),
                           ha='center',
                           fontsize="xx-large",
                           arrowprops=dict(arrowstyle='-|>',
                                           ls='-',
                                           fc=self.__ann_color)
                           )

        self._ann_power_flow = Text(text="$p$", ha='center', fontsize="xx-large", x=self._ann_lf.get_position()[0],
                                    y=offset - 0.06)
        self._arr_power_flow = FancyArrowPatch((self._ann_lf.get_position()[0], offset - 0.078),
                                               (self._ann_lf.get_position()[0], 0.0),
                                               arrowstyle='-|>',
                                               mutation_scale=30,
                                               lw=2,
                                               color=self.__p_color)

        self._ax1.annotate('$h$',
                           xy=(self._ann_lf.get_position()[0] + 0.07, 1 + offset),
                           xytext=(self._ann_lf.get_position()[0] + 0.07, 1 + offset - 0.30),
                           ha='center',
                           fontsize="xx-large",
                           arrowprops=dict(arrowstyle='-|>',
                                           ls='-',
                                           fc=self.__ann_color)
                           )
        self._ax1.annotate('$h$',
                           xy=(self._ann_lf.get_position()[0] + 0.07, 1 + offset - 0.55),
                           xytext=(self._ann_lf.get_position()[0] + 0.07, 1 + offset - 0.30),
                           ha='center',
                           fontsize="xx-large",
                           arrowprops=dict(arrowstyle='-|>',
                                           ls='-',
                                           fc=self.__ann_color)
                           )

        self._h.update(dict(xy=(lf_left, offset),
                            width=lf_width,
                            height=1 - 0.55,
                            color=self.__lf_color))

        self._ax1.annotate('$g$',
                           xy=(ls_left + ls_width + 0.02, ls_height + gamma_o),
                           xytext=(ls_left + ls_width + 0.02, ls_height * 0.61 + gamma_o),
                           ha='center',
                           fontsize="xx-large",
                           arrowprops=dict(arrowstyle='-|>',
                                           ls='-',
                                           fc=self.__ann_color)
                           )
        self._ax1.annotate('$g$',
                           xy=(ls_left + ls_width + 0.02, ls_height * 0.3 + gamma_o),
                           xytext=(ls_left + ls_width + 0.02, ls_height * 0.61 + gamma_o),
                           ha='center',
                           fontsize="xx-large",
                           arrowprops=dict(arrowstyle='-|>',
                                           ls='-',
                                           fc=self.__ann_color)
                           )
        self._g.update(dict(xy=(ls_left, gamma_o),
                            width=ls_width,
                            height=ls_height * 0.3,
                            color=self.__ls_color))

        ann_p_an = Text(text="$p_{L}$", ha='left', usetex=True, fontsize="xx-large", x=ls_left - 0.06,
                        y=gamma_o + 0.11)
        ann_arr_flow = Text(text="$M_{LS}$", ha='left', usetex=True, fontsize="xx-large", x=ls_left - 0.09,
                            y=gamma_o + 0.03)
        self._ann_r2_flow = Text(text="$M_{LF}$", ha='left', usetex=True, fontsize="xx-large", x=ls_left - 0.04,
                                 y=gamma_o - 0.07)

        self._arr_r2_flow = FancyArrowPatch((ls_left - 0.1, gamma_o),
                                            (ls_left, gamma_o),
                                            arrowstyle='<|-|>',
                                            mutation_scale=30,
                                            lw=2,
                                            color=self.__ls_color)

        self._ax1.annotate('$\\theta$',
                           xy=(ls_left + ls_width / 2, 1 + offset),
                           xytext=(
                               ls_left + ls_width / 2,
                               1 - (1 - (ls_height + gamma_o - offset)) / 2 + offset - 0.015),
                           ha='center',
                           fontsize="xx-large",
                           arrowprops=dict(arrowstyle='-|>',
                                           ls='-',
                                           fc=self.__ann_color)
                           )
        self._ax1.annotate('$\\theta$',
                           xy=(ls_left + ls_width / 2, ls_height + gamma_o),
                           xytext=(
                               ls_left + ls_width / 2,
                               1 - (1 - (ls_height + gamma_o - offset)) / 2 + offset - 0.015),
                           ha='center',
                           fontsize="xx-large",
                           arrowprops=dict(arrowstyle='-|>',
                                           ls='-',
                                           fc=self.__ann_color)
                           )

        self._ax1.annotate('$\\gamma$',
                           xy=(ls_left + ls_width / 2, offset),
                           xytext=(ls_left + ls_width / 2, (gamma_o - offset) / 2 + offset - 0.015),
                           ha='center',
                           fontsize="xx-large",
                           arrowprops=dict(arrowstyle='-|>',
                                           ls='-',
                                           fc=self.__ann_color)
                           )
        self._ax1.annotate('$\\gamma$',
                           xy=(ls_left + ls_width / 2, gamma_o),
                           xytext=(ls_left + ls_width / 2, (gamma_o - offset) / 2 + offset - 0.015),
                           ha='center',
                           fontsize="xx-large",
                           arrowprops=dict(arrowstyle='-|>',
                                           ls='-',
                                           fc=self.__ann_color)
                           )

        self._ax1.annotate('$1$',
                           xy=(1.05, 0 + offset),
                           xytext=(1.05, 0.5 + offset),
                           ha='center',
                           fontsize="xx-large",
                           arrowprops=dict(arrowstyle='-|>',
                                           ls='-',
                                           fc=self.__ann_color)
                           )
        self._ax1.annotate('$1$',
                           xy=(1.05, 1 + offset),
                           xytext=(1.05, 0.5 + offset),
                           ha='center',
                           fontsize="xx-large",
                           arrowprops=dict(arrowstyle='-|>',
                                           ls='-',
                                           fc=self.__ann_color)
                           )
        self._ax1.add_artist(ann_arr_flow)
        # self._ax1.add_artist(ann_p_an)
        # self._ax1.add_artist(ann_p_ae)
        self._ax1.axhline(offset, linestyle='--', color=self.__ann_color)
        self._ax1.axhline(1 + offset - 0.001, linestyle='--', color=self.__ann_color)
        self._ax1.add_artist(self._ann_power_flow)
        self._ax1.add_artist(self._arr_power_flow)
        self._ax1.add_artist(self._arr_u_flow)
        self._ax1.add_artist(self._ann_u_flow)
        self._ax1.add_artist(self._arr_r2_flow)
        self._ax1.add_artist(self._ann_r2_flow)

    def __set_animation_layout(self):
        """
        Adds layout components that are required for an animation
        """

        offset = self.__offset
        o_width = self.__width_u
        phi_o = self._agent.phi + offset
        gamma_o = self._agent.gamma + offset

        # U flow (R1)
        self._arr_u_flow = FancyArrowPatch((o_width, phi_o),
                                           (o_width + 0.1, phi_o),
                                           arrowstyle='simple',
                                           mutation_scale=0,
                                           ec='white',
                                           fc=self.__u_color)
        self._ann_u_flow = Text(text="flow: ", ha='right', fontsize="large", x=o_width, y=phi_o - 0.05)

        # Tap flow (Power)
        self._arr_power_flow = FancyArrowPatch((self._ann_lf.get_position()[0], offset - 0.05),
                                               (self._ann_lf.get_position()[0], 0.0),
                                               arrowstyle='simple',
                                               mutation_scale=0,
                                               ec='white',
                                               color=self.__p_color)
        self._ann_power_flow = Text(text="flow: ", ha='center', fontsize="large", x=self._ann_lf.get_position()[0],
                                    y=offset - 0.05)

        # LS flow (R2)
        self._arr_r2_l_pos = [(self._ls.get_x(), gamma_o),
                              (self._ls.get_x() - 0.1, gamma_o)]
        self._arr_r2_flow = FancyArrowPatch(self._arr_r2_l_pos[0],
                                            self._arr_r2_l_pos[1],
                                            arrowstyle='simple',
                                            mutation_scale=0,
                                            ec='white',
                                            color=self.__ls_color)
        self._ann_r2_flow = Text(text="flow: ", ha='left', fontsize="large", x=self._ls.get_x(),
                                 y=gamma_o - 0.05)

        # information annotation
        self._ann_time = Text(x=1, y=0.9 + offset, ha="right")

        self._ax1.add_artist(self._ann_power_flow)
        self._ax1.add_artist(self._arr_power_flow)
        self._ax1.add_artist(self._arr_u_flow)
        self._ax1.add_artist(self._ann_u_flow)
        self._ax1.add_artist(self._arr_r2_flow)
        self._ax1.add_artist(self._ann_r2_flow)

    def __set_basic_layout(self):
        """
        updates position estimations and layout
        """

        # get sizes from agent
        lf = self._agent.lf
        ls = self._agent.ls
        ls_height = self._agent.height_ls

        # u_left is 0
        u_width = self.__width_u

        # determine width with size ratio retained
        lf_left = u_width + 0.1
        lf_width = ((lf * ls_height) * (1 - lf_left - 0.1)) / ls
        ls_left = lf_left + lf_width + 0.1
        ls_width = 1 - ls_left

        # some offset to the bottom
        offset = self.__offset
        phi_o = self._agent.phi + offset
        gamma_o = self._agent.gamma + offset

        # S tank
        self._u = Rectangle((0.0, phi_o), 0.05, 1 - self._agent.phi, color=self.__u_color, alpha=0.3)
        self._u1 = Rectangle((0.05, phi_o), 0.05, 1 - self._agent.phi, color=self.__u_color, alpha=0.6)
        self._u2 = Rectangle((0.1, phi_o), u_width - 0.1, 1 - self._agent.phi, color=self.__u_color)
        self._r1 = Line2D([u_width, u_width + 0.1],
                          [phi_o, phi_o],
                          color=self.__u_color)
        self._ann_u = Text(text="$U$", ha='center', fontsize="xx-large",
                           x=u_width / 2,
                           y=((1 - self._agent.phi) / 2) + phi_o - 0.02)

        # LF vessel
        self._lf = Rectangle((lf_left, offset), lf_width, 1, fill=False, ec="black")
        self._h = Rectangle((lf_left, offset), lf_width, 1, color=self.__lf_color)
        self._ann_lf = Text(text="$LF$", ha='center', fontsize="xx-large",
                            x=lf_left + (lf_width / 2),
                            y=offset + 0.5 - 0.02)

        # LS vessel
        self._ls = Rectangle((ls_left, gamma_o), ls_width, ls_height, fill=False, ec="black")
        self._g = Rectangle((ls_left, gamma_o), ls_width, ls_height, color=self.__ls_color)
        self._r2 = Line2D([ls_left, ls_left - 0.1],
                          [gamma_o, gamma_o],
                          color=self.__ls_color)
        self._ann_ls = Text(text="$LS$", ha='center', fontsize="xx-large",
                            x=ls_left + (ls_width / 2),
                            y=gamma_o + (ls_height / 2) - 0.02)

        # the basic layout
        self._ax1.add_line(self._r1)
        self._ax1.add_line(self._r2)
        self._ax1.add_artist(self._u)
        self._ax1.add_artist(self._u1)
        self._ax1.add_artist(self._u2)
        self._ax1.add_artist(self._lf)
        self._ax1.add_artist(self._ls)
        self._ax1.add_artist(self._h)
        self._ax1.add_artist(self._g)
        self._ax1.add_artist(self._ann_u)
        self._ax1.add_artist(self._ann_lf)
        self._ax1.add_artist(self._ann_ls)

    def update_basic_layout(self, agent):
        """
        updates tank positions and sizes according to new agent
        :param agent: agent to be visualised
        """

        self._agent = agent

        # get sizes from agent
        lf = agent.lf
        ls = agent.ls
        ls_heigh = agent.height_ls

        # o_left is 0
        u_width = self.__width_u

        # determine width with size ratio retained
        lf_left = u_width + 0.1
        lf_width = ((lf * ls_heigh) * (1 - lf_left - 0.1)) / (ls + lf * ls_heigh)
        ls_left = lf_left + lf_width + 0.1
        ls_width = 1 - ls_left

        # some offset to the bottom
        offset = self.__offset
        phi_o = agent.phi + offset
        gamma_o = agent.gamma + offset

        # S tank
        self._u.set_bounds(0.0, phi_o, 0.05, 1 - self._agent.phi)
        self._u1.set_bounds(0.05, phi_o, 0.05, 1 - self._agent.phi)
        self._u2.set_bounds(0.1, phi_o, u_width - 0.1, 1 - self._agent.phi)
        self._r1.set_xdata([u_width, u_width + 0.1])
        self._r1.set_ydata([phi_o, phi_o])
        self._ann_u.set_position(xy=(u_width / 2, ((1 - self._agent.phi) / 2) + phi_o - 0.02))

        # LF vessel
        self._lf.set_bounds(lf_left, offset, lf_width, 1)
        self._h.set_bounds(lf_left, offset, lf_width, 1)
        self._ann_lf.set_position(xy=(lf_left + (lf_width / 2),
                                      offset + 0.5 - 0.02))

        # LS vessel
        self._ls.set_bounds(ls_left, gamma_o, ls_width, ls_heigh)
        self._g.set_bounds(ls_left, gamma_o, ls_width, ls_heigh)
        self._r2.set_xdata([ls_left, ls_left - 0.1])
        self._r2.set_ydata([gamma_o, gamma_o])
        self._ann_ls.set_position(xy=(ls_left + (ls_width / 2), gamma_o + (ls_heigh / 2) - 0.02))

        # update levels
        self._h.set_height(1 - self._agent.get_h())
        self._g.set_height(self._agent.height_ls - self._agent.get_g())

    def hide_basic_annotations(self):
        """
        Simply hides the S, LF, and LS text
        """
        self._ann_u.set_text("")
        self._ann_lf.set_text("")
        self._ann_ls.set_text("")

    def update_animation_data(self, frame_number):
        """
        For animations and simulations.
        The function to call at each frame.
        :param frame_number: frame number has to be taken because of parent class method
        :return: an iterable of artists
        """

        if not self._animated:
            raise UserWarning("Animation flag has to be enabled in order to use this function")

        # perform one step
        cur_time = self._agent.get_time()
        power = self._agent.perform_one_step()

        # draw some information
        self._ann_time.set_text("agent \n time: {}".format(int(cur_time)))

        # power arrow
        self._ann_power_flow.set_text("power: {}".format(round(power)))
        self._arr_power_flow.set_mutation_scale(math.log(power + 1) * 10)

        # oxygen arrow
        p_u = round(self._agent.get_p_u() * self._agent.hz, 1)
        max_str = "(MAX)" if p_u == self._agent.m_u else ""
        self._ann_u_flow.set_text("flow: {} {}".format(p_u, max_str))
        self._arr_u_flow.set_mutation_scale(math.log(p_u + 1) * 10)

        # lactate arrow
        p_g = round(self._agent.get_p_l() * self._agent.hz, 1)
        if p_g < 0:
            max_str = "(MAX)" if p_g == self._agent.m_lf else ""
            self._arr_r2_flow.set_positions(self._arr_r2_l_pos[1], self._arr_r2_l_pos[0])
        else:
            max_str = "(MAX)" if p_g == self._agent.m_ls else ""
            self._arr_r2_flow.set_positions(self._arr_r2_l_pos[0], self._arr_r2_l_pos[1])
        self._ann_r2_flow.set_text("flow: {} {}".format(p_g, max_str))
        self._arr_r2_flow.set_mutation_scale(math.log(abs(p_g) + 1) * 10)

        # update levels
        self._h.set_height(1 - self._agent.get_h())
        self._g.set_height(self._agent.height_ls - self._agent.get_g())

        # list of artists to be drawn
        return [self._ann_time,
                self._ann_power_flow,
                self._arr_power_flow,
                self._g,
                self._h]