Esempio n. 1
0
class Video:
    ''' Writer for a video recording of the real time plot. '''
    def __init__(self, name='recording.mp4', fps=10):
        self.name = name
        self.writer = FFMpegWriter(fps=fps)

    def setup(self, fig):
        self.writer.setup(fig, self.name)
Esempio n. 2
0
class Canvas():
    def __init__(self, fig, axes):
        self.fig = fig
        self.axes = axes
        self.fps = 10
        view_port = 80
        self.view_port = view_port
        self.axes.set_xlim([-view_port, view_port])
        self.axes.set_ylim([-view_port, view_port])
        self.fig.canvas.draw()
        self.report_text = None
        background = fig.canvas.copy_from_bbox(axes.bbox)
        self.writer = None

    def set_view_port(self, view_port):
        self.view_port = view_port
        self.axes.set_xlim([-view_port, view_port])
        self.axes.set_ylim([-view_port, view_port])
        self.fig.canvas.draw()

    def saving(self, dpi=100, fps=10):
        self.fps = fps
        self.writer = FFMpegWriter(fps=self.fps)
        self.writer.setup(self.fig, "writer_test.mp4", dpi)

    def write(self):
        if self.writer is None:
            raise ValueError(
                "writer has not been initialised, try calling Canvas.saving()")
        self.writer.grab_frame()

    def finish(self):
        if self.writer is None:
            raise ValueError(
                "writer has not been initialised, try calling Canvas.saving()")
        self.writer.finish()
        #self.writer.cleanup()

    def report(self, messages):
        if self.report_text is None:
            self.report_text = []
            x_pos = -self.view_port * 0.95
            y_pos = self.view_port * 0.95
            for msg in messages:
                self.report_text.append(
                    plt.text(x_pos, y_pos, "{}: {:.0f}".format(msg[0],
                                                               msg[1])))
                y_pos -= self.view_port * 0.05
        else:
            for i_msg, msg in enumerate(messages):
                text = self.report_text[i_msg]
                text.set_text("{}: {:.0f}".format(msg[0], msg[1]))
Esempio n. 3
0
class SaveAnimation:
    def __init__(self, figNumber, outfile):      
        #self.fig = plt.gcf()    #assuming we are focused on 1 plt
        self.fig = plt.figure(figNumber)
        metadata = dict(title='Movie Test', artist='Matplotlib', comment='Movie support!')
        self.writer = FFMpegWriter(fps=15, metadata=metadata)
        self.writer.setup(self.fig, outfile, dpi=100)

    def update(self):
        #self.fig = animateObject.fig #may not be necessary here...
        self.writer.grab_frame()

    def save(self):
        self.writer.finish()
Esempio n. 4
0
def generate_fig5():
    for dfs in DATA_FILES:
        common.fill_data(dfs)

    X = np.array(SAMPLE_SIZES)
    Y = np.array([i.episodes for i in DATA_FILES[0][0].data])
    Z_avg = np.zeros((X.shape[0], Y.shape[0]))
    Z_min = np.zeros((X.shape[0], Y.shape[0]))
    Z_max = np.zeros((X.shape[0], Y.shape[0]))

    for x, dfs in enumerate(DATA_FILES):
        for y in range(len(dfs[0].data)):
            z = [df.data[y].win_rate for df in dfs]
            Z_avg[x, y] = np.mean(z)
            Z_min[x, y] = min(z)
            Z_max[x, y] = max(z)

    with style.context("ggplot"):
        fig = Figure()
        canvas = FigureCanvas(fig)
        ax = fig.add_subplot(111, projection='3d')
        ax.plot_wireframe(X[:, None], Y[None, :], Z_avg, rstride=3, cstride=8)
        # ax.plot_wireframe(X[:, None], Y[None, :], Z_min, rstride=3, cstride=8)
        # ax.plot_wireframe(X[:, None], Y[None, :], Z_max, rstride=3, cstride=8)
        ax.set_xlabel("Sample Size")
        ax.set_ylabel("SGD steps")
        ax.set_zlabel("Win Rate")
        ax.view_init(40, -60)
        ax.ticklabel_format(style="sci", scilimits=(-2, 2))

        common.save_next_fig(PART_NUM, fig)

        writer = FFMpegWriter(fps=20)
        writer.setup(fig, "figures/part{}/movie.mp4".format(PART_NUM))
        writer.grab_frame()

        for i in range(-60, 360 * 2 - 60, 1):
            ax.view_init(40, i)
            writer.grab_frame()
        writer.finish()
class Visualizer(object):
    """Takes as input a psycholab env and visualize its map and rewards."""

    UNIT_SIZE = 15
    ROWS = 2

    def __init__(self,
                 env,
                 fps=1000,
                 by_episode=False,
                 save_video=False,
                 directory=''):
        self.env = env
        self.fps = fps
        self.by_episode = by_episode
        self.rewards = np.zeros(self.env.num_players)
        # Num_players + 1 for the average reward:
        self.rewards_data = [[[], [], []]
                             for _ in range(self.env.num_players + 1)]

        # Fig 1 = env map
        # Fig 2 = players rewards
        self.fig = plt.figure(figsize=(self.UNIT_SIZE,
                                       self.UNIT_SIZE * self.ROWS))
        # self.fig.subplots_adjust(
        # left=0, bottom=0, right=1, top=1, wspace=None, hspace=None)

        self.save_video = save_video
        if self.save_video:
            self.dir = directory
            self.metadata = dict(title='Movie Test',
                                 artist='Matplotlib',
                                 comment='Movie support!')
            self.writer = FFMpegWriter(fps=10, metadata=self.metadata)
            self.writer.setup(self.fig, self.dir + 'game.mp4', 300)

        # player colors for plot:
        self.average_color = (0, 0, 0)
        self.players_colors = []
        for player_name in self.env.players_order:
            player_color = [(255 + c) / (2 * 255)
                            for c in self.env.players[player_name].color]
            self.players_colors.append(player_color)

        self._init_game()
        self._init_rewards()

        self._eps = 0.1
        self._min = 0
        self._max = 0

        self.env_episodes = 0
        self.freq = 1 / self.fps

        self.fig.tight_layout()
        plt.ion()
        plt.show(block=False)

    @property
    def steps(self):
        return self.env.steps

    @property
    def num_players(self):
        return self.env.num_players

    @property
    def num_actions(self):
        return self.env.num_actions

    @property
    def num_states(self):
        return self.env.num_states

    def _init_game(self):
        """Initializes the current game map."""

        self.game_axes = plt.subplot(211)
        self.game_image = self.game_axes.imshow(self.env.render(),
                                                interpolation='nearest')

    def _init_rewards(self):
        """Initializes the plots of the evolution of all players rewards."""

        self.reward_axes = plt.subplot(212)
        self.reward_axes.title.set_text('players returns')
        self.reward_axes.set_xlabel('env episodes')
        self.reward_plots = []
        for player_color in self.players_colors:
            self.reward_plots.append(
                self.reward_axes.plot([], [], lw=2, color=player_color))

        # Average reward plot:
        self.reward_plots.append(
            self.reward_axes.plot([], [], lw=2, color=self.average_color))

    def _smooth_data(self, data, percentage=10):
        """Uses percentage% last episodes moving average to smooth."""

        win_size = int(self.env_episodes / percentage) + 1
        reward_array = np.array(data)
        average = pd.Series(reward_array).rolling(
            window=win_size).mean().iloc[win_size - 1:].values
        smoothed = np.zeros(self.env_episodes)
        smoothed[-len(average):] = average
        return smoothed

    def _update(self, done, infos):
        """Update the current visualization of the game."""

        # The game map:
        self.game_axes.title.set_text(infos)
        self.game_image.set_data(self.env.render())

        # Players rewards:
        if done:
            self.env_episodes += 1

            for player, reward in enumerate(self.rewards):
                self.rewards_data[player][0].append(self.env_episodes)
                self.rewards_data[player][1].append(reward)
                self.rewards_data[player][2] = self._smooth_data(
                    self.rewards_data[player][1])

            self.rewards_data[-1][0].append(self.env_episodes)
            self.rewards_data[-1][1].append(self.rewards.mean())
            self.rewards_data[-1][2] = self._smooth_data(
                self.rewards_data[-1][1])

            # Update min and max values for reward plots:
            big_array = np.stack([data[2] for data in self.rewards_data])
            self._max = np.max(big_array)
            self._min = np.min(big_array)

            for player_plot, reward_data in zip(self.reward_plots,
                                                self.rewards_data):
                player_plot[0].set_xdata(reward_data[0])
                player_plot[0].set_ydata(reward_data[2])  # uses smoothed value
                self.reward_axes.set_xlim(0, np.max(reward_data[0]) + 1)
                self.reward_axes.set_ylim(self._min - 1, self._max + 1)

            self.rewards *= 0

        self.fig.tight_layout()
        self.fig.canvas.draw()

        if self.save_video:
            self.writer.grab_frame()

    def reset(self):
        self.rewards *= 0
        observation = self.env.reset()
        return observation

    def step(self, action):
        """Do an environment step, update the visualization."""

        observation, rewards, done, infos = self.env.step(action)
        self.rewards += rewards
        if self.by_episode:
            if done:
                self._update(done, infos)
        else:
            self._update(done, infos)

        time.sleep(self.freq)
        return observation, rewards, done, infos

    def obs2state(self, obs):
        return self.env.obs2state(obs)

    def finish(self):
        if self.save_video:
            self.writer.finish()
        else:
            pass
