예제 #1
0
    def build_visualizer(self, record_path):
        atari_params = {'environment': self._environment}
        atari_plot = atari_plotter.AtariPlotter(parameter_dict=atari_params)

        return agent_visualizer.AgentVisualizer(
            record_path=record_path,
            plotters=[atari_plot],
            screen_width=atari_plot.parameters['width'],
            screen_height=atari_plot.parameters['height'])
예제 #2
0
    def visualize(self, record_path, num_global_steps=500):
        if not tf.io.gfile.exists(record_path):
            tf.io.gfile.makedirs(record_path)
        self._agent.eval_mode = True

        # Set up the game playback rendering.
        atari_params = {'environment': self._environment}
        atari_plot = atari_plotter.AtariPlotter(parameter_dict=atari_params)
        # Plot the rewards received next to it.
        reward_params = {
            'x': atari_plot.parameters['width'],
            'xlabel': 'Timestep',
            'ylabel': 'Reward',
            'title': 'Rewards',
            'get_line_data_fn': self._agent.get_rewards
        }
        reward_plot = line_plotter.LinePlotter(parameter_dict=reward_params)
        action_names = [
            'Action {}'.format(x) for x in range(self._agent.num_actions)
        ]
        # Plot Q-values (DQN) or Q-value distributions (Rainbow).
        q_params = {
            'x': atari_plot.parameters['width'] // 2,
            'y': atari_plot.parameters['height'],
            'legend': action_names
        }
        if 'DQN' in self._agent.__class__.__name__:
            q_params['xlabel'] = 'Timestep'
            q_params['ylabel'] = 'Q-Value'
            q_params['title'] = 'Q-Values'
            q_params['get_line_data_fn'] = self._agent.get_q_values
            q_plot = line_plotter.LinePlotter(parameter_dict=q_params)
        elif 'Implicit' in self._agent.__class__.__name__:
            q_params['xlabel'] = 'Timestep'
            q_params['ylabel'] = 'Quantile Value'
            q_params['title'] = 'Quantile Values'
            q_params['get_line_data_fn'] = self._agent.get_q_values
            q_plot = line_plotter.LinePlotter(parameter_dict=q_params)
        else:
            q_params['xlabel'] = 'Return'
            q_params['ylabel'] = 'Return probability'
            q_params['title'] = 'Return distribution'
            q_params['get_bar_data_fn'] = self._agent.get_probabilities
            q_plot = bar_plotter.BarPlotter(parameter_dict=q_params)
        screen_width = (atari_plot.parameters['width'] +
                        reward_plot.parameters['width'])
        screen_height = (atari_plot.parameters['height'] +
                         q_plot.parameters['height'])
        # Dimensions need to be divisible by 2:
        if screen_width % 2 > 0:
            screen_width += 1
        if screen_height % 2 > 0:
            screen_height += 1
        visualizer = agent_visualizer.AgentVisualizer(
            record_path=record_path,
            plotters=[atari_plot, reward_plot, q_plot],
            screen_width=screen_width,
            screen_height=screen_height)
        global_step = 0
        while global_step < num_global_steps:
            initial_observation = self._environment.reset()
            action = self._agent.begin_episode(initial_observation)
            while True:
                observation, reward, is_terminal, _ = self._environment.step(
                    action)
                global_step += 1
                visualizer.visualize()
                if self._environment.game_over or global_step >= num_global_steps:
                    break
                elif is_terminal:
                    self._agent.end_episode(reward)
                    action = self._agent.begin_episode(observation)
                else:
                    action = self._agent.step(reward, observation)
            self._end_episode(reward)
        visualizer.generate_video()
예제 #3
0
    def visualize(self, record_path, num_global_steps=500):
        '''customize viz for bubble
        - origin from MyRunner.visualize()
        '''
        print('RUN> visualize(%s, %d)' % (record_path, num_global_steps))
        if not tf.gfile.Exists(record_path):
            tf.gfile.MakeDirs(record_path)
        self._agent.eval_mode = True

        # Set up the game playback rendering.
        atari_params = {
            'environment': self._environment,
            'width': 240,
            'height': 224
        }

        atari_plot = atari_plotter.AtariPlotter(parameter_dict=atari_params)
        # Plot the rewards received next to it.
        reward_params = {
            'x': atari_plot.parameters['width'],
            'xlabel': 'Timestep',
            'ylabel': 'Reward',
            'title': 'Rewards',
            'get_line_data_fn': self._agent.get_rewards
        }
        #reward_plot = line_plotter.LinePlotter(parameter_dict=reward_params)
        reward_plot = MyLinePlotter(parameter_dict=reward_params)
        action_names = [
            'Action {}'.format(x) for x in range(self._agent.num_actions)
        ]
        # Plot Observation at left-bottom
        obsrv_params = {
            'x': atari_plot.parameters['x'],
            'y': atari_plot.parameters['height'] - 10,
            'width': atari_plot.parameters['width'],
            'height': atari_plot.parameters['height'],
        }
        obsrv_plot = MyObservationPlotter(parameter_dict=obsrv_params)
        # Plot Q-values (DQN) or Q-value distributions (Rainbow).
        q_params = {
            'x': atari_plot.parameters['width'],
            'y': atari_plot.parameters['height'],
            'legend': action_names
        }
        if 'DQN' in self._agent.__class__.__name__:
            q_params['xlabel'] = 'Timestep'
            q_params['ylabel'] = 'Q-Value'
            q_params['title'] = 'Q-Values'
            q_params['get_line_data_fn'] = self._agent.get_q_values
            q_plot = MyLinePlotter(parameter_dict=q_params)
        else:
            q_params['xlabel'] = 'Return'
            q_params['ylabel'] = 'Return probability'
            q_params['title'] = 'Return distribution'
            q_params['get_bar_data_fn'] = self._agent.get_probabilities
            q_plot = MyBarPlotter(parameter_dict=q_params)
        # Screen Size
        screen_width = (atari_plot.parameters['width'] +
                        reward_plot.parameters['width'])
        screen_height = (atari_plot.parameters['height'] +
                         q_plot.parameters['height'])
        # Dimensions need to be divisible by 2:
        screen_width += 1 if screen_width % 2 > 0 else 0
        screen_height += 1 if screen_height % 2 > 0 else 0
        # build visualizer.
        visualizer = agent_visualizer.AgentVisualizer(
            record_path=record_path,
            plotters=[atari_plot, reward_plot, obsrv_plot, q_plot],
            screen_width=screen_width,
            screen_height=screen_height)
        # run loop in global_step
        global_step = 0
        while global_step < num_global_steps:
            initial_observation = self._environment.reset()
            action = self._agent.begin_episode(initial_observation)
            while True:
                observation, reward, is_terminal, info = self._environment.step(
                    action)
                global_step += 1
                obsrv_plot.setObservation(observation)
                visualizer.visualize()
                if self._environment.game_over or global_step >= num_global_steps:
                    break
                elif is_terminal:
                    self._agent.end_episode(reward)
                    action = self._agent.begin_episode(observation)
                else:
                    action = self._agent.step(reward, observation, info)
            self._end_episode(reward)
        visualizer.generate_video()