Ejemplo n.º 1
0
 def update(self, _):
     xs_r, ys_r = ([], [])
     xs_d, ys_d = ([], [])
     self.clear_annotations()
     if self.index >= len(self.rad_creeps_history):
         self.index = len(self.rad_creeps_history) - 1
     for creep in self.rad_creeps_history[self.index]:
         xs_r.append(creep.x)
         ys_r.append(creep.y)
         # annotation = text.Annotation("%.2f %s" % (creep.health, creep.label),
         annotation = text.Annotation("%i %s" %
                                      (len(creep.history), creep.label),
                                      xy=(creep.x, creep.y),
                                      xytext=(-35, -5),
                                      textcoords='offset points')
         self.add_annotation(annotation)
     for creep in self.dir_creeps_history[self.index]:
         xs_d.append(creep.x)
         ys_d.append(creep.y)
         annotation = text.Annotation("%.2f" % creep.health,
                                      xy=(creep.x, creep.y),
                                      xytext=(10, 0),
                                      textcoords='offset points')
         self.add_annotation(annotation)
     self.ax.set_title("Frame %i" % self.index)
     self.sc_r.set_offsets(np.column_stack((xs_r, ys_r)))
     self.sc_d.set_offsets(np.column_stack((xs_d, ys_d)))
     self.index += 1
     return self.sc_r,  #self.sc_d
Ejemplo n.º 2
0
    def testCorrectSettingOfMultipleAnnotationPoints(self):
        '''
        Test correct setting of annotations if mouse hovers over multiple scattered points
        '''
        # create mock values
        mocked_ax = plt.gca()
        mocked_ax.set_xlim((0, 1))
        mocked_ax.set_ylim((0, 1))
        mocked_scatter = mocked_ax.scatter(
            pd.Series([.5, .5, .5], dtype=float),
            pd.Series([.5, .5, .5], dtype=float))
        # 3 annotations on same point
        mocked_annotation_points_list = [
            txt.Annotation('dummy-annotation1', (.5, .5), visible=False),
            txt.Annotation('dummy-annotation2', (.5, .5), visible=False),
            txt.Annotation('dummy-annotation3', (.5, .5), visible=False)
        ]
        mocked_last_hov_index = -1
        mocked_mouse_event_on = bb.MouseEvent('mocked-mouse-event-on',
                                              plt.gcf().canvas, 322,
                                              242)  # on point (.5|.5)
        mocked_mouse_event_off = bb.MouseEvent('mocked-mouse-event-off',
                                               plt.gcf().canvas, 100,
                                               100)  # off point (.5|.5)

        # create object
        main_sequence = createUUT()
        with patch.object(main_sequence, '_annotation_points',
                          mocked_annotation_points_list):
            with patch.object(main_sequence, '_last_hov_anno_index',
                              mocked_last_hov_index):
                # call function on point
                main_sequence._annotate_point(mocked_mouse_event_on,
                                              mocked_scatter)

                # assert last hovered index
                self.assertEqual(0, main_sequence._last_hov_anno_index)
                self.assertTrue(
                    main_sequence._annotation_points[0].get_visible())

                # call function off point
                main_sequence._annotate_point(mocked_mouse_event_off,
                                              mocked_scatter)

                # assert last hovered index (no change in index off point)
                self.assertEqual(0, main_sequence._last_hov_anno_index)
                self.assertFalse(
                    main_sequence._annotation_points[0].get_visible())

                # call function on point again (annotation should change)
                main_sequence._annotate_point(mocked_mouse_event_on,
                                              mocked_scatter)

                # assert last hovered index change
                self.assertEqual(1, main_sequence._last_hov_anno_index)
                self.assertTrue(
                    main_sequence._annotation_points[1].get_visible())