Esempio n. 6
0
class HighScoreOptimiserPlot(object):
    def __init__(self, optimiser, problem, history, xpar_name, ypar_name,
                 movie_filename):

        self.optimiser = optimiser
        self.problem = problem
        self.chains = optimiser.chains(problem, history)
        self.history = history
        self.xpar_name = xpar_name
        self.ypar_name = ypar_name
        self.fontsize = 10.
        self.movie_filename = movie_filename
        self.show = False
        self.iiter = 0
        self.iiter_last_draw = 0
        self._volatile = []
        self._blocks_complete = set()

    def start(self):
        nfx = 1
        nfy = 1

        problem = self.problem

        ixpar = problem.name_to_index(self.xpar_name)
        iypar = problem.name_to_index(self.ypar_name)

        mpl_init(fontsize=self.fontsize)
        fig = plt.figure(figsize=(9.6, 5.4))
        labelpos = mpl_margins(fig,
                               nw=nfx,
                               nh=nfy,
                               w=7.,
                               h=5.,
                               wspace=7.,
                               hspace=2.,
                               units=self.fontsize)

        xpar = problem.parameters[ixpar]
        ypar = problem.parameters[iypar]

        if xpar.unit == ypar.unit:
            axes = fig.add_subplot(nfy, nfx, 1, aspect=1.0)
        else:
            axes = fig.add_subplot(nfy, nfx, 1)

        labelpos(axes, 2.5, 2.0)

        axes.set_xlabel(xpar.get_label())
        axes.set_ylabel(ypar.get_label())

        axes.get_xaxis().set_major_locator(plt.MaxNLocator(4))
        axes.get_yaxis().set_major_locator(plt.MaxNLocator(4))

        xref = problem.get_reference_model()
        axes.axvline(xpar.scaled(xref[ixpar]), color='black', alpha=0.3)
        axes.axhline(ypar.scaled(xref[iypar]), color='black', alpha=0.3)

        self.fig = fig
        self.problem = problem
        self.xpar = xpar
        self.ypar = ypar
        self.axes = axes
        self.ixpar = ixpar
        self.iypar = iypar
        from matplotlib import colors
        n = self.optimiser.nbootstrap + 1
        hsv = num.vstack((num.random.uniform(0., 1., n),
                          num.random.uniform(0.5, 0.9, n), num.repeat(0.7,
                                                                      n))).T

        self.bcolors = colors.hsv_to_rgb(hsv[num.newaxis, :, :])[0, :, :]
        self.bcolors[0, :] = [0., 0., 0.]

        bounds = self.problem.get_combined_bounds()

        from grond import plot
        self.xlim = plot.fixlim(*xpar.scaled(bounds[ixpar]))
        self.ylim = plot.fixlim(*ypar.scaled(bounds[iypar]))

        self.set_limits()

        from matplotlib.colors import LinearSegmentedColormap

        self.cmap = LinearSegmentedColormap.from_list('probability',
                                                      [(1.0, 1.0, 1.0),
                                                       (0.5, 0.9, 0.6)])

        self.writer = None
        if self.movie_filename:
            from matplotlib.animation import FFMpegWriter

            metadata = dict(title=problem.name, artist='Grond')

            self.writer = FFMpegWriter(fps=30,
                                       metadata=metadata,
                                       codec='libx264',
                                       bitrate=200000,
                                       extra_args=[
                                           '-pix_fmt', 'yuv420p', '-profile:v',
                                           'baseline', '-level', '3', '-an'
                                       ])

            self.writer.setup(self.fig, self.movie_filename, dpi=200)

        if self.show:
            plt.ion()
            plt.show()

    def set_limits(self):
        self.axes.autoscale(False)
        self.axes.set_xlim(*self.xlim)
        self.axes.set_ylim(*self.ylim)

    def draw_frame(self):

        self.chains.goto(self.iiter + 1)
        msize = 15.

        for artist in self._volatile:
            artist.remove()

        self._volatile[:] = []

        nblocks = self.iiter // 100 + 1

        models = self.history.models[:self.iiter + 1]

        for iblock in range(nblocks):
            if iblock in self._blocks_complete:
                continue

            models_add = self.history.models[iblock *
                                             100:min((iblock + 1) *
                                                     100, self.iiter + 1)]

            fx = self.problem.extract(models_add, self.ixpar)
            fy = self.problem.extract(models_add, self.iypar)
            collection = self.axes.scatter(self.xpar.scaled(fx),
                                           self.ypar.scaled(fy),
                                           color='black',
                                           s=msize * 0.15,
                                           alpha=0.2,
                                           edgecolors='none')

            if models_add.shape[0] != 100:
                self._volatile.append(collection)
            else:
                self._blocks_complete.add(iblock)

        for ichain in range(self.chains.nchains):

            iiters = self.chains.indices(ichain)
            fx = self.problem.extract(models[iiters, :], self.ixpar)
            fy = self.problem.extract(models[iiters, :], self.iypar)

            nfade = 20
            t1 = num.maximum(0.0, iiters - (models.shape[0] - nfade)) / nfade
            factors = num.sqrt(1.0 - t1) * (1.0 + 15. * t1**2)

            msizes = msize * factors

            paths = self.axes.scatter(self.xpar.scaled(fx),
                                      self.ypar.scaled(fy),
                                      color=self.bcolors[ichain],
                                      s=msizes,
                                      alpha=0.5,
                                      edgecolors='none')

            self._volatile.append(paths)

        phase, iiter_phase = self.optimiser.get_sampler_phase(self.iiter)

        np = 1000
        models_prob = num.zeros((np, self.problem.nparameters))
        for ip in range(np):
            models_prob[ip, :] = phase.get_sample(self.problem, iiter_phase,
                                                  self.chains)

        fx = self.problem.extract(models_prob, self.ixpar)
        fy = self.problem.extract(models_prob, self.iypar)

        if False:

            bounds = self.problem.get_combined_bounds()

            nx = 20
            ny = 20
            x_edges = num.linspace(bounds[self.ixpar][0],
                                   bounds[self.ixpar][1], nx)
            y_edges = num.linspace(bounds[self.iypar][0],
                                   bounds[self.iypar][1], ny)

            p, _, _ = num.histogram2d(fx, fy, bins=(x_edges, y_edges))
            x, y = num.meshgrid(x_edges, y_edges)

            artist = self.axes.pcolormesh(self.xpar.scaled(x),
                                          self.ypar.scaled(y),
                                          p,
                                          cmap=self.cmap,
                                          zorder=-1)

            self._volatile.append(artist)

        else:
            collection = self.axes.scatter(self.xpar.scaled(fx),
                                           self.ypar.scaled(fy),
                                           color='green',
                                           s=msize * 0.15,
                                           alpha=0.2,
                                           edgecolors='none')

            self._volatile.append(collection)

        if self.writer:
            self.writer.grab_frame()

        artist = self.axes.annotate(
            '%i (%s)' % (self.iiter + 1, phase.__class__.__name__),
            xy=(0., 1.),
            xycoords='axes fraction',
            xytext=(self.fontsize / 2., -self.fontsize / 2.),
            textcoords='offset points',
            ha='left',
            va='top',
            fontsize=self.fontsize,
            fontstyle='normal')

        self._volatile.append(artist)

        if self.show:
            plt.draw()

        self.iiter_last_draw = self.iiter + 1

    def finish(self):
        if self.writer:
            self.writer.finish()

        if self.show:
            plt.show()
            plt.ioff()

    def render(self):
        self.start()

        while self.iiter < self.history.nmodels:
            logger.info('rendering frame %i/%i' %
                        (self.iiter + 1, self.history.nmodels))
            self.draw_frame()
            self.iiter += 1

        self.finish()
