def _init_value_vis(self): self.vf_fig = plt.figure("Value Function") self.vf_ax, self.vf_img = self._init_vis_common(self.vf_fig) # Create quivers for each action. 4 in total shift = self.ACTIONS * self.SHIFT x, y = np.arange(self.cols), np.arange(self.rows) for name, s in zip(self.ARROW_NAMES, shift): self.arrow_figs.append( self._init_arrow(name, *np.meshgrid(x + s[1], y + s[0]), self.vf_ax) )
def _init_domain_vis(self, s, legend=False, noticks=False): fig_name = f"{(self.__class__.__name__)}: {self.mapname}" if self.performance: fig_name += "(Evaluation)" self.domain_fig = plt.figure(fig_name) self.domain_ax = self.domain_fig.add_subplot(111) self._show_map(legend=legend) if noticks: self.domain_ax.set_xticks([]) self.domain_ax.set_yticks([]) else: self._set_ticks(self.domain_ax) self.agent_fig = self._agent_fig(s)
def _init_heatmap_vis( self, name, cmap, nrows, ncols, index, legend, ticks, cmap_vmin, cmap_vmax ): if name not in self.heatmap_fig: self.heatmap_fig[name] = plt.figure(name) self.heatmap_fig[name].show() ax, img = self._init_vis_common( self.heatmap_fig[name], cmap=cmap, axarg=(nrows, ncols, index), legend=legend, ticks=ticks, cmap_vmin=cmap_vmin, cmap_vmax=cmap_vmax, ) self.heatmap_ax[name, index], self.heatmap_img[name, index] = ax, img
def __init__( self, xy_discr, nrows=1, ncols=1, name="Pinball Value Function", cmap="ValueFunction-New", vmin=-10, vmax=10, ): from rlpy.tools.plotting import plt self.fig = plt.figure(name) cmap = plt.get_cmap(cmap) self.data_shape = xy_discr, xy_discr dummy_data = np.zeros(self.data_shape) self.imgs = [] for i in range(nrows * ncols): ax = self.fig.add_subplot(nrows, ncols, i + 1) img = ax.imshow( dummy_data, cmap=cmap, interpolation="nearest", vmin=vmin, vmax=vmax, ) cbar = ax.figure.colorbar(img, ax=ax) cbar.ax.set_ylabel("", rotation=-90, va="bottom") ax.set_xticks([]) ax.set_yticks([]) self.imgs.append(img) self.fig.tight_layout() self.fig.canvas.draw() def close(): plt.close(self.fig) self.close = close
def show_policy( self, policy, value=None, nrows=1, ncols=1, index=1, ticks=True, scale=1.0, title=None, colorbar=False, notext=False, cmap="ValueFunction-New", cmap_vmin=MIN_RETURN, cmap_vmax=MAX_RETURN, arrow_resize=True, figure_title="Policy", ): if figure_title not in self.policy_fig: scale_x = np.sqrt(ncols / nrows) scale_y = np.sqrt(nrows / ncols) with with_scaled_figure(scale_x * scale, scale_y * scale): self.policy_fig[figure_title] = plt.figure(figure_title) self.policy_fig[figure_title].show() fig = self.policy_fig[figure_title] key = figure_title, index if key not in self.policy_ax: self.policy_ax[key], self.policy_img[key] = self._init_vis_common( fig, axarg=(nrows, ncols, index), legend=False, ticks=ticks, cmap=cmap, cmap_vmin=cmap_vmin, cmap_vmax=cmap_vmax, ) shift = self.ACTIONS * self.SHIFT x, y = np.arange(self.cols), np.arange(self.rows) for name, s in zip(self.ARROW_NAMES, shift): grid = np.meshgrid(x + s[1], y + s[0]) self.policy_arrows[key].append( self._init_arrow( name, *grid, self.policy_ax[key], arrow_scale=scale, ) ) if title is not None: self.policy_ax[key].set_title(title) if colorbar: cbar = self.policy_ax[key].figure.colorbar( self.policy_img[key], ax=self.policy_ax[key] ) cbar.ax.set_ylabel("", rotation=-90, va="bottom") arrow_mask = np.ones((self.rows, self.cols, self.num_actions), dtype=np.bool) arrow_size = np.ones(arrow_mask.shape, dtype=np.float32) arrow_color = np.zeros(arrow_mask.shape, dtype=np.uint8) try: policy = policy.reshape(self.rows, self.cols, -1) except ValueError: raise ValueError(f"Invalid policy shape: {policy.shape}") _, _, action_dim = policy.shape for r, c in itertools.product(range(self.rows), range(self.cols)): cell = self.map[r, c] if cell not in (self.START, self.EMPTY): continue s = np.array([r, c]) actions = self.possible_actions(s) best_act = policy[r, c].argmax() arrow_mask[r, c, actions] = False arrow_color[r, c, best_act] = 1 if arrow_resize: arrow_size[r, c] = policy[r, c] # Show Policy for arrows for i, name in enumerate(self.ARROW_NAMES): dy, dx = self.ACTIONS[i] size, mask = arrow_size[:, :, i], arrow_mask[:, :, i] dx = np.ma.masked_array(dx * size, mask=mask) dy = np.ma.masked_array(dy * size * -1, mask=mask) c = np.ma.masked_array(arrow_color[:, :, i], mask=mask) self.policy_arrows[key][i].set_UVC(dx, dy, c) if value is None: self.policy_img[key].set_data(self.map * 0.0) else: try: value = value.reshape(self.rows, self.cols) except ValueError: raise ValueError(f"Invalid value shape: {value.shape}") if not colorbar and not notext: self._reset_texts(self.policy_texts[key]) for r, c, ext_v in self._normalize_value(value): self._text_on_cell( c, r, ext_v, self.policy_texts[key], self.policy_ax[key] ) self.policy_img[key].set_data(value * self._map_mask()) fig.canvas.draw() fig.canvas.flush_events() return key