Ejemplo n.º 3
0
 def update(self, _):
     xs_r, ys_r = ([], [])
     xs_d, ys_d = ([], [])
     self.clear_annotations()
     if self.index > self.end_index:
         self.index = self.end_index  # keep animating last step
     for creep_id, creep in enumerate(self.rad_creeps_history):
         was_born = (self.index >= creep.birth_count)
         has_died = (self.index >
                     len(creep.health_history) + creep.birth_count)
         if was_born and not has_died:
             i = self.index - creep.birth_count
             if i == len(creep.x_history):
                 x_coord = creep.x
                 y_coord = creep.y
                 health = creep.health
             else:
                 x_coord = creep.x_history[i]
                 y_coord = creep.y_history[i]
                 health = creep.health_history[i]
             xs_r.append(x_coord)
             ys_r.append(y_coord)
             creep_label = rad_labels[creep_id % len(rad_labels)]
             annotation = text.Annotation("%.2f %s" % (health, creep_label),
                                          xy=(x_coord, y_coord),
                                          xytext=(-35, -5),
                                          textcoords='offset points')
             self.add_annotation(annotation)
     for creep_id, creep in enumerate(self.dir_creeps_history):
         was_born = (self.index >= creep.birth_count)
         has_died = (self.index >
                     len(creep.health_history) + creep.birth_count)
         if was_born and not has_died:
             i = self.index - creep.birth_count
             if i == len(creep.x_history):
                 x_coord = creep.x
                 y_coord = creep.y
                 health = creep.health
             else:
                 x_coord = creep.x_history[i]
                 y_coord = creep.y_history[i]
                 health = creep.health_history[i]
             xs_d.append(x_coord)
             ys_d.append(y_coord)
             creep_label = dir_labels[creep_id % len(dir_labels)]
             annotation = text.Annotation("%s %.2f" % (creep_label, health),
                                          xy=(x_coord, y_coord),
                                          xytext=(10, -5),
                                          textcoords='offset points')
             self.add_annotation(annotation)
     self.ax.set_title("Frame %i" % self.index)
     self.sc_r.set_offsets(np.column_stack((xs_r, ys_r)))
     self.sc_d.set_offsets(np.column_stack((xs_d, ys_d)))
     self.index += 1
     return self.sc_r,  #self.sc_d
Ejemplo n.º 4
0
    def draw_arrow(self, offset, y=1, linespacing=1, *args, **kwargs):
        """
        draw_arrow(self, offset, y=1, linespacing=1, *args, **kwargs)

        Draw an arrow visual annotation.
        Any additional arguments and keyword arguments are provided to the `matplotlib.text.Annotation <https://matplotlib.org/3.2.2/tutorials/text/annotations.html>`_ class.

        :param offset: The x-offset where to draw the annotation.
        :type offset: float
        :param y: The y-position of the annotation.
        :type y: float
        :param linespacing: The linespacing of the accompanying text annotation.
        :type linespacing: float

        :return: The drawn arrow.
        :rtype: :class:`matplotlib.text.annotation`
        """

        figure = self.drawable.figure
        axes = self.drawable.axes

        arrow = text.Annotation('', xy=(offset + 0.025, y - linespacing / 2.),
                                xytext=(offset, y - linespacing / 2.),
                                xycoords=axes.transAxes, textcoords=axes.transAxes, arrowprops=kwargs)
        arrow.set_clip_on(False)
        axes.add_artist(arrow)

        return arrow
Ejemplo n.º 5
0
    def testCallbackConnectionToMotionEvent(self, mocked_ms_anno_func,
                                            mocked_con_func):
        '''
        Test that the annotation-callback is correctly connected to Figure.Canvas
        '''
        # assert mock
        self.assertIs(MainSequence._annotate_point, mocked_ms_anno_func)
        self.assertIs(bb.FigureCanvasBase.mpl_connect, mocked_con_func)

        # create mock values
        mocked_scatter = plt.gca().scatter(pd.Series(dtype=float),
                                           pd.Series(dtype=float))
        mocked_annotation_points_list = [
            txt.Annotation('dummy-annotation1', (.5, .5), visible=False)
        ]

        # create object
        main_sequence = createUUT()
        with patch.object(main_sequence, '_annotation_points',
                          mocked_annotation_points_list):
            # call function to test
            main_sequence._define_motion_annotation_callback(mocked_scatter)

            # assert correct function call
            mocked_con_func.assert_called_once()

            # assert call-arguments (canvas.mpl_connect)
            (call_event, call_lambda), _ = mocked_con_func.call_args
            self.assertEqual('motion_notify_event', call_event)

            # assert correct lambda connection by invoking it
            call_lambda(None)
            mocked_ms_anno_func.assert_called_once_with(None, mocked_scatter)
Ejemplo n.º 6
0
    def testCorrectFunctionCallsIfSinglePointSelected(self,
                                                      mocked_txt_get_vis_func,
                                                      mocked_txt_set_vis_func,
                                                      mocked_coll_cont_func):
        '''
        Test correct function calls if mouse hovers over a single scattered point
        '''
        # assert mocks
        self.assertIs(txt.Text.get_visible, mocked_txt_get_vis_func)
        self.assertIs(txt.Text.set_visible, mocked_txt_set_vis_func)
        self.assertIs(coll.PathCollection.contains, mocked_coll_cont_func)

        # create mock values
        mocked_ax = plt.gca()
        mocked_ax.set_xlim((0, 1))
        mocked_ax.set_ylim((0, 1))
        mocked_scatter = mocked_ax.scatter(pd.Series([.5], dtype=float),
                                           pd.Series([.5], dtype=float))
        mocked_coll_cont_func.return_value = True, {
            'ind': np.array([0], dtype=int)
        }
        mocked_annotation_points_list = [
            txt.Annotation('dummy-annotation1', (.5, .5), visible=False),
            txt.Annotation('dummy-annotation2', (.1, .5), visible=False),
            txt.Annotation('dummy-annotation3', (.7, .2), visible=False)
        ]
        mocked_mouse_event = bb.MouseEvent('mocked-mouse-event',
                                           plt.gcf().canvas, 322,
                                           242)  # on point (.5|.5)

        # create object
        main_sequence = createUUT()
        with patch.object(main_sequence, '_annotation_points',
                          mocked_annotation_points_list):
            # call function to test
            main_sequence._annotate_point(mocked_mouse_event, mocked_scatter)

            # assert function calls
            mocked_coll_cont_func.assert_called_once()
            call_list = [
                mocked_txt_get_vis_func(),
                mocked_txt_get_vis_func(),
                mocked_txt_get_vis_func()
            ]
            mocked_txt_get_vis_func.has_calls(call_list)  # called three times
            mocked_txt_set_vis_func.assert_called()