Esempio n. 7
0
    trace_vid_plt[i], = trace_ax.plot(np.arange(trace_len),
                                      np.arange(trace_len),
                                      '-',
                                      label=cell_label)
    hexcol = trace_vid_plt[i].get_c()
    contour_colours[i] = np.array(mcolors.hex2color(hexcol)) * 255

ca_max = np.quantile(C[selected_cells,:], 0.995) + len(selected_cells) * 2.0
ca_min = np.quantile(C[selected_cells,:], 0.05)
plt.ylim(ca_min, ca_max)
trace_ax.legend(loc='upper left', fontsize='x-small')
trace_ax.axes.get_yaxis().set_visible(False)
trace_ax.patch.set_visible(False)

writer.setup(fig, video_traces_outputfile, dpi=500)
print('Output file: ' + video_traces_outputfile)
while has_frame:
    if row_idx >= len(tracking_df):
        break

    if frame_idx == tracking_df.loc[row_idx, 'frame']:
        pos_x = tracking_df.loc[row_idx, 'smooth_x']
        pos_y = tracking_df.loc[row_idx, 'smooth_y']
        if pos_x >= 0 and pos_y >= 0:
            valid_pos[pos_idx] = (int(pos_x), int(pos_y))
            pos_idx += 1
        row_idx += 1

    for i in range(1, pos_idx):
        cv2.line(frame, valid_pos[i-1], valid_pos[i], (255, 0, 0), 2)
Esempio n. 8
0
    '/Users/cusgadmin/Documents/UCB/Academics/SSastry/Multi_agent_competition/'
)

print(colored('Testing learnt policy from model file {} for {} games!'.\
  format(args.model,args.num_test),'red'))
start_time = time.time()
model = GAIL.load(args.model)
env = gym.make('gym_pursuitevasion_small:pursuitevasion_small-v0')
g = 1
obs = env.reset(ep=g)
e_win_games = int(0)
env.render(mode='human', highlight=True, ep=g)
if args.save:
    metadata = dict(title='Game')
    writer = FFMpegWriter(fps=5, metadata=metadata)
    writer.setup(env.window.fig, "test_game.mp4", 300)
    writer.grab_frame()
while True:
    action, _states = model.predict(obs)
    obs, rewards, done, e_win = env.step(action)
    env.render(mode='human', highlight=True, ep=g)
    if args.save:
        writer.grab_frame()
    if done:
        g += 1
        obs = env.reset(ep=g)
        if g % 100 == 0:
            print('Playing game {}'.format(g))
        if e_win:
            e_win_games += 1
        if g > args.num_test:
Esempio n. 9
0
class Movie:
    """Class for creating movies from matplotlib figures using ffmpeg

    Note:
        Internally, this class uses :class:`matplotlib.animation.FFMpegWriter`.
        Note that the `ffmpeg` program needs to be installed in a system path,
        so that `matplotlib` can find it.


    Warning:
        The movie is only fully written after the :meth:`save` method has been called.
        To aid with this, it is best practice to use a contextmanager:

        .. code-block:: python

            with Movie("output.mp4") as movie:
                movie.add_figure()
    """
    def __init__(self,
                 filename: str,
                 framerate: float = 30,
                 dpi: float = None,
                 **kwargs):
        r"""
        Args:
            filename (str):
                The filename where the movie is stored. The suffix of this path
                also determines the default movie codec.
            framerate (float):
                The number of frames per second, which determines how fast the
                movie will appear to run.
            dpi (float):
                The resolution of the resulting movie
            \**kwargs:
                Additional parameters are used to initialize
                :class:`matplotlib.animation.FFMpegWriter`. Here, we can for instance
                set the bit rate of the resulting video using the `bitrate` parameter.
        """
        self.filename = str(filename)
        self.framerate = framerate
        self.dpi = dpi
        self.kwargs = kwargs

        # check whether ffmpeg is available
        if not self.is_available():
            raise RuntimeError(
                "FFMpegWriter is not available. This is most likely because a suitable "
                "installation of FFMpeg was not found. See ffmpeg.org for how to "
                "install it properly on your system.")

        # check whether the path to which the movie is written is available
        folder = pathlib.Path(self.filename).parent
        if not folder.exists() or not folder.is_dir():
            raise OSError(f"Folder `{folder}` does not exist")

        self._writer = None

    @classmethod
    def is_available(cls) -> bool:
        """check whether the movie infrastructure is available

        Returns:
            bool: True if movies can be created
        """
        from matplotlib.animation import FFMpegWriter

        return FFMpegWriter.isAvailable()  # type: ignore

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, exc_tb):
        self._end()
        return False

    def _end(self):
        """clear up temporary things if necessary"""
        if self._writer is not None:
            self._writer.finish()
        self._writer = None

    def add_figure(self, fig=None):
        """adds the figure `fig` as a frame to the current movie

        Args:
            fig (:class:`~matplotlib.figures.Figure`):
                The plot figure that is added to the movie
        """
        if fig is None:
            import matplotlib.pyplot as plt

            fig = plt.gcf()

        if self._writer is None:
            # initialize a new writer
            from matplotlib.animation import FFMpegWriter

            self._writer = FFMpegWriter(self.framerate, **self.kwargs)
            self._writer.setup(fig, self.filename, dpi=self.dpi)

        else:
            # update the figure reference on a given writer, since it might have
            # changed from the last call. In particular, this will happen when
            # figures are shown using the `inline` backend.
            self._writer.fig = fig

        # we need to impose a white background to get reasonable antialiasing
        self._writer.grab_frame(facecolor="white")

    def save(self):
        """convert the recorded images to a movie using ffmpeg"""
        self._end()
Esempio n. 10
0
class Movie:
    """ Class for creating movies from matplotlib figures using ffmpeg
    
    Note:
        Internally, this class uses :class:`matplotlib.animation.FFMpegWriter`.
        Note that the `ffmpeg` program needs to be installed in a system path,
        so that `matplotlib` can find it. 
    """
    def __init__(self,
                 filename: str,
                 framerate: float = 30,
                 dpi: float = None,
                 **kwargs):
        r"""
        Args:
            filename (str):
                The filename where the movie is stored. The suffix of this path
                also determines the default movie codec.
            framerate (float):
                The number of frames per second, which determines how fast the
                movie will appear to run.
            dpi (float):
                The resolution of the resulting movie
            \**kwargs:
                Additional parameters are used to initialize
                :class:`matplotlib.animation.FFMpegWriter`.
        """
        self.filename = str(filename)
        self.framerate = framerate
        self.dpi = dpi
        self.kwargs = kwargs

        # test whether ffmpeg is available
        from matplotlib.animation import FFMpegWriter
        if not FFMpegWriter.isAvailable():
            raise RuntimeError('FFMpegWriter is not available. This is most '
                               'likely because a suitable installation of '
                               'FFMpeg was not found. See ffmpeg.org for how '
                               'to install it properly on your system.')

        self._writer = None

    @classmethod
    def is_available(cls) -> bool:
        """ check whether the movie infrastructure is available
        
        Returns:
            bool: True if movies can be created
        """
        from matplotlib.animation import FFMpegWriter
        return FFMpegWriter.isAvailable()  # type: ignore

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, exc_tb):
        self._end()
        return False

    def _end(self):
        """ clear up temporary things if necessary """
        if self._writer is not None:
            self._writer.finish()
        self._writer = None

    def add_figure(self, fig=None):
        """ adds the figure `fig` as a frame to the current movie
        
        Args:
            fig (:class:`~matplotlib.figures.Figure`):
                The plot figure that is added to the movie
        """
        if fig is None:
            import matplotlib.pyplot as plt
            fig = plt.gcf()

        if self._writer is None:
            # initialize a new writer
            from matplotlib.animation import FFMpegWriter
            self._writer = FFMpegWriter(self.framerate, **self.kwargs)
            self._writer.setup(fig, self.filename, dpi=self.dpi)

        else:
            # update the figure reference on a given writer, since it might have
            # changed from the last call. In particular, this will happen when
            # figures are shown using the `inline` backend.
            self._writer.fig = fig

        self._writer.grab_frame()

    def save(self):
        """ convert the recorded images to a movie using ffmpeg """
        self._end()
