示例#1
0
    def _rollout(self, state, action):
        next_state = state  #self.t_model(state, action)

        trajectory = []  #

        total_discounted_reward = []
        for i in range(self.rollout_depth):

            next_action = self._next_action(state)  #self._next_action(state)

            # Simulate next.
            next_obs, next_reward, is_terminal, _ = self.env_model._step(
                next_action)
            next_state = GymState(next_obs)

            # Track rewards and states.
            total_discounted_reward.append(self.gamma**i * next_reward)
            trajectory.append((next_state, next_action))

            if is_terminal:
                break

        self.env_model._set_current_observation(state.data)

        # Update all visited nodes.
        for i, experience in enumerate(trajectory):
            s, a = experience
            self.visitation_counts[(s, a)] += 1
            self.value_total[(s, a)] += sum(total_discounted_reward[i:])

        return total_discounted_reward
    def _reward_func(self, state, action):
        '''
x        Args:
            state (AtariState)
            action (str)

        Returns
            (float)
        '''
        obs, reward, is_terminal, self.next_info = self.env.step(action)

        obs_f = self.process_state(obs)

        # for i, v in enumerate(obs_f):
        #     if v > 0:
        #         print('obs[', i, '] = ', v)
        # print('obs=', obs)

        # print('next_info=', self.next_info)

        # TODO: Hack to make MontezumaRevenge terminates with 1 life.
        # if 'ale.lives' in self.next_info and self.next_info['ale.lives'] == 5:
        #     is_terminal = True

        if self.render:
            self.env.render()

        self.next_state = GymState(obs_f, is_terminal=is_terminal)

        # print('next_state.data=', self.next_state.data)

        if type(reward) is np.ndarray:
            return reward[0]
        return reward
    def __init__(self,
                 env_name='CartPole-v0',
                 grayscale=False,
                 downscale=False,
                 flatten=True,
                 render=False):
        '''
        Args:
            env_name (str)
        '''
        self.env_name = env_name
        self.env = gym.make(env_name)
        self.grayscale = grayscale
        self.downscale = downscale
        self.flatten = flatten
        self.render = render

        # if self.grayscale:
        #     # TODO: it is only applicable if the environment is Atari games.
        #     self.env.env._get_image = self.env.env.ale.getScreenGrayscale

        if isinstance(self.env.action_space, gym.spaces.Discrete):
            self.atari = True
        else:
            self.atari = False

        if self.atari:
            # Atari games
            MDP.__init__(self,
                         range(self.env.action_space.n),
                         self._transition_func,
                         self._reward_func,
                         init_state=GymState(
                             self.process_state(self.env.reset())))
        else:
            # MuJoCo experiments
            MDP.__init__(self,
                         self.env.action_space,
                         self._transition_func,
                         self._reward_func,
                         init_state=GymState(self.env.reset()))
示例#4
0
 def __init__(self, env_name='CartPole-v0', render=False):
     '''
     Args:
         env_name (str)
     '''
     self.env_name = env_name
     self.env = gym.make(env_name)
     self.render = render
     MDP.__init__(self,
                  range(self.env.action_space.n),
                  self._transition_func,
                  self._reward_func,
                  init_state=GymState(self.env.reset()))
示例#5
0
    def _reward_func(self, state, action):
        '''
        Args:
            state (AtariState)
            action (str)

        Returns
            (float)
        '''
        obs, reward, is_terminal, info = self.env.step(action)

        if self.render and self.episode % self.render_every_n_episodes == 0:
            self.env.render()

        self.next_state = GymState(obs, is_terminal=is_terminal)

        return reward
    def _transition_func(self, state, action):
        '''
        Args:
            state (AtariState)
            action (str)

        Returns
            (State)
        '''
        obs, reward, is_terminal, info = self.env.step(action)

        if self.render and (self.render_every_n_episodes == 0 or
                            self.episode % self.render_every_n_episodes == 0):
            self.env.render()

        self.prev_reward = reward
        self.next_state = GymState(obs, is_terminal=is_terminal)

        return self.next_state
示例#7
0
    def _reward_func(self, state, action):
        '''
        Args:
            state (AtariState)
            action (str)

        Returns
            (float)
        '''
        obs, reward, is_terminal, info = self.env.step(action)

        if self.render:
            self.env.render()

        self.next_state = GymState(
            obs,
            is_terminal=is_terminal,
            interaction_features=self.interaction_features)

        return reward
 def __init__(self,
              env_name='CartPole-v0',
              render=False,
              render_every_n_episodes=0):
     '''
     Args:
         env_name (str)
         render (bool): If True, renders the screen every time step.
         render_every_n_epsiodes (int): @render must be True, then renders the screen every n episodes.
     '''
     # self.render_every_n_steps = render_every_n_steps
     self.render_every_n_episodes = render_every_n_episodes
     self.episode = 0
     self.env_name = env_name
     self.env = gym.make(env_name)
     self.render = render
     MDP.__init__(self,
                  range(self.env.action_space.n),
                  self._transition_func,
                  self._reward_func,
                  init_state=GymState(self.env.reset()))
 def reset(self):
     self.init_state = copy.deepcopy(
         GymState(self.process_state(self.env.reset()), False))
     self.next_state = None
     self.next_info = None
     self.cur_state = copy.deepcopy(self.init_state)