Ejemplo n.º 7
0
 def _init_offsetText(self, direction):
     x, y, va, ha = self._offsetText_pos[direction]
     self.offsetText = mtext.Annotation(
         "",
         xy=(x, y), xycoords="axes fraction",
         xytext=(0, 0), textcoords="offset points",
         color=rcParams['xtick.color'],
         horizontalalignment=ha, verticalalignment=va,
     )
     self.offsetText.set_transform(IdentityTransform())
     self.axes._set_artist_props(self.offsetText)
Ejemplo n.º 8
0
    def on_mplg_frames_mousemove(self, event):
        if self.trajectory_lines:
            for line in self.trajectory_lines:
                line.set_pickradius(5)
                if line.contains(event)[0]:
                    axes = self.ui.mplg_frames.all_sp_axes[0]
                    """:type : matplotlib.axes.Axes"""
                    xy = line.get_xydata()
                    point_ind = line.contains(event)[1]['ind'][0]

                    traj_name = self.ui.cB_trajectory_name.currentText()
                    trial_num = int(self.ui.label_trial_number_int.text())
                    analysis_path = os.path.join(self.session,
                                                 self.analysis_folder)
                    session_hdf5 = pd.HDFStore(analysis_path +
                                               self.analysis_file_name,
                                               mode='r')
                    trajectories = session_hdf5[self.trajectories_key]
                    """:type : pd.DataFrame"""
                    session_hdf5.close()
                    trajectories = trajectories.sort(self.frame_traj_point,
                                                     ascending=True)
                    frame = trajectories[trajectories[
                        self.name_traj_point] == traj_name][trajectories[
                            self.trial_traj_point] == trial_num][
                                self.frame_traj_point].tolist()[point_ind]
                    self.annotation = mpt.Annotation(
                        str(frame),
                        xy=tuple(xy[point_ind]),
                        xytext=(xy[point_ind][0], xy[point_ind][1] - 40),
                        xycoords='data',
                        textcoords='data',
                        horizontalalignment="left",
                        arrowprops=dict(arrowstyle="simple",
                                        connectionstyle="arc3,rad=-0.2"),
                        bbox=dict(boxstyle="round",
                                  facecolor="w",
                                  edgecolor="0.5",
                                  alpha=0.9))
                    axes.add_artist(self.annotation)
                    self.ui.mplg_frames.canvas.draw()
Ejemplo n.º 9
0
    def on_click(event):
        # First change the angel of the camera.
        azim, elev = ax.azim, ax.elev
        xa.view_init(elev=elev, azim=azim)

        # Then change the position of the labels.
        # I "brute force" this by simply removing the old labels and
        # make completely new ones.
        X_, Y_, foo = proj3d.proj_transform(P_merged[plot_a], P_merged[plot_b], \
                  P_merged[plot_c], xa.get_proj())
        labels_ = []
        for i in range(len(mg.labels)):
            label = mg.labels[i]
            label.remove()

            label_ = txt.Annotation(variables[i], xycoords = 'data', \
                        xy = (X_[i], Y_[i]))
            xa.add_artist(label_)
            labels_.append(label_)

        mg.labels = []
        for element in labels_:
            mg.labels.append(element)