Esempio n. 11
0
class Space:  # a rectangle populated by Body's
    def __init__(self, name, T=RT, R=1, limits=''):
        self.name = name
        if visual:
            self.fig = plt.figure(
                figsize=(SS * W / np.sqrt(W**2 + H**2),
                         SS * H / np.sqrt(W**2 + H**2)),
                frameon=False)  # W:H proportion, SS inches diag (resizeable)
            self.fig.canvas.set_window_title('small_worl2d ' + self.name +
                                             ' (' + str(W) + ':' + str(H) +
                                             ')')
            self.ax = self.fig.add_axes([0, 0, 1, 1])  # full window
            self.ax.set_facecolor('w')  # white background
            self.ax.set_xlim(
                -W, W
            )  # note that x coords, of everything visible in the space, range from -W to W
            self.ax.set_ylim(-H, H)  # and y coords from -H to H
            self.movie_writer = FFMpegWriter(fps=10)
            self.movie_writer.setup(self.fig, self.name + '.mp4', self.fig.dpi)
        self.bodies = []  # list ob Body's in this Space
        self.dist = np.zeros(
            (0, 0)
        )  # distances between self.bodies centroids in a symmetric np.matrix
        self.R = R  # communication radius (by now, communication is omnidirectional and with same R for every type of (ani)Body)
        self.conn = {
        }  # a dict where conn[i] is the set of j's such that (i,j) is edge; i, j are AniBody's of the same type in self.bodies
        self.conngraph = [
        ]  # a list of edges (i,j) for graphical representation
        self.time = time()  # time (s) last updated
        self.t0 = self.time
        self.T = T  # update cycle time (s)
        # the "right" value depends on the number of moving bodies, the window size and events,
        # and the performance of the Hw, Sw, and code (it is developped for readability first, profiling yet TBD)
        self.updates = -1
        self.avgT = 0  # for timing statistics
        self.limits = limits  # str possibly containing v or/and h
        # Limits are implemented as Obstacles, with a near-infinity enclosing radio
        # Without limits, moving bodies that cross the top appear at the bottom, etc
        # In such case, avoid placing obstacles, etc near the no-limit
        # Creation and drawing of the borders:
        if 'v' in limits:
            self.bodies.append(
                Obstacle('top',
                         pos=-1,
                         area=-1,
                         fc='k',
                         vertices=[(-W, H - 0.1), (W, H - 0.1), (W, 1e12),
                                   (-W, 1e12)]))
            self.bodies.append(
                Obstacle('bottom',
                         pos=-1,
                         area=-1,
                         fc='k',
                         vertices=[(-W, -H + 0.1), (W, -H + 0.1), (W, -1e12),
                                   (-W, -1e12)]))
        if 'h' in limits:
            self.bodies.append(
                Obstacle('left',
                         pos=-1,
                         area=-1,
                         fc='k',
                         vertices=[(-W + 0.1, -H), (-W + 0.1, H), (-1e12, H),
                                   (-1e12, -H)]))
            self.bodies.append(
                Obstacle('right',
                         pos=-1,
                         area=-1,
                         fc='k',
                         vertices=[(W - 0.1, -H), (W - 0.1, H), (1e12, H),
                                   (1e12, -H)]))
        if visual:
            self.redraw()  # required to init axes
            plt.pause(0.001)
            if loginfo:
                print('Init ' + name + ' with mpl version ' + mpl.__version__)
                for b in self.bodies:
                    print(repr(b))

    ## Closing functions

    def has_been_closed(self):
        """ Returns True when the figure where self is drawn is not active """
        if visual:
            fig = self.ax.figure.canvas.manager
            active_figs = plt._pylab_helpers.Gcf.figs.values()
            return fig not in active_figs
        else:
            return False

    def close(self):
        print(self.name + ' closed')
        if loginfo:
            for b in self.bodies:
                print('Body ' + b.name +
                      ' updated each avg {0:1.2f} s'.format(b.avgT))
                if isinstance(b, AniBody):
                    for x in b.souls:
                        print('Soul ' + str(x) +
                              ' updated each avg {0:1.2f} s (T = {1:1.2f} s)'.
                              format(x.avgT, x.T))
            if visual:
                print('Space ' + self.name +
                      ' redrawn each avg {0:1.2f} s (T ={1:1.2f} s)'.format(
                          self.avgT, self.T))
        if visual:
            self.movie_writer.finish()
        del self

    ## Drawing functions

    def redraw(self):
        """ Draws all the self.bodies whose time is more recent than self.time

        The patch to draw each Body and Soul is stored within it, so it's a matter of removing one and adding another.
        In ROS, this is called in the space node, just frequently enough.
        """
        t = time()
        for b in self.bodies:
            if b.time > self.time:
                if not b.pp == None:
                    b.pp.remove()
                b.pp = patches.Polygon(b.vertices,
                                       fc=b.fc)  # the Body is dense
                self.ax.add_patch(b.pp)
                if isinstance(b, AniBody) and shoul:
                    for s in b.souls:
                        if s.time > self.time:
                            if not s.pp == None:
                                s.pp.remove()
                            if s.vertices == None:
                                s.pp = None
                            else:
                                s.pp = patches.Polygon(
                                    s.vertices,
                                    fill=False,
                                    ec=b.fc,
                                    lw=0.5,
                                    ls=':')  # the Soul is ethereal
                                self.ax.add_patch(s.pp)
        if showconn:
            while len(self.conngraph) > 0:
                trash = self.conngraph.pop()
                trash[0].remove()
            for i in list(self.conn):
                (xi, yi) = (self.bodies[i].pos.x, self.bodies[i].pos.y)
                for j in list(self.conn[i]):
                    (xj, yj) = ((self.bodies[j].pos.x, self.bodies[j].pos.y))
                    self.conngraph.append(
                        self.ax.plot([xi, xj], [yi, yj],
                                     color=[0.7, 0.7, 0.7],
                                     lw=0.3,
                                     ls=':'))
        plt.draw()
        if not self.has_been_closed():
            self.movie_writer.grab_frame()
            plt.pause(0.001)
        self.updates += 1
        if self.updates > 0:
            self.avgT = ((self.updates - 1) * self.avgT +
                         (t - self.time)) / self.updates
        self.time = t

    def flash(self, bodies):
        """ Useful for debugging """
        i = 0
        while i < 10:
            i += 1
            for b in bodies:
                if not b.pp == None:
                    b.pp.remove()
            plt.draw()
            plt.pause(0.02)
            for b in bodies:
                if not b.pp == None:
                    self.ax.add_patch(b.pp)
            plt.draw()
            plt.pause(0.02)

    ## Body's management functions

    def bodindex(self, name):
        """ Returns the index in self.bodies of the Body named so """
        names = [self.bodies[i].name for i in range(len(self.bodies))]
        return names.index(name)

    def typindices(self, type):
        """ Returns the indices in self.bodies of the Body's of the type """
        indices = []
        for i in range(len(self.bodies)):
            if isinstance(self.bodies[i], type):
                indices.append(i)
        return indices

    def update_dist(self):
        """ Updates the matrix of dist between all Body's in self.bodies """
        for i in range(len(self.bodies)):
            bi = self.bodies[i]
            for j in range(i + 1, len(self.bodies)):
                bj = self.bodies[j]
                self.dist[i, j] = self.dist[j, i] = bi.pos.distance(bj.pos)

    def update_conn(self):
        """ Updates the list of conn pairs between AniBody's of the same type in self.bodies """
        self.conn = {}
        for i in range(len(self.bodies)):
            if isinstance(self.bodies[i], AniBody):
                self.conn[i] = set()
        for i in list(self.conn):
            type_i = type(self.bodies[i])
            for j in range(i + 1, len(self.bodies)):
                if isinstance(self.bodies[j], type_i) and self.dist[
                        i, j] < self.R:  # i and j are kin and near
                    ray = LineString([self.bodies[i].pos, self.bodies[j].pos])
                    for k in range(len(self.bodies)):
                        bk = self.bodies[k]
                        if not k in (i, j) and self.dist[
                                i, k] < self.dist[i, j] + bk.r_encl:
                            if isinstance(bk, Obstacle) and Polygon(
                                    bk.vertices).intersects(ray):
                                break
                    else:
                        self.conn[i] |= {j}
                        self.conn[j] |= {i}

    def graph(self, type):
        """ Returns the graph (dictionary of connections) formed by the AniBody's of type """
        result = {}
        for i in range(len(self.bodies)):
            if isinstance(self.bodies[i], type):
                result[i] = self.conn[i]
        return result

    def remobodies(self, ko):
        """ Removes all the self.bodies in list ko """
        if len(ko) > 0:
            ko = sorted(set(ko))  ## needs to be ordered
            ok = []
            for i in range(len(self.bodies)):
                if not (i in ko):
                    ok.append(i)
            for i in ko:
                if visual:
                    self.bodies[i].pp.remove(
                    )  # remove the Body patch in the plot
                    if isinstance(self.bodies[i], AniBody) and shoul:
                        for s in self.bodies[i].souls:
                            if not s.pp == None:
                                s.pp.remove(
                                )  # remove the Soul patch in the plot
            ko.reverse(
            )  # after removing (poping) i, the j>i would advance one position
            for i in ko:  # so remove in reverse order
                if logerror: print('Removed ' + self.bodies.pop(i).name)
            self.dist = self.dist[
                ok, :][:, ok]  # remove row and column from dist matrix

    def spawn_bodies(self,
                     nm=50,
                     nk=0,
                     ns=0,
                     nf=0,
                     nn=0,
                     nO=0,
                     nMO=0,
                     room=room):
        """ A convenience SCRIPT for spawning many self.bodies randomly, and initializing self.dist and the plot
        
        Typically, obstacles and nests won't be random, they will be a part of the definition of a case;
        Their creation should be done BEFORE calling this funcion, so they're taken into account for the initialization.
        Some details, e.g., initial and maximum velocities, might be changed, in their new=Body(...) call
        """
        s = self
        left = -1e9
        right = 1e9
        bottom = -1e9
        top = 1e9
        for vertice in room:
            x = vertice[0]
            y = vertice[1]
            if x > left:
                left = x
            elif x < right:
                right = x
            if y > bottom:
                bottom = y
            elif y < top:
                top = y
        # Obstacle's and MObstacle's can overlap
        i = 0
        while i < nO:
            new = Obstacle('O' + str(i),
                           (uniform(left, right), uniform(bottom, top)),
                           uniform(-np.pi, np.pi))
            if s.fits(new, room, True):
                s.bodies.append(new)
                i += 1
        i = 0
        while i < nMO:
            new = MObstacle('MO' + str(i),
                            (uniform(left, right), uniform(bottom, top)),
                            uniform(-np.pi, np.pi),
                            v=vN / 20,
                            v_max=vN / 10,
                            w_max=wN / 10)
            if s.fits(new, room, True):
                s.bodies.append(new)
                i += 1
        # But not the other Body's
        i = 0
        while i < nn:
            new = Nest('n' + str(i),
                       (uniform(left, right), uniform(bottom, top)),
                       uniform(-np.pi, np.pi))
            if s.fits(new, room):
                s.bodies.append(new)
                i += 1
        i = 0
        while i < nf:
            new = Food('f' + str(i),
                       (uniform(left, right), uniform(bottom, top)),
                       uniform(-np.pi, np.pi))
            if s.fits(new, room):
                s.bodies.append(new)
                i += 1
        i = 0
        while i < nm:
            new = Mobot('m' + str(i),
                        (uniform(left, right), uniform(bottom, top)),
                        uniform(-np.pi, np.pi),
                        v_max=vN / 4,
                        w_max=wN)
            if s.fits(new, room):
                s.bodies.append(new)
                i += 1
        i = 0
        while i < nk:
            new = Killer('k' + str(i),
                         (uniform(left, right), uniform(bottom, top)),
                         uniform(-np.pi, np.pi),
                         v_max=vN / 2,
                         w_max=wN / 4)
            if s.fits(new, room):
                s.bodies.append(new)
                i += 1
        i = 0
        while i < ns:
            new = Shepherd('s' + str(i),
                           (uniform(left, right), uniform(bottom, top)),
                           uniform(-np.pi, np.pi),
                           v_max=vN / 3,
                           w_max=wN / 2)
            if s.fits(new, room):
                s.bodies.append(new)
                i += 1
        s.dist = np.zeros(
            (len(s.bodies),
             len(s.bodies)))  # distances between centroids in a np.matrix
        s.update_dist()
        s.update_conn()
        if visual:
            if loginfo:
                for b in s.bodies:
                    if b.time > s.time: print(repr(b))
            s.redraw()
            print(
                'Ready to start. You have a few secs to resize window, do it now or leave it.'
            )
            plt.pause(
                2
            )  # some pause is required by the GUI to show the effects o a draw, and cath events
            # this first one is long to allow some time to resize window before starting
            s.time = time()  # reset initial time of space
            s.t0 = s.time
            s.updates = 0
            for b in s.bodies:
                b.time = s.time  # reset initial time of every body

    ## Collision detection functions

    def fits(self, new, where=room, noverlap=True):
        """ Returns True if the new Body fits in where (list of vertices of a polygon) """
        newPolygon = Polygon(new.vertices)
        newPolygon = newPolygon.buffer(new.r_encl)
        if noverlap:
            for old in self.bodies:
                if newPolygon.intersects(Polygon(old.vertices)):
                    return False
        if Polygon(where).contains(Polygon(new.vertices)):
            return True
        else:
            return False

    def collisions(self):
        """ Detect collisions between self.bodies and return list of ko AniBody's (only AniBody's die, they must avoid collisions) """
        """ TBD: return list of collisions instead, of any kind of Body's, to be filtered out outside of this function, more flexible use """
        ko = []
        for i in range(len(self.bodies)):
            bi = self.bodies[i]
            if not (i in ko) and isinstance(bi, (Obstacle, AniBody)):
                for j in range(i + 1, len(self.bodies)):
                    bj = self.bodies[j]
                    if not (j in ko) and isinstance(
                            bj,
                        (Obstacle, AniBody)) and self.dist[i, j] < (
                            bi.r_encl + bj.r_encl) and Polygon(
                                bi.vertices).intersects(Polygon(bj.vertices)):
                        if isinstance(bi, Mobot) and isinstance(
                                bj,
                            (Obstacle, Mobot,
                             Killer)) or isinstance(bi, Killer) and isinstance(
                                 bj,
                                 (Obstacle, Killer, Shepherd)) or isinstance(
                                     bi, Shepherd) and isinstance(
                                         bj, (Obstacle, Shepherd)):
                            ko.append(i)
                        if isinstance(bj, Mobot) and isinstance(
                                bj,
                            (Obstacle, Mobot,
                             Killer)) or isinstance(bj, Killer) and isinstance(
                                 bi,
                                 (Obstacle, Killer, Shepherd)) or isinstance(
                                     bj, Shepherd) and isinstance(
                                         bi, (Obstacle, Shepherd)):
                            ko.append(j)
        return ko

    ## Perception functions: perception functions must be defined in the Body's Space, not in the Body itself

    def nearby(self, i, r, rng, type):
        """ Returns a list with the Body's of type "visible" from Body i """
        pos = self.bodies[i].pos
        th = self.bodies[i].th
        nearby = []
        a = np.linspace(-rng, rng, 60)
        vpa = [(r * np.cos(x), r * np.sin(x)) for x in a
               ]  # vertices in perception area, relative to pos and th
        if rng < np.pi:  # when not full range, the body (its centroid) is another vertex
            vpa.append((0, 0))
        pa = translate(
            rotate(Polygon(vpa), th, (0, 0), True), pos.x,
            pos.y)  # rotate th and translate to pos the perception area
        for j in range(len(self.bodies)):
            bj = self.bodies[j]
            if isinstance(bj, type) and j != i and self.dist[i, j] < (
                    r + bj.r_encl) and pa.intersects(Polygon(bj.vertices)):
                ray = LineString([pos, bj.pos])
                for k in range(len(self.bodies)):
                    bk = self.bodies[k]
                    if not k in (i, j) and self.dist[
                            i, k] < self.dist[i, j] + bk.r_encl:
                        if isinstance(bk, Obstacle) and Polygon(
                                bk.vertices).intersects(ray):
                            break
                else:
                    nearby.append(bj)
        return nearby

    def nearest(self, i, r, rng, type):
        """ Returns the nearest to Body i of the nearby Body's of type """
        pos = self.bodies[i].pos
        nearby = self.nearby(i, r, rng, type)
        nearest = None
        mindist = r
        while len(nearby) > 0:
            dist = pos.distance(Polygon(nearby[-1].vertices))
            if dist < mindist:
                nearest = nearby.pop()
                mindist = dist
            else:
                nearby.pop()
        return nearest

    def nearestpoint(self, i, b):
        """ Returns the nearest to Body i point of b """
        bi = Polygon(self.bodies[i].vertices)
        bj = Polygon(b.vertices)
        np = nearest_points(bi, bj)
        return np[1]

    def incontact(self, i, type):
        """ Returns a list with the Body's of type in contact with Body i """
        bi = self.bodies[i]
        pbi = Polygon(bi.vertices)
        incontact = []
        for j in range(len(self.bodies)):
            bj = self.bodies[j]
            if isinstance(bj, type) and j != i and self.dist[i, j] < (
                    bi.r_encl + bj.r_encl) and pbi.intersects(
                        Polygon(bj.vertices)):
                incontact.append(bj)
        return incontact
Esempio n. 12
0
def main():

    if args.save:
        metadata = dict(title='Game')
        writer = FFMpegWriter(fps=5, metadata=metadata)

    seed = np.random.randint(2**10)
    env = gym.make('gym_pursuitevasion_small:pursuitevasion_small-v0', ep=1)

    nb_agents = len(env.agents)
    stop = False  #if you want to stop playing
    num_ep = 100  #number of games
    max_steps = 100
    start_time = time.time()
    print(colored('You are player 0', 'red'))
    print(colored('Your input code:\n\tLeft arrow: Turn left\n\t'\
     +'Right arrow: Turn right\n\tUp arrow: Go forward\n\tDown arrow: Still'\
     +'\n\tAnything else: Stop Game','red'))
    actions = []
    observations = []
    rewards = []
    episode_returns = np.zeros((num_ep, ))
    episode_starts = []
    e_win = np.zeros((num_ep, ), dtype=bool)
    ng = num_ep
    for ep in range(1, num_ep + 1):
        print(colored('Game number {}'.format(ep), 'blue'))
        obs = env.reset(ep=ep)
        reward_sum = 0.0
        env.render(mode='human', highlight=True, ep=ep)
        if args.save:
            writer.setup(env.window.fig, "smallest_game{}.mp4".format(ep), 300)
        new_game = True
        while True:
            a = get()
            if a == 100:  #stop playing
                episode_returns[ep - 1] = reward_sum
                env.messages = [
                    'Sorry to see you go! You only played {} games!'.format(ep)
                ]
                env.render(mode='human', highlight=True, ep=ep)
                if args.save:
                    writer.grab_frame()
                print(
                    colored(
                        '\nSorry to see you go! You only played {} games!\n'.
                        format(ep), 'red'))
                stop = True
                ng = ep
                time.sleep(1.0)
                env.window.close()
            else:
                if new_game:
                    episode_starts.append(True)
                    new_game = False
                observations.append(obs)
                a = np.array([int(a)])
                obs, rew, done, ewi = env.step(ev_action=a)
                actions.append(a)
                rewards.append(rew)
                reward_sum += rew
                env.render(mode='human', highlight=True, ep=ep)
                if args.save:
                    writer.grab_frame()
                if done:
                    time.sleep(0.5)
                    env.window.close()
                    episode_returns[ep - 1] = reward_sum
                    e_win[ep - 1] = ewi
                    if ep == num_ep:
                        stop = True
                    break
                else:
                    episode_starts.append(done)
            if stop:
                break
        if stop:
            break
    if args.save:
        writer.finish()
    if len(episode_starts) > np.shape(observations)[0]:
        print('len(obs): {}; len(episode_starts): {}'.format(
            np.shape(observations)[0], len(episode_starts)))
        print('Removing last entry of episode_starts!')
        episode_starts = episode_starts[:-1]
    end_time = time.time()
    print('Playing time: {:.2f}s = {:.2f}min'.format(
        end_time - start_time, (end_time - start_time) / 60))
    actions = np.asarray(actions, dtype=int)
    observations = np.asarray(observations, dtype=int)
    rewards = np.asarray(rewards)
    exp_dict = {
        'actions': actions,
        'obs': observations,
        'rewards': rewards,
        'episode_returns': episode_returns,
        'episode_starts': episode_starts,
        'evader_wins': e_win,
        'number_games': ng
    }  # type: Dict[str, np.ndarray]
    np.savez(args.log_file, **exp_dict)
    print('Expert evader won {}/{} games played!'.format(np.sum(e_win), ng))
