def test_discrete_env(rate_power): env = GridWorld() counter = DiscreteCounter(env.observation_space, env.action_space, rate_power=rate_power) for N in range(10, 20): assert counter.get_n_visited_states() == 0 assert counter.get_entropy() == 0.0 for ss in range(env.observation_space.n): for aa in range(env.action_space.n): for _ in range(N): ns, rr, _, _ = env.sample(ss, aa) counter.update(ss, aa, ns, rr) assert counter.N_sa[ss, aa] == N assert counter.count(ss, aa) == N if rate_power == pytest.approx(1): assert np.allclose(counter.measure(ss, aa), 1.0 / N) elif rate_power == pytest.approx(0.5): assert np.allclose(counter.measure(ss, aa), np.sqrt(1.0 / N)) assert counter.get_n_visited_states() == env.observation_space.n assert np.allclose(counter.get_entropy(), np.log2(env.observation_space.n)) counter.reset()
def test_discrete_env(): env = GridWorld() counter = DiscreteCounter(env.observation_space, env.action_space) for N in range(10, 20): for ss in range(env.observation_space.n): for aa in range(env.action_space.n): for _ in range(N): ns, rr, _, _ = env.sample(ss, aa) counter.update(ss, aa, ns, rr) assert counter.N_sa[ss, aa] == N assert counter.count(ss, aa) == N counter.reset()
def test_continuous_state_env(): env = MountainCar() counter = DiscreteCounter(env.observation_space, env.action_space) for N in [10, 20, 30]: for _ in range(100): ss = env.observation_space.sample() aa = env.action_space.sample() for _ in range(N): ns, rr, _, _ = env.sample(ss, aa) counter.update(ss, aa, ns, rr) dss = counter.state_discretizer.discretize(ss) assert counter.N_sa[dss, aa] == N assert counter.count(ss, aa) == N counter.reset()
def test_continuous_state_env(rate_power): env = MountainCar() counter = DiscreteCounter(env.observation_space, env.action_space, rate_power=rate_power) for N in [10, 20]: for _ in range(50): ss = env.observation_space.sample() aa = env.action_space.sample() for _ in range(N): ns, rr, _, _ = env.sample(ss, aa) counter.update(ss, aa, ns, rr) dss = counter.state_discretizer.discretize(ss) assert counter.N_sa[dss, aa] == N assert counter.count(ss, aa) == N if rate_power == pytest.approx(1): assert np.allclose(counter.measure(ss, aa), 1.0 / N) elif rate_power == pytest.approx(0.5): assert np.allclose(counter.measure(ss, aa), np.sqrt(1.0 / N)) counter.reset()
class Vis2dWrapper(Wrapper): """ Stores and visualizes the trajectories environments with 2d box observation spaces and discrete action spaces. Parameters ---------- env: gym.Env n_bins_obs : int, default = 10 Number of intervals to discretize each dimension of the observation space. Used to count number of visits. memory_size : int, default = 100 Maximum number of trajectories to keep in memory. The most recent ones are kept. state_preprocess_fn : callable(state, env, **kwargs)-> np.ndarray, default: None Function that converts the state to a 2d array state_preprocess_kwargs : dict, default: None kwargs for state_preprocess_fn """ def __init__( self, env, n_bins_obs=10, memory_size=100, state_preprocess_fn=None, state_preprocess_kwargs=None, ): Wrapper.__init__(self, env) if state_preprocess_fn is None: assert isinstance(env.observation_space, spaces.Box) assert isinstance(env.action_space, spaces.Discrete) self.state_preprocess_fn = state_preprocess_fn or identity self.state_preprocess_kwargs = state_preprocess_kwargs or {} self.memory = TrajectoryMemory(memory_size) self.total_visit_counter = DiscreteCounter(self.env.observation_space, self.env.action_space, n_bins_obs=n_bins_obs) self.episode_visit_counter = DiscreteCounter( self.env.observation_space, self.env.action_space, n_bins_obs=n_bins_obs) self.current_state = None self.curret_step = 0 def reset(self): self.current_step = 0 self.current_state = self.env.reset() return self.current_state def step(self, action): observation, reward, done, info = self.env.step(action) # initialize new trajectory if self.current_step == 0: self.memory.end_trajectory() self.episode_visit_counter.reset() self.current_step += 1 # update counters ss, aa = self.current_state, action ns = observation self.total_visit_counter.update(ss, aa, ns, reward) self.episode_visit_counter.update(ss, aa, ns, reward) # store transition transition = Transition( ss, self.state_preprocess_fn(ss, self.env, **self.state_preprocess_kwargs), aa, reward, self.total_visit_counter.count(ss, aa), self.episode_visit_counter.count(ss, aa), ) self.memory.append(transition) # update current state self.current_state = observation return observation, reward, done, info def plot_trajectories( self, fignum=None, figsize=(6, 6), hide_axis=True, show=True, video_filename=None, colormap_name="cool", framerate=15, n_skip=1, dot_scale_factor=2.5, alpha=0.25, xlim=None, ylim=None, dot_size_means="episode_visits", ): """ Plot history of trajectories in a scatter plot. Colors distinguish recent and old trajectories, the size of the dots represent the number of visits to a state. If video_filename is given, a video file is saved. Otherwise, plot only the final frame. Parameters ---------- fignum : str Figure name figsize : (float, float) (width, height) of the image in inches. hide_axis : bool If True, axes are hidden. show : bool If True, calls plt.show() video_filename : str or None If not None, save a video with given filename. colormap_name : str, default = 'cool' Colormap name. See https://matplotlib.org/tutorials/colors/colormaps.html framerate : int, default: 15 Video framerate. n_skip : int, default: 1 Skip period: every n_skip trajectories, one trajectory is plotted. dot_scale_factor : double Scale factor for scatter plot points. alpha : float, default: 0.25 The alpha blending value, between 0 (transparent) and 1 (opaque). xlim: list, default: None x plot limits, set to [0, 1] if None ylim: list, default: None y plot limits, set to [0, 1] if None dot_size_means : str, {'episode_visits' or 'total_visits'}, default: 'episode_visits' Whether to scale the dot size with the number of visits in an episode or the total number of visits during the whole interaction. """ logger.info("Plotting...") fignum = fignum or str(self) colormap_fn = plt.get_cmap(colormap_name) # discretizer try: discretizer = self.episode_visit_counter.state_discretizer epsilon = min( discretizer._bins[0][1] - discretizer._bins[0][0], discretizer._bins[1][1] - discretizer._bins[1][0], ) except Exception: epsilon = 0.01 # figure setup xlim = xlim or [0.0, 1.0] ylim = ylim or [0.0, 1.0] fig = plt.figure(fignum, figsize=figsize) fig.clf() canvas = FigureCanvas(fig) images = [] ax = fig.gca() ax.set_xlim(xlim) ax.set_ylim(ylim) if hide_axis: ax.set_axis_off() # scatter plot indices = np.arange(self.memory.n_trajectories)[::n_skip] for idx in indices: traj = self.memory.trajectories[idx] color_time_intensity = (idx + 1) / self.memory.n_trajectories color = colormap_fn(color_time_intensity) states = np.array([traj[ii].state for ii in range(len(traj))]) if dot_size_means == "episode_visits": sizes = np.array( [traj[ii].n_episode_visits for ii in range(len(traj))]) elif dot_size_means == "total_visits": raw_states = [traj[ii].raw_state for ii in range(len(traj))] sizes = np.array([ np.sum([ self.total_visit_counter.count(ss, aa) for aa in range(self.env.action_space.n) ]) for ss in raw_states ]) else: raise ValueError() sizes = 1 + sizes sizes = (dot_scale_factor**2) * 100 * epsilon * sizes / sizes.max() ax.scatter(x=states[:, 0], y=states[:, 1], color=color, s=sizes, alpha=alpha) plt.tight_layout() if video_filename is not None: canvas.draw() image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) image_from_plot = image_from_plot.reshape( fig.canvas.get_width_height()[::-1] + (3, )) images.append(image_from_plot) if video_filename is not None: logger.info("... writing video ...") video_write(video_filename, images, framerate=framerate) logger.info("... done!") if show: plt.show() def plot_trajectory_actions( self, fignum=None, figsize=(8, 6), n_traj_to_show=10, hide_axis=True, show=True, video_filename=None, colormap_name="Paired", framerate=15, n_skip=1, dot_scale_factor=2.5, alpha=1.0, action_description=None, xlim=None, ylim=None, ): """ Plot actions (one action = one color) chosen in recent trajectories. If video_filename is given, a video file is saved showing the evolution of the actions taken in past trajectories. Parameters ---------- fignum : str Figure name figsize : (float, float) (width, height) of the image in inches. n_traj_to_show : int Number of trajectories to visualize in each frame. hide_axis : bool If True, axes are hidden. show : bool If True, calls plt.show() video_filename : str or None If not None, save a video with given filename. colormap_name : str, default = 'tab20b' Colormap name. See https://matplotlib.org/tutorials/colors/colormaps.html framerate : int, default: 15 Video framerate. n_skip : int, default: 1 Skip period: every n_skip trajectories, one trajectory is plotted. dot_scale_factor : double Scale factor for scatter plot points. alpha : float, default: 1.0 The alpha blending value, between 0 (transparent) and 1 (opaque). action_description : list or None (optional) List (of strings) containing a description of each action. For instance, ['left', 'right', 'up', 'down']. xlim: list, default: None x plot limits, set to [0, 1] if None ylim: list, default: None y plot limits, set to [0, 1] if None """ logger.info("Plotting...") fignum = fignum or (str(self) + "-actions") colormap_fn = plt.get_cmap(colormap_name) action_description = action_description or list( range(self.env.action_space.n)) # discretizer try: discretizer = self.episode_visit_counter.state_discretizer epsilon = min( discretizer._bins[0][1] - discretizer._bins[0][0], discretizer._bins[1][1] - discretizer._bins[1][0], ) except Exception: epsilon = 0.01 # figure setup xlim = xlim or [0.0, 1.0] ylim = ylim or [0.0, 1.0] # indices to visualize if video_filename is None: indices = [self.memory.n_trajectories - 1] else: indices = np.arange(self.memory.n_trajectories)[::n_skip] # images for video images = [] # for idx in indices: for init_idx in indices: idx_set = range(max(0, init_idx - n_traj_to_show + 1), init_idx + 1) # clear before showing new trajectories fig = plt.figure(fignum, figsize=figsize) fig.clf() canvas = FigureCanvas(fig) ax = fig.gca() ax.set_xlim(xlim) ax.set_ylim(ylim) if hide_axis: ax.set_axis_off() for idx in idx_set: traj = self.memory.trajectories[idx] states = np.array([traj[ii].state for ii in range(len(traj))]) actions = np.array( [traj[ii].action for ii in range(len(traj))]) sizes = (dot_scale_factor**2) * 750 * epsilon for aa in range(self.env.action_space.n): states_aa = states[actions == aa] color = colormap_fn(aa / self.env.action_space.n) ax.scatter( x=states_aa[:, 0], y=states_aa[:, 1], color=color, s=sizes, alpha=alpha, label=f"action = {action_description[aa]}", ) # for unique legend entries, source: https://stackoverflow.com/a/57600060 plt.legend( *[ *zip(*{ l: h for h, l in zip(*ax.get_legend_handles_labels()) }.items()) ][::-1], loc="upper left", bbox_to_anchor=(1.00, 1.00), ) plt.tight_layout() if video_filename is not None: canvas.draw() image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) image_from_plot = image_from_plot.reshape( fig.canvas.get_width_height()[::-1] + (3, )) images.append(image_from_plot) if video_filename is not None: logger.info("... writing video ...") video_write(video_filename, images, framerate=framerate) logger.info("... done!") if show: plt.show()