Ejemplo n.º 10
0
    def __init__(self):
        # paths
        self.front_video_path = r"\front_video.avi"
        self.front_counter_path = r"\front_counter.csv"
        self.top_video_path = r"\top_video.avi"
        self.adc_path = r"\adc.bin"
        self.sync_path = r"\sync.bin"
        self.analysis_folder = "Analysis"
        self.analysis_file_name = r"\session.hdf5"

        # session.hdf5 structure
        self.fronttime_key = 'video/front/time'
        self.fronttrials_key = 'video/front/trials'
        self.toptime_key = 'video/top/time'
        self.paw_events_key = 'task/events/paws'
        self.good_trials_key = 'task/events/trials'
        self.trajectories_key = 'task/trajectories'

        # colums of trial_start_stop_info in session.hdf5
        self.trials_info_start_frame = "start frame"
        self.trials_info_stop_frame = "stop frame"
        self.trials_info_start_frame_time = "start frame time"
        self.trials_info_end_frame_time = "end frame time"
        self.trials_info_trial_duration = "trial duration"

        # colums of paw_events in session.hdf5
        self.blpaw = 'back left paw'
        self.brpaw = 'back right paw'
        self.flpaw = 'front left paw'
        self.frpaw = 'front right paw'
        self.trial_paw_event = 'trial of event'
        self.time_paw_event = 'time of event'

        # columns of trajectories in session.hdf5
        self.name_traj_point = 'name of trajectory point'
        self.trial_traj_point = 'trial of trajectory point'
        self.frame_traj_point = 'frame of trajectory point'
        self.time_traj_point = 'time of trajectory point'
        self.x_traj_point = 'X of trajectory point'
        self.y_traj_point = 'Y of trajectory point'

        # instance variables
        app = QtGui.QApplication(sys.argv)
        window = QtGui.QMainWindow()
        self.ui = Ui_RatShuttlingPawEventsGenerator()
        self.ui.setupUi(window)
        self.rec_freq = 8000
        self.ss_freq = 100
        self.session = ""
        self.front_video = ""
        self.top_video = ""
        self.corrected_frame_numbers = []
        self.cam_shutter_closing_samples = []
        self.analysis_exists = False
        self.data_loaded = self.ui.qled_data_loaded.value
        self.t = RunVideoThread(self.ui)
        self.trajectory_lines = []
        self.annotation = mpt.Annotation("", xy=(-1, -1))

        self.connect_slots()

        window.show()
        app.exec_()
Ejemplo n.º 11
0
    def testCorrectSettingOfMultipleAnnotationPointsOfAnotherGroup(self):
        '''
        Test correct setting of annotations if mouse hovers from one group of annotation points to another
        '''
        # create mock values
        mocked_ax = plt.gca()
        mocked_ax.set_xlim((0, 1))
        mocked_ax.set_ylim((0, 1))
        mocked_scatter = mocked_ax.scatter(
            pd.Series([.5, .5, .25, .25], dtype=float),
            pd.Series([.5, .5, .25, .25], dtype=float))
        # 2 annotations on same point
        mocked_annotation_points_list = [
            txt.Annotation('dummy-annotation1_1', (.5, .5), visible=False),
            txt.Annotation('dummy-annotation1_2', (.5, .5), visible=False),
            txt.Annotation('dummy-annotation2_1', (.25, .25), visible=False),
            txt.Annotation('dummy-annotation2_2', (.25, .25), visible=False),
        ]
        mocked_last_hov_index = -1
        mocked_mouse_event_on_1 = bb.MouseEvent('mocked-mouse-event-on-1',
                                                plt.gcf().canvas, 322,
                                                242)  # on point (.5|.5)
        mocked_mouse_event_on_2 = bb.MouseEvent('mocked-mouse-event-on-2',
                                                plt.gcf().canvas, 205,
                                                146)  # on point (.25|.25)
        mocked_mouse_event_off = bb.MouseEvent('mocked-mouse-event-off',
                                               plt.gcf().canvas, 100,
                                               100)  # off points

        # create object
        main_sequence = createUUT()
        with patch.object(main_sequence, '_annotation_points',
                          mocked_annotation_points_list):
            with patch.object(main_sequence, '_last_hov_anno_index',
                              mocked_last_hov_index):
                # call function on point 1
                main_sequence._annotate_point(mocked_mouse_event_on_1,
                                              mocked_scatter)

                # assert last hovered index
                self.assertEqual(0, main_sequence._last_hov_anno_index)
                self.assertTrue(
                    main_sequence._annotation_points[0].get_visible())

                # call function off point
                main_sequence._annotate_point(mocked_mouse_event_off,
                                              mocked_scatter)

                # assert last hovered index (no change in index off point)
                self.assertEqual(0, main_sequence._last_hov_anno_index)
                self.assertFalse(
                    main_sequence._annotation_points[0].get_visible())

                # call function on point 2
                main_sequence._annotate_point(mocked_mouse_event_on_2,
                                              mocked_scatter)

                # assert last hovered index change
                self.assertEqual(2, main_sequence._last_hov_anno_index)
                self.assertTrue(
                    main_sequence._annotation_points[2].get_visible())
import pandas as pd
import matplotlib
import matplotlib.lines as mpl_lines
import matplotlib.text as mpl_text
import matplotlib.pyplot as plt