Esempio n. 13
0
class FullStateQ():
    def __init__(
        self,
        K_arm=2,
        first_choice=None,
        max_run_length=10,
        discount_rate=0.99,
        learn_rate=0.1,
        softmax_temperature=None,
        epsilon=None,
        if_record_Q='',
    ):

        self.if_record_Q = if_record_Q
        self.learn_rate = learn_rate
        self.softmax_temperature = softmax_temperature
        self.discount_rate = discount_rate
        self.epsilon = epsilon

        self._init_states(max_run_length, K_arm)
        self.ax = []

        if first_choice is None:
            first_choice = np.random.choice(K_arm)
        self.current_state = self.states[
            first_choice, 0]  # Randomly initialize the first choice
        self.backup_SA = [self.current_state,
                          -1]  # First trial is a STAY at first_choice

        if self.softmax_temperature is not None:
            self.if_softmax = True
        elif self.epsilon is not None:
            self.if_softmax = False
        else:
            raise ValueError('Both softmax_temp and epsilon are missing!')

    def _init_states(self, max_run_length, K_arm):
        # Generate a K_arm * max_run_length numpy array of states
        max_run_length = int(np.ceil(max_run_length))

        self.states = np.zeros([K_arm, max_run_length], dtype=object)
        for k in range(K_arm):
            for r in range(max_run_length):
                self.states[k, r] = State(k, r)

        # Define possible transitions
        for k in range(K_arm):
            for r in range(max_run_length):
                for kk in range(K_arm):
                    # Leave: to any other arms
                    if k != kk:
                        self.states[k, r].add_next_states([self.states[kk, 0]])

                if r < max_run_length - 1:
                    self.states[k, r].add_next_states(
                        [self.states[k,
                                     r + 1]])  # Stay is always the last index

    def act(self):  # State transition
        if self.if_softmax:
            next_state_idx = self.current_state.act_softmax(
                self.softmax_temperature)  # Next state index!!
        else:  # Epsilon-greedy
            next_state_idx = self.current_state.act_epsilon(
                self.epsilon)  # Next state index!!

        self.backup_SA = [self.current_state,
                          next_state_idx]  # For one-step backup in Q-learning
        self.current_state = self.current_state.next_states[next_state_idx]
        choice = self.current_state.which[
            0]  # Return absolute choice! (LEFT/RIGHT)
        return choice

    def update_Q(self, reward):  # Q-learning (off-policy TD-0 bootstrap)
        max_next_SAvalue_for_backup_state = np.max(
            self.current_state.Q)  # This makes it off-policy
        last_state, last_choice = self.backup_SA
        last_state.Q[last_choice] += self.learn_rate * (
            reward + self.discount_rate * max_next_SAvalue_for_backup_state -
            last_state.Q[last_choice])  # Q-learning

        # print('Last: ', last_state.which, '(updated); This: ', self.current_state.which)
        # print('----------------------------------')
        # print('Left, leave: ', [s.Q[0] for s in self.states[0,:]])
        # print('Right,leave: ', [s.Q[0] for s in self.states[1,:]])
        # print('Left, stay : ', [s.Q[1] for s in self.states[0,:]])
        # print('Right,stay : ', [s.Q[1] for s in self.states[1,:]])

    def plot_Q(self,
               time=np.nan,
               reward=np.nan,
               p_reward=np.nan,
               description=''):  # Visualize value functions (Q(s,a))
        # Initialization
        if self.ax == []:
            # Prepare axes
            self.fig, self.ax = plt.subplots(2,
                                             2,
                                             sharey=True,
                                             figsize=[12, 8])
            plt.subplots_adjust(hspace=0.5, top=0.85)
            self.ax2 = self.ax.copy()
            self.annotate = plt.gcf().text(0.05, 0.9, '', fontsize=13)
            for c in [0, 1]:
                for d in [0, 1]:
                    self.ax2[c, d] = self.ax[c, d].twinx()

            # Prepare animation
            if self.if_record_Q:
                metadata = dict(title='FullStateQ', artist='Matplotlib')
                self.writer = FFMpegWriter(fps=25, metadata=metadata)
                self.writer.setup(self.fig,
                                  "..\\results\\%s.mp4" % description, 150)

        direction = ['LEFT', 'RIGHT']
        decision = ['Leave', 'Stay']
        X = np.r_[1:np.shape(self.states)[1] -
                  0.1]  # Ignore the last run_length (Must leave)

        # -- Q values and policy --
        for d in [0, 1]:
            # Compute policy p(a|s)
            if self.if_softmax:
                Qs = np.array([s.Q for s in self.states[d, :-1]])
                ps = []
                for qq in Qs:
                    ps.append(softmax(qq, self.softmax_temperature))
                ps = np.array(ps)

            for c in [0, 1]:
                self.ax[c, d].cla()
                self.ax2[c, d].cla()

                self.ax[c, d].set_xlim([0, max(X) + 1])
                self.ax[c, d].set_ylim([-0.05, max(plt.ylim())])

                bar_color = 'r' if c == 0 else 'g'

                self.ax[c, d].bar(X, Qs[:, c], color=bar_color, alpha=0.5)
                self.ax[c, d].set_title(direction[d] + ', ' + decision[c])
                self.ax[c, d].axhline(0, color='k', ls='--')
                if d == 0: self.ax[c, d].set_ylabel('Q(s,a)', color='k')
                # self.ax[c, d].set_xticks(np.round(self.ax[c, d].get_xticks()))
                self.ax[c, d].set_xticks(X)

                self.ax2[c, d].plot(X, ps[:, c], bar_color + '-o')
                if d == 1: self.ax2[c, d].set_ylabel('P(a|s)', color=bar_color)
                self.ax2[c, d].axhline(0, color=bar_color, ls='--')
                self.ax2[c, d].axhline(1, color=bar_color, ls='--')
                self.ax2[c, d].set_ylim([-0.05, 1.05])

        # -- This state --
        last_state = self.backup_SA[0].which
        current_state = self.current_state.which
        if time > 1:
            self.ax2[0, last_state[0]].plot(last_state[1] + 1,
                                            self.last_reward,
                                            'go',
                                            markersize=10,
                                            alpha=0.5)
        self.ax2[0, current_state[0]].plot(current_state[1] + 1,
                                           reward,
                                           'go',
                                           markersize=15)
        self.last_reward = reward

        # plt.ylim([-1,1])
        self.annotate.set_text(
            '%s\nt = %g, p_reward = %s\n%s --> %s, reward = %g\n' %
            (description, time, p_reward, last_state, current_state, reward))
        if self.if_record_Q:
            print(time)
            self.writer.grab_frame()
            return True
        else:
            plt.gcf().canvas.draw()
            return plt.waitforbuttonpress()
