예제 #1
0
파일: agent.py 프로젝트: xj361685640/coach
    def prepare_initial_state(self):
        """
        Create an initial state when starting a new episode
        :return: None
        """
        observation = self.preprocess_observation(
            self.env.state['observation'])
        self.curr_stack = deque([observation] *
                                self.tp.env.observation_stack_size,
                                maxlen=self.tp.env.observation_stack_size)
        observation = LazyStack(self.curr_stack, -1)

        self.curr_state = {'observation': observation}
        if self.tp.agent.use_measurements:
            self.curr_state['measurements'] = self.env.measurements
            if self.tp.agent.use_accumulated_reward_as_measurement:
                self.curr_state['measurements'] = np.append(
                    self.curr_state['measurements'], 0)
예제 #2
0
파일: agent.py 프로젝트: xj361685640/coach
    def act(self, phase=RunPhase.TRAIN):
        """
        Take one step in the environment according to the network prediction and store the transition in memory
        :param phase: Either Train or Test to specify if greedy actions should be used and if transitions should be stored
        :return: A boolean value that signals an episode termination
        """

        if phase != RunPhase.TEST:
            self.total_steps_counter += 1
        self.current_episode_steps_counter += 1

        # get new action
        action_info = {
            "action_probability": 1.0 / self.env.action_space_size,
            "action_value": 0,
            "max_action_value": 0
        }

        if phase == RunPhase.HEATUP and not self.tp.heatup_using_network_decisions:
            action = self.env.get_random_action()
        else:
            action, action_info = self.choose_action(self.curr_state,
                                                     phase=phase)

        # perform action
        if type(action) == np.ndarray:
            action = action.squeeze()
        result = self.env.step(action)

        shaped_reward = self.preprocess_reward(result['reward'])
        if 'action_intrinsic_reward' in action_info.keys():
            shaped_reward += action_info['action_intrinsic_reward']
        # TODO: should total_reward_in_current_episode include shaped_reward?
        self.total_reward_in_current_episode += result['reward']
        next_state = result['state']
        next_state['observation'] = self.preprocess_observation(
            next_state['observation'])

        # plot action values online
        if self.tp.visualization.plot_action_values_online and phase != RunPhase.HEATUP:
            self.plot_action_values_online()

        # initialize the next state
        # TODO: provide option to stack more than just the observation
        self.curr_stack.append(next_state['observation'])
        observation = LazyStack(self.curr_stack, -1)

        next_state['observation'] = observation
        if self.tp.agent.use_measurements and 'measurements' in result.keys():
            next_state['measurements'] = result['state']['measurements']
            if self.tp.agent.use_accumulated_reward_as_measurement:
                next_state['measurements'] = np.append(
                    next_state['measurements'],
                    self.total_reward_in_current_episode)

        # store the transition only if we are training
        if phase == RunPhase.TRAIN or phase == RunPhase.HEATUP:
            transition = Transition(self.curr_state, result['action'],
                                    shaped_reward, next_state, result['done'])
            for key in action_info.keys():
                transition.info[key] = action_info[key]
            if self.tp.agent.add_a_normalized_timestep_to_the_observation:
                transition.info['timestep'] = float(
                    self.current_episode_steps_counter
                ) / self.env.timestep_limit
            self.memory.store(transition)
        elif phase == RunPhase.TEST and self.tp.visualization.dump_gifs:
            # we store the transitions only for saving gifs
            self.last_episode_images.append(self.env.get_rendered_image())

        # update the current state for the next step
        self.curr_state = next_state

        # deal with episode termination
        if result['done']:
            if self.tp.visualization.dump_csv:
                self.update_log(phase=phase)
            self.log_to_screen(phase=phase)

            if phase == RunPhase.TRAIN or phase == RunPhase.HEATUP:
                self.reset_game()

            self.current_episode += 1
            self.tp.current_episode = self.current_episode

        # return episode really ended
        return result['done']