def build_visualizer(self, record_path): atari_params = {'environment': self._environment} atari_plot = atari_plotter.AtariPlotter(parameter_dict=atari_params) return agent_visualizer.AgentVisualizer( record_path=record_path, plotters=[atari_plot], screen_width=atari_plot.parameters['width'], screen_height=atari_plot.parameters['height'])
def visualize(self, record_path, num_global_steps=500): if not tf.io.gfile.exists(record_path): tf.io.gfile.makedirs(record_path) self._agent.eval_mode = True # Set up the game playback rendering. atari_params = {'environment': self._environment} atari_plot = atari_plotter.AtariPlotter(parameter_dict=atari_params) # Plot the rewards received next to it. reward_params = { 'x': atari_plot.parameters['width'], 'xlabel': 'Timestep', 'ylabel': 'Reward', 'title': 'Rewards', 'get_line_data_fn': self._agent.get_rewards } reward_plot = line_plotter.LinePlotter(parameter_dict=reward_params) action_names = [ 'Action {}'.format(x) for x in range(self._agent.num_actions) ] # Plot Q-values (DQN) or Q-value distributions (Rainbow). q_params = { 'x': atari_plot.parameters['width'] // 2, 'y': atari_plot.parameters['height'], 'legend': action_names } if 'DQN' in self._agent.__class__.__name__: q_params['xlabel'] = 'Timestep' q_params['ylabel'] = 'Q-Value' q_params['title'] = 'Q-Values' q_params['get_line_data_fn'] = self._agent.get_q_values q_plot = line_plotter.LinePlotter(parameter_dict=q_params) elif 'Implicit' in self._agent.__class__.__name__: q_params['xlabel'] = 'Timestep' q_params['ylabel'] = 'Quantile Value' q_params['title'] = 'Quantile Values' q_params['get_line_data_fn'] = self._agent.get_q_values q_plot = line_plotter.LinePlotter(parameter_dict=q_params) else: q_params['xlabel'] = 'Return' q_params['ylabel'] = 'Return probability' q_params['title'] = 'Return distribution' q_params['get_bar_data_fn'] = self._agent.get_probabilities q_plot = bar_plotter.BarPlotter(parameter_dict=q_params) screen_width = (atari_plot.parameters['width'] + reward_plot.parameters['width']) screen_height = (atari_plot.parameters['height'] + q_plot.parameters['height']) # Dimensions need to be divisible by 2: if screen_width % 2 > 0: screen_width += 1 if screen_height % 2 > 0: screen_height += 1 visualizer = agent_visualizer.AgentVisualizer( record_path=record_path, plotters=[atari_plot, reward_plot, q_plot], screen_width=screen_width, screen_height=screen_height) global_step = 0 while global_step < num_global_steps: initial_observation = self._environment.reset() action = self._agent.begin_episode(initial_observation) while True: observation, reward, is_terminal, _ = self._environment.step( action) global_step += 1 visualizer.visualize() if self._environment.game_over or global_step >= num_global_steps: break elif is_terminal: self._agent.end_episode(reward) action = self._agent.begin_episode(observation) else: action = self._agent.step(reward, observation) self._end_episode(reward) visualizer.generate_video()
def visualize(self, record_path, num_global_steps=500): '''customize viz for bubble - origin from MyRunner.visualize() ''' print('RUN> visualize(%s, %d)' % (record_path, num_global_steps)) if not tf.gfile.Exists(record_path): tf.gfile.MakeDirs(record_path) self._agent.eval_mode = True # Set up the game playback rendering. atari_params = { 'environment': self._environment, 'width': 240, 'height': 224 } atari_plot = atari_plotter.AtariPlotter(parameter_dict=atari_params) # Plot the rewards received next to it. reward_params = { 'x': atari_plot.parameters['width'], 'xlabel': 'Timestep', 'ylabel': 'Reward', 'title': 'Rewards', 'get_line_data_fn': self._agent.get_rewards } #reward_plot = line_plotter.LinePlotter(parameter_dict=reward_params) reward_plot = MyLinePlotter(parameter_dict=reward_params) action_names = [ 'Action {}'.format(x) for x in range(self._agent.num_actions) ] # Plot Observation at left-bottom obsrv_params = { 'x': atari_plot.parameters['x'], 'y': atari_plot.parameters['height'] - 10, 'width': atari_plot.parameters['width'], 'height': atari_plot.parameters['height'], } obsrv_plot = MyObservationPlotter(parameter_dict=obsrv_params) # Plot Q-values (DQN) or Q-value distributions (Rainbow). q_params = { 'x': atari_plot.parameters['width'], 'y': atari_plot.parameters['height'], 'legend': action_names } if 'DQN' in self._agent.__class__.__name__: q_params['xlabel'] = 'Timestep' q_params['ylabel'] = 'Q-Value' q_params['title'] = 'Q-Values' q_params['get_line_data_fn'] = self._agent.get_q_values q_plot = MyLinePlotter(parameter_dict=q_params) else: q_params['xlabel'] = 'Return' q_params['ylabel'] = 'Return probability' q_params['title'] = 'Return distribution' q_params['get_bar_data_fn'] = self._agent.get_probabilities q_plot = MyBarPlotter(parameter_dict=q_params) # Screen Size screen_width = (atari_plot.parameters['width'] + reward_plot.parameters['width']) screen_height = (atari_plot.parameters['height'] + q_plot.parameters['height']) # Dimensions need to be divisible by 2: screen_width += 1 if screen_width % 2 > 0 else 0 screen_height += 1 if screen_height % 2 > 0 else 0 # build visualizer. visualizer = agent_visualizer.AgentVisualizer( record_path=record_path, plotters=[atari_plot, reward_plot, obsrv_plot, q_plot], screen_width=screen_width, screen_height=screen_height) # run loop in global_step global_step = 0 while global_step < num_global_steps: initial_observation = self._environment.reset() action = self._agent.begin_episode(initial_observation) while True: observation, reward, is_terminal, info = self._environment.step( action) global_step += 1 obsrv_plot.setObservation(observation) visualizer.visualize() if self._environment.game_over or global_step >= num_global_steps: break elif is_terminal: self._agent.end_episode(reward) action = self._agent.begin_episode(observation) else: action = self._agent.step(reward, observation, info) self._end_episode(reward) visualizer.generate_video()