Esempio n. 14
0
class PlotMovieWriter(object):
    """
    PlotLoop + MovieWriter
    Example:
        import sflow.python.ploting as plt
        plot = plt.PlotMovieWriter(plt.imshow_flat, outputfile, dpi=100)
        for i in range(nstep):
            if i % 50 == 0:
                out = sess.run([styled, trainop])[0]
                plot(out)
                # plt.pause(0.0001)
                print(i)
            else:
                l = sess.run([loss_content, loss_style, trainop])[:-1]
                print(i, l)
        plot.finish()

        plt.plot_pause()

    """
    def __init__(self,
                 outfile,
                 showfun=None,
                 fig=None,
                 drawopt=None,
                 dpi=100,
                 **movieopt):

        self.showfun = showfun or plt.imshow
        self.fig = fig or plt.figure()
        drawopt = drawopt or dict()
        self.drawopt = drawopt
        self.setdata = None
        self.onclose = drawopt.pop('onclose', self._onclose)
        # for movie writing
        self.moviewriter = None
        self.movieopt = movieopt

        self.outfile = outfile
        self.dpi = dpi
        self._first = True

    def setup_movie(self, fig):
        # create moviewriter
        # then setup
        # self.movieopt.pop()

        # fps=5, codec=None, bitrate=None, extra_args=None, metadata
        self.moviewriter = FFMpegWriter(**self.movieopt)
        self.moviewriter.setup(fig, self.outfile, self.dpi)

    def _onclose(self, evt):
        # import sys
        # print('figure closed, exit')
        # sys.exit(0)
        pass

    def __call__(self, *args, **kwargs):

        if not self.setdata:
            if self.fig is None:
                self.fig = plt.figure()
                # self.fig = self.setup_figure()
                self.fig.canvas.mpl_connect('close_event', self.onclose)
            kwargs.update(self.drawopt)
            self.setdata = self.showfun(*args, **kwargs)
            if self._first:
                self.setup_movie(self.fig)
                self._first = False

            plt.show(block=False)
        else:
            # todo@dade : if error? showfun?
            try:
                self.setdata.set_data(*args, **kwargs)
            except AttributeError:
                self.showfun(*args, **kwargs)

            self.fig.canvas.draw()

        self.grab()

    def grab(self):
        self.moviewriter.grab_frame()

    def finish(self):
        if self._first:
            raise ValueError('setup not called')
        self.moviewriter.finish()
        logg.info('movie saved to [{}]'.format(self.outfile))
Esempio n. 15
0
class GUI:
    def __init__(self,
                 sim,
                 l_pol,
                 r_pol,
                 max_episodes,
                 name_left=None,
                 name_right=None,
                 capture=None):

        self.sim = sim
        self.l_pol = l_pol
        self.r_pol = r_pol
        self.max_episodes = max_episodes
        self.l_name = name_left if name_left else self.l_pol.name
        self.r_name = name_right if name_right else self.r_pol.name
        self.capture = capture

        self.next_update = 0.
        self.episode_num = 0

        self.fig = plt.figure(figsize=c.FIGSIZE)
        self.canvas = self.fig.canvas
        self.canvas.set_window_title('PONG')

        self.ax = self.fig.add_axes((0., 0., 1., 1.))
        self.ax.axis([
            c.LEFT - c.PADDLE_WIDTH - c.BALL_RADIUS,
            c.RIGHT + c.PADDLE_WIDTH + c.BALL_RADIUS,
            c.BOTTOM - c.BALL_RADIUS,
            c.TOP + c.BALL_RADIUS,
        ])
        self.ax.tick_params(top="off", bottom="off", left="off", right="off")
        self.ax.set_xticklabels([])
        self.ax.set_yticklabels([])
        self.ax.set_aspect(1)
        self.ax.set_facecolor(c.CLR_BLACK)

        self.background = self.canvas.copy_from_bbox(self.ax.bbox)

        self.l = patches.Rectangle((0., 0.),
                                   c.PADDLE_WIDTH,
                                   2 * c.HPL,
                                   color=c.PADDLE_COLOR,
                                   animated=True)
        self.r = patches.Rectangle((0., 0.),
                                   c.PADDLE_WIDTH,
                                   2 * c.HPL,
                                   color=c.PADDLE_COLOR,
                                   animated=True)
        self.ball = patches.Circle((0., 0.),
                                   radius=c.BALL_RADIUS,
                                   color=c.BALL_COLOR,
                                   animated=True)

        self.l_arrow = patches.FancyArrow(c.ARROW_START,
                                          0.,
                                          -c.ARROW_LENGTH,
                                          0.,
                                          c.ARROW_WIDTH,
                                          color=c.ARROW_COLOR,
                                          animated=True)
        self.r_arrow = patches.FancyArrow(-c.ARROW_START,
                                          0.,
                                          c.ARROW_LENGTH,
                                          0.,
                                          c.ARROW_WIDTH,
                                          color=c.ARROW_COLOR,
                                          animated=True)
        self.d_arrow = patches.FancyArrow(0.,
                                          c.ARROW_START,
                                          0.,
                                          -c.ARROW_LENGTH,
                                          c.ARROW_WIDTH,
                                          color=c.ARROW_COLOR,
                                          animated=True)

        self.ax.add_patch(self.l)
        self.ax.add_patch(self.r)
        self.ax.add_patch(self.ball)

        self.l_action = 0
        self.r_action = 0
        self.buttons = set()

        font_dict = {
            "family": "monospace",
            "size": "large",
            "weight": "bold",
            "animated": True
        }

        self.l_text = self.ax.text(c.LEFT,
                                   c.BOTTOM,
                                   self.l_name,
                                   color=c.NAME_COLOR,
                                   ha="left",
                                   **font_dict)
        self.r_text = self.ax.text(c.RIGHT,
                                   c.BOTTOM,
                                   self.r_name,
                                   color=c.NAME_COLOR,
                                   ha="right",
                                   **font_dict)
        self.score = self.ax.text((c.LEFT + c.RIGHT) / 2,
                                  c.BOTTOM,
                                  "",
                                  color=c.SCORE_COLOR,
                                  ha="center",
                                  **font_dict)

    def draw(self):

        if self.l_arrow.is_figure_set(): self.l_arrow.remove()
        if self.r_arrow.is_figure_set(): self.r_arrow.remove()
        if self.d_arrow.is_figure_set(): self.d_arrow.remove()

        self.score.set_text("{l}|{draw}|{r}".format(**self.sim.score))

        s = self.sim.get_state()

        self.ball.center = (s[S.BALL_X], s[S.BALL_Y])
        self.l.set_xy(
            (c.LEFT - c.PADDLE_WIDTH - c.BALL_RADIUS, s[S.L_Y] - c.HPL))
        self.r.set_xy((c.RIGHT + c.BALL_RADIUS, s[S.R_Y] - c.HPL))

        self.canvas.restore_region(self.background)
        self.ax.draw_artist(self.l_text)
        self.ax.draw_artist(self.r_text)
        self.ax.draw_artist(self.score)

        if self.sim.done:
            if self.sim.win == "l":
                self.ax.add_patch(self.l_arrow)
                self.ax.draw_artist(self.l_arrow)
            elif self.sim.win == "r":
                self.ax.add_patch(self.r_arrow)
                self.ax.draw_artist(self.r_arrow)
            else:
                self.ax.add_patch(self.d_arrow)
                self.ax.draw_artist(self.d_arrow)
        else:
            self.ax.draw_artist(self.ball)

        self.ax.draw_artist(self.l)
        self.ax.draw_artist(self.r)
        self.canvas.blit(self.ax.bbox)
        if self.capture:
            if self.sim.done:
                n = c.CAPTURE_FPS // 2
            else:
                n = 1
            for i in range(n):
                self.writer.grab_frame()

    def main_loop(self):
        if (time() > self.next_update) or self.capture:
            state = self.sim.get_state()
            l_a = self.l_pol.get_action(state, self.buttons)
            r_a = self.r_pol.get_action(state, self.buttons)

            self.sim.step(l_a, r_a)
            self.draw()
            self.last_update = time()

            if self.sim.done:
                self.episode_num += 1
                if self.episode_num >= self.max_episodes:
                    plt.close()
                else:
                    self.sim.new_episode()
                    self.l_pol.new_episode()
                    self.r_pol.new_episode()
                    self.next_update = time() + c.POINT_DELAY
            else:
                self.next_update = time() + c.FRAME_DELAY

    def key_press(self, event):
        if event.key == "q":
            plt.close()
        else:
            self.buttons.add(event.key)

    def key_release(self, event):
        self.buttons.remove(event.key)

    def handle_redraw(self, event):
        self.background = self.canvas.copy_from_bbox(self.ax.bbox)

    def first_draw(self, event):
        if self.canvas.manager.key_press_handler_id is not None:
            self.canvas.mpl_disconnect(
                self.canvas.manager.key_press_handler_id)

        self.canvas.mpl_disconnect(self.cid)
        self.canvas.mpl_connect('draw_event', self.handle_redraw)
        self.canvas.mpl_connect("key_press_event", self.key_press)
        self.canvas.mpl_connect("key_release_event", self.key_release)
        self.canvas.restore_region(self.background)

        self.timer = self.canvas.new_timer(interval=1)
        self.timer.add_callback(self.main_loop)
        self.timer.start()

    def start(self):
        self.cid = self.canvas.mpl_connect('draw_event', self.first_draw)
        if self.capture:
            self.writer = FFMpegWriter(fps=c.CAPTURE_FPS)
            self.writer.setup(self.fig, self.capture)

    def end(self):
        if self.capture:
            self.writer.finish()
def generate_fig3():
    tree = RandomTree(3)
    for i in range(10):
        tree.simulate(8)
    
    while True:
        try:
            a = np.random.choice(3)
            tree.children[a].set()
            break
        except AttributeError:
            pass
    
    fig1 = Figure(figsize=(16/2, 9/2))
    canvas1 = FigureCanvas(fig1)
    ax1 = fig1.add_axes((0.01, 0.01, 0.98, 0.98))
    fig2 = Figure(figsize=(16/2, 9/2))
    canvas2 = FigureCanvas(fig2)
    ax2 = fig2.add_axes((0.01, 0.01, 0.98, 0.98))

    fig3 = Figure(figsize=(16, 9))
    canvas3 = FigureCanvas(fig3)
    ax3 = fig3.add_axes((0.01, 0.01, 0.98, 0.98))

    for ax in [ax1, ax2, ax3]:
        common.set_ax_params(ax)
        ax.axis([0., 16., 0., 9.])

    r = 0.4
    
    tree.xy = (1., 9. / 2.)
    tree.box_xy = (1. - r, 9. / 2. - r)
    tree.width = 2 * r
    tree.height = 2 * r
    tree.text = "$s$"
    tree.facecolor1 = "lightblue"
    tree.facecolor2 = "lightblue"
    tree.alpha = 0.2
    tree.connectors = [(1. + r + 0.1, 9. / 2. + j * r / 3) for j in [1, 0, -1]]

    X = np.linspace(3., 15., 8)
    L = [tree]
    for i in range(8):
        L2 = []
        for n in L:
            L2.extend(c for c in n.children if c)
        Y = np.linspace(9., 0., len(L2) + 2)[1:-1]
        cnt = 0
        for n in L:
            for j in range(3):
                if n.children[j] is not None:
                    c = n.children[j]
                    x, y = X[i], Y[cnt]
                    c.connectors = [(x + r/2 + 0.1, y + k * r / 6)
                        for k in [1, 0, -1]]
                    c.xy = (x, y)
                    c.box_xy = (x - r/2, y - r/2)
                    c.width = r
                    c.height = r
                    c.father_a_xy = n.connectors[j]
                    c.a_xy = (x - r/2 - 0.1, y)
                    c.text = "$a_{}$".format(j)
                    c.facecolor1 = "lightgreen"
                    if (i==0 and c.active):
                        c.facecolor2 = "lightblue"
                    else:
                        c.facecolor2 = "lightgreen"
                    c.alpha = 1. if c.active else 0.2
                    cnt += 1
        L = L2
    
    writer = FFMpegWriter()
    writer.setup(fig3, "figures/part{}/mcts_movie.mp4".format(PART_NUM))
    writer.grab_frame()
    reset = False
    
    for c in tree.visitorder:
        if reset:
            n = c
            L = []
            while n:
                L.append(n)
                n = n.parent
        else:
            L = [c]

        for n in L[::-1]:
            n.draw(ax3, "red", 1., "xx-large")
            writer.grab_frame()
            n.remove(ax3)
            n.draw(ax3, n.facecolor1, 1., "xx-large")
            writer.grab_frame()

        c.draw(ax1, c.facecolor1, 1.)
        c.draw(ax2, c.facecolor2, c.alpha)
        
        reset = not any(c.children)
    
    writer.finish()
    
    common.save_next_fig(PART_NUM, fig1)
    common.save_next_fig(PART_NUM, fig2)