print(matplotlib.get_backend())
df = pd.DataFrame(data=[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
                  index=[10, 20, 30],
                  columns=['a', 'b', 'c'])
axes = df.plot()
vert_line = mpl_lines.Line2D([20, 20], [0, 10], lw=2, color='black', axes=axes)
label = mpl_text.Annotation('vert_line', xy=(20.2, 2))
axes.add_artist(vert_line)
axes.add_artist(label)
# Does not stop script
#axes.figure.show()
plt.show()
Ejemplo n.º 13
0
def plot_3D(plot_this, R_2, graphtitle, observations, variables, \
    P_merged, Z_merged, colours = False, legend = None, height = None):
    # For "linking" the two plots. See comment below.
    def on_click(event):
        # First change the angel of the camera.
        azim, elev = ax.azim, ax.elev
        xa.view_init(elev=elev, azim=azim)

        # Then change the position of the labels.
        # I "brute force" this by simply removing the old labels and
        # make completely new ones.
        X_, Y_, foo = proj3d.proj_transform(P_merged[plot_a], P_merged[plot_b], \
                  P_merged[plot_c], xa.get_proj())
        labels_ = []
        for i in range(len(mg.labels)):
            label = mg.labels[i]
            label.remove()

            label_ = txt.Annotation(variables[i], xycoords = 'data', \
                        xy = (X_[i], Y_[i]))
            xa.add_artist(label_)
            labels_.append(label_)

        mg.labels = []
        for element in labels_:
            mg.labels.append(element)

    plt.close()

    plot_these = plot_this.split(',')
    plot_a = int(plot_these[0]) - 1
    plot_b = int(plot_these[1]) - 1
    plot_c = int(plot_these[2]) - 1

    # If in the input a fourth thing is given, it is interpreted as to
    # draw the labels onto each point.
    try:
        if plot_these[3] != 'risimif':
            plot_label = True
    except IndexError:
        plot_label = False

    percentage_a = int((R_2[plot_a + 1] - R_2[plot_a]) * 100)
    percentage_b = int((R_2[plot_b + 1] - R_2[plot_b]) * 100)
    percentage_c = int((R_2[plot_c + 1] - R_2[plot_c]) * 100)

    fig = plt.figure()
    # ax will be the scores.
    ax = fig.add_subplot(121, projection='3d')
    ax.set_xlabel('Score PC-%s (%s %%)' % (plot_a + 1, percentage_a))
    ax.set_ylabel('Score PC-%s (%s %%)' % (plot_b + 1, percentage_b))
    ax.set_zlabel('Score PC-%s (%s %%)' % (plot_c + 1, percentage_b))

    # xa will be the loadings
    xa = fig.add_subplot(122, projection='3d')
    xa.set_xlabel('Score PC-%s (%s %%)' % (plot_a + 1, percentage_a))
    xa.set_ylabel('Score PC-%s (%s %%)' % (plot_b + 1, percentage_b))
    xa.set_zlabel('Score PC-%s (%s %%)' % (plot_c + 1, percentage_b))
    text = "\n(Click + hold + drag on Scores-graph \nto align data in both graphs)"
    plt.title(graphtitle + text)

    if colours:
        ax.scatter(Z_merged[plot_a], Z_merged[plot_b], Z_merged[plot_c], \
                    c = colours, s = 100)
        xa.scatter(P_merged[plot_a], P_merged[plot_b], P_merged[plot_c], \
                    c = 'blue', s = 100)
    else:
        ax.scatter(Z_merged[plot_a],
                   Z_merged[plot_b],
                   Z_merged[plot_c],
                   c='blue')
        xa.scatter(P_merged[plot_a],
                   P_merged[plot_b],
                   P_merged[plot_c],
                   c='blue')

    if plot_label:
        # How to get labels into 3D-plot and make these updatable when the
        # plot is rotated, is a bit complicated. I've pieced solutions
        # together from here:
        # http://stackoverflow.com/questions/12903538/label-3dplot-points-update
        # and here:
        # http://stackoverflow.com/questions/12222397/ ...
        # ... python-and-remove-annotation-from-figure
        # I also could not make it work without the use of my_globals

        # Transform the coordinates to get the initial 2D-projection
        X_, Y_, foo = proj3d.proj_transform(P_merged[plot_a], P_merged[plot_b], \
                  P_merged[plot_c], xa.get_proj())

        mg.labels = []
        for i, text in enumerate(variables):
            # When the position of the label shall be updated, I first
            # remove the old label from the figure. To be able to do so,
            # this label needs a .remove()-method. A txt.Annotation()-object
            # provides such a function. This works together with add_artist
            # below.
            label = txt.Annotation(text, xycoords = 'data', \
                        xy = (X_[i], Y_[i]))
            # To be able to remove the label I need to add it like an
            # artist to the canvas.
            xa.add_artist(label)

            mg.labels.append(label)

    if legend != None:
        fig.figimage(legend, 0, fig.bbox.ymax - height, zorder=10)

    # Here I "link" the two subplots with each other so that if I
    # move one plot the other is moved, too.
    # See here: http://stackoverflow.com/questions/23424282/ ...
    # ... how-to-get-azimuth-and-elevation-from-a-matplotlib-figure
    fig.canvas.mpl_connect('motion_notify_event', on_click)
    #fig.canvas.mpl_connect('button_release_event', on_release)

    plt.show()
Ejemplo n.º 14
0
    def animate(
        self,
        pltdata={},
        interval=None,
        draw_fbd=False,
        data_stretch=False,
        figsize=(8, 4.5),
        blit=True,
    ):
        if interval is None:
            interval = self.dt * 1000

        if pltdata:
            fig, ax = plt.subplots(nrows=2, figsize=(figsize[0], figsize[1] * 2))
            ax0, ax1 = ax[0], ax[1]
        else:
            fig, ax0 = plt.subplots(figsize=figsize)

        if data_stretch:
            yls, yus = [], []
            for reskey in pltdata.keys():
                yls.append(self.data[reskey].values.min())
                yus.append(self.data[reskey].values.max())
            yl = min(yls)
            yu = max(yus)
            xl = self.data[reskey].index.min()
            xu = self.data[reskey].index.max()

            ax1.set_ylim((yl, yu))
            ax1.set_xlim((xl, xu))

        # axis setup
        ax0.set_xlim(self.xmin - self.pend.l * 2, self.xmax + self.pend.l * 2)
        ax0.set_ylim(self.ymin, self.ymax)
        ax0.set_aspect("equal")
        n_frames = (
            np.floor(len(self.data.index.values.tolist()) / self.speed).astype(int) - 10
        )
        # Initialize objects
        cart, mass, line = self._draw_objs()
        # Line for external force
        ext_force = patches.FancyArrow(0, 0, 1, 1, ec="red")
        # Line for control force
        ctrl_force = patches.FancyArrow(0, 0, 1, 1, ec="blue")

        if draw_fbd:
            pRx_f = patches.FancyArrow(0, 0, 1, 1, ec="k", zorder=4)
            pRy_f = patches.FancyArrow(0, 0, 1, 1, ec="k", zorder=4)
            pG_f = patches.FancyArrow(0, 0, 1, 1, ec="k", zorder=4)
            cRx_f = patches.FancyArrow(0, 0, 1, 1, ec="k", zorder=4)
            cRy_f = patches.FancyArrow(0, 0, 1, 1, ec="k", zorder=4)
            cG_f = patches.FancyArrow(0, 0, 1, 1, ec="k", zorder=4)
            cN_f = patches.FancyArrow(0, 0, 1, 1, ec="k", zorder=4)
            fbd_draws = (pRx_f, pRy_f, pG_f, cRx_f, cRy_f, cG_f, cN_f)

        ground = patches.Rectangle((-1000, -2000), 2000, 2000, fc="grey")
        # ground
        ground.set_zorder(-1)
        # Time text
        time_text = text.Annotation("", (4, 28), xycoords="axes points")

        plots = []
        for name, attrs in pltdata.items():
            if attrs["type"] == "line":
                (plot,) = ax1.plot(
                    [],
                    [],
                    label=attrs["label"],
                    linestyle=attrs["linestyle"],
                    color=attrs["color"],
                )
                plots.append(plot)
            elif attrs["type"] == "scatter":
                plot = ax1.scatter(
                    [],
                    [],
                    label=attrs["label"],
                    c=attrs["color"],
                    edgecolors=None,
                    marker=".",
                )
                plots.append(plot)
            else:
                raise ValueError("Wrong type or no type given.")

        def _init():
            plist = []
            ax0.add_patch(cart)
            ax0.add_patch(mass)
            ax0.add_artist(line)
            ax0.add_patch(ext_force)
            ax0.add_patch(ctrl_force)
            ax0.add_patch(ground)
            ax0.add_artist(time_text)
            if draw_fbd:
                for fbd in fbd_draws:
                    ax0.add_patch(fbd)
                plist.extend(fbd_draws)
            plist = [ground, cart, mass, line, ext_force, ctrl_force, time_text]
            plist.extend(plots)
            return plist

        def _animate(i):

            i = np.floor(i * self.speed).astype(int)

            retobjs = []
            # limits for y-axis

            if pltdata:
                scyall = [0]
                for (name, attrs), sc in zip(pltdata.items(), plots):
                    l = max(i - attrs["plotpoints"], 0)
                    scx = self.data.index[l:i]
                    scy = self.data[name].values[l:i]
                    if attrs["type"] == "scatter":
                        sc.set_offsets(np.column_stack([scx, scy]))
                    elif attrs["type"] == "line":
                        sc.set_data(scx, scy)
                    scyall.extend(list(scy))
                retobjs.extend(plots)
                yl = min(-0.1, min(scyall))
                yu = max(0.1, max(scyall))
                xl = self.data.index[
                    max(i - max([p["plotpoints"] for p in pltdata.values()]), 0)
                ]
                xu = self.data.index[i] + 1e-5
                if not data_stretch:
                    ax1.set_ylim((yl, yu))
                    ax1.set_xlim((xl, xu))
                ax1.legend(loc=2)

            # draw cart
            state_xi = list(self.data[("state", "x")].values)[i]
            state_ti = list(self.data[("state", "t")].values)[i]
            self._draw_cart(cart, mass, line, state_xi, state_ti)
            # external force
            self.draw_force(
                ext_force,
                list(self.data[("forces", "forces")].values)[i],
                state_xi,
                0.6,
            )
            self.draw_force(
                ctrl_force,
                list(self.data[("control action", "control action")].values)[i],
                state_xi,
                0.5,
            )
            time_text.set_text(r"t=" + str(round(self.data.index[i], 3)))

            # fbds
            if draw_fbd:
                # pend x
                px = state_xi - self.pend.l * np.sin(state_ti)
                # pend y
                py = self.cart_h + self.pend.l * np.cos(state_ti)
                # draw reaction force (pendulum)
                self.draw_pend_fbd(
                    pRx_f,
                    self.data[("forces", "pRx")].values[i],
                    np.array([1, 0]),
                    np.array([px, py]),
                )
                self.draw_pend_fbd(
                    pRy_f,
                    self.data[("forces", "pRy")].values[i],
                    np.array([0, 1]),
                    np.array([px, py]),
                )
                self.draw_pend_fbd(
                    pG_f,
                    self.data[("forces", "pG")].values[i],
                    np.array([0, 1]),
                    np.array([px, py]),
                )
                retobjs.extend((pRx_f, pRy_f, pG_f))
                # draw reaction force (cart)
                cx, cy = state_xi, self.cart_h
                self.draw_cart_fbd(
                    cRx_f,
                    self.data[("forces", "cRx")].values[i],
                    np.array([1, 0]),
                    np.array([cx, cy]),
                )
                self.draw_cart_fbd(
                    cRy_f,
                    self.data[("forces", "cRy")].values[i],
                    np.array([0, 1]),
                    np.array([cx, cy]),
                )
                self.draw_cart_fbd(
                    cG_f,
                    self.data[("forces", "cG")].values[i],
                    np.array([0, 1]),
                    np.array([cx, cy]),
                )
                self.draw_cart_fbd(
                    cN_f,
                    self.data[("forces", "cN")].values[i],
                    np.array([0, 1]),
                    np.array([cx, cy]),
                )

                retobjs.extend((cRx_f, cRy_f, cG_f, cN_f))

            retobjs.extend([ground, cart, mass, line, ext_force, ctrl_force, time_text])
            return retobjs

        anim_running = True

        def onClick(event):
            nonlocal anim_running
            if anim_running:
                anim.event_source.stop()
                anim_running = False
            else:
                anim.event_source.start()
                anim_running = True

        fig.canvas.mpl_connect("button_press_event", onClick)
        anim = FuncAnimation(
            fig,
            _animate,
            frames=n_frames,
            init_func=_init,
            blit=blit,
            interval=interval,
        )

        return anim
Ejemplo n.º 15
0
    def display_viz(self):
        '''
        Display (show) the animated visualization. This function calls plt.show()

        Returns
        -------
        None
        '''
        # axis setup
        viz = plt.figure(figsize=self.viz_window_size)
        ax = plt.axes()
        plt.axis('scaled')
        ax.set_xlim(self.viz_xmin, self.viz_xmax)
        ax.set_ylim(-self.pendulum.l - 1,
                    self.pendulum.l + self.cart_height + 1)

        # add elements
        cart = patches.Rectangle(
            (-self.cart_display_width * 0.5, self.cart_height),
            width=self.cart_display_width,
            height=-self.cart_height,
            ec='black',
            fc='seagreen')
        mass = patches.Circle((0, 0),
                              radius=self.pend_radius,
                              fc='skyblue',
                              ec='black')
        line = patches.FancyArrow(0, 0, 1, 1)
        force = patches.FancyArrow(0, 0, 1, 1, ec='red')
        ctrl_force = patches.FancyArrow(0, 0, 1, 1, ec='blue')
        ground = patches.Rectangle((-1000, -2000), 2000, 2000, fc='lightgrey')
        ground.set_zorder(-1)

        # text
        angle_text = text.Annotation('', (4, 4), xycoords='axes points')
        x_text = text.Annotation('', (4, 16), xycoords='axes points')
        time_text = text.Annotation('', (4, 28), xycoords='axes points')

        def init():
            '''
            Initialize elements in animation
            '''
            ax.add_patch(cart)
            ax.add_patch(mass)
            ax.add_patch(line)
            ax.add_patch(force)
            ax.add_patch(ctrl_force)
            ax.add_patch(ground)

            ax.add_artist(angle_text)
            ax.add_artist(x_text)
            ax.add_artist(time_text)
            return [
                ground, cart, mass, line, force, ctrl_force, angle_text,
                x_text, time_text
            ]

        # matplotlib animate doesn't play nice with dataframes :(
        animate_x = data['x'].values.tolist()[::self.frameskip]
        animate_theta = data['theta'].values.tolist()[::self.frameskip]
        animate_force = data['forces'].values.tolist()[::self.frameskip]
        animate_cforce = data['control action'].values.tolist()[::self.
                                                                frameskip]
        animate_times = data.index.values.tolist()[::self.frameskip]
        frames = len(animate_times)

        def animate(i):
            # get animation frames
            x = -animate_x[i]  # position
            th = animate_theta[i]  # angle
            u = animate_force[i]  # disturbance force applied at i
            c = animate_cforce[i]  # controller force applied at i

            # animate disturbance force
            if u > 0.0:
                force_begin = (x + .5 * self.cart_display_width,
                               .5 * self.cart_height)
                force_end = (x + .5 * self.cart_display_width +
                             np.sqrt(.1 * u), .5 * self.cart_height)
                force.set_xy((force_begin, force_end))
                force.set_linewidth(np.sqrt(u))
                force.set_visible(True)
            elif u < 0.0:
                force_begin = (x - .5 * self.cart_display_width,
                               .5 * self.cart_height)
                force_end = (x - .5 * self.cart_display_width -
                             np.sqrt(.1 * np.abs(u)), .5 * self.cart_height)
                force.set_xy((force_begin, force_end))
                force.set_linewidth(np.sqrt(np.abs(u)))
                force.set_visible(True)
            else:
                force.set_visible(False)

            # animate control force
            if c > 0.0:
                ctrl_force_begin = (x + .5 * self.cart_display_width,
                                    0.9 * self.cart_height)
                ctrl_force_end = (x + .5 * self.cart_display_width +
                                  np.sqrt(.1 * np.abs(c)),
                                  0.9 * self.cart_height)
                ctrl_force.set_xy((ctrl_force_begin, ctrl_force_end))
                ctrl_force.set_linewidth(np.sqrt(np.abs(c)))
                ctrl_force.set_visible(True)
            elif c < 0.0:
                ctrl_force_begin = (x - .5 * self.cart_display_width,
                                    0.9 * self.cart_height)
                ctrl_force_end = (x - .5 * self.cart_display_width -
                                  np.sqrt(.1 * np.abs(c)),
                                  0.9 * self.cart_height)
                ctrl_force.set_xy((ctrl_force_begin, ctrl_force_end))
                ctrl_force.set_linewidth(np.sqrt(np.abs(c)))
                ctrl_force.set_visible(True)
            else:
                ctrl_force.set_visible(False)

            # display cart/pend
            # True cart x, y is centered at the point where the line connects. But matplotlib draws
            # rectangles from the corner. So we have to add/subtract half the cart width and the full cart
            # height in order to display the cart properly.
            cartxy_true = (x, self.cart_height)
            massxy = (x + self.pendulum.l * np.sin(th),
                      self.cart_height + self.pendulum.l * np.cos(th))
            line.set_xy((massxy, cartxy_true))
            cartxy_visible = (x - self.cart_display_width * .5,
                              self.cart_height)
            mass.set_center(massxy)
            cart.set_xy(cartxy_visible)

            # display text
            angle_text.set_text(r"$\theta=$" + str(round(animate_theta[i], 3)))
            x_text.set_text(r"$x=$" + str(round(animate_x[i], 3)))
            time_text.set_text(r"t=" + str(round(animate_times[i], 3)))
            return [
                ground, cart, mass, line, force, ctrl_force, angle_text,
                x_text, time_text
            ]

        def run_animation():
            anim_running = True
            animation = FuncAnimation(viz,
                                      animate,
                                      frames,
                                      init_func=init,
                                      blit=True,
                                      interval=16)

            def onClick(event):
                nonlocal anim_running
                if anim_running:
                    animation.event_source.stop()
                    anim_running = False
                else:
                    animation.event_source.start()
                    anim_running = True

            viz.canvas.mpl_connect('button_press_event', onClick)
            if self.save:
                animation.save('./video.mp4', fps=30, bitrate=1000)

        run_animation()
        plt.show()