예제 #1
0
    def __init__(self, env, agent_params):

        self.env = env
        self.agent_params = agent_params
        self.batch_size = agent_params['batch_size']
        self.last_obs = self.env.reset()

        self.num_actions = agent_params['ac_dim']
        self.learning_starts = agent_params['learning_starts']
        self.learning_freq = agent_params['learning_freq']
        self.target_update_freq = agent_params['target_update_freq']

        self.replay_buffer_idx = None
        self.exploration = agent_params['exploration_schedule']
        self.optimizer_spec = agent_params['optimizer_spec']

        self.critic = DQNCritic(agent_params, self.optimizer_spec)
        self.actor = ArgMaxPolicy(self.critic)

        lander = agent_params['env_name'].startswith('LunarLander')
        self.replay_buffer = MemoryOptimizedReplayBuffer(
            agent_params['replay_buffer_size'],
            agent_params['frame_history_len'],
            lander=lander)
        self.t = 0
        self.num_param_updates = 0
예제 #2
0
class DQNAgent(BaseAgent):
    def __init__(self, env, agent_params):
        super(DQNAgent, self).__init__()

        # init vars
        self.env = env
        self.agent_params = agent_params
        self.gamma = self.agent_params['gamma']
        self.target_update_freq = self.agent_params['target_update_freq']

        # actor and critic
        self.critic = DQNCritic(self.agent_params)
        self.actor = ArgMaxPolicy(self.critic, self.agent_params['dtype'])
        self.num_param_updates = 0

        # replay buffer
        self.replay_buffer = ReplayBuffer(1000000)

    def add_to_replay_buffer(self, paths):
        self.replay_buffer.add_rollouts(paths)

    def sample(self, batch_size):
        return self.replay_buffer.sample_random_data(batch_size)

    def train(self, ob_batch, ac_batch, re_batch, next_ob_batch,
              terminal_batch):
        if (self.num_param_updates % self.target_update_freq == 0):
            self.critic.update_target()
        self.num_param_updates += 1
        loss = self.critic.update(ob_batch, ac_batch, re_batch, next_ob_batch,
                                  terminal_batch)
        return loss

    def save_model(self, path):
        self.critic.save_model(path)
예제 #3
0
    def __init__(self, env, agent_params):
        super(DQNAgent, self).__init__()

        # init vars
        self.env = env
        self.agent_params = agent_params
        self.gamma = self.agent_params['gamma']
        self.target_update_freq = self.agent_params['target_update_freq']

        # actor and critic
        self.critic = DQNCritic(self.agent_params)
        self.actor = ArgMaxPolicy(self.critic, self.agent_params['dtype'])
        self.num_param_updates = 0

        # replay buffer
        self.replay_buffer = ReplayBuffer(1000000)
예제 #4
0
class DQNAgent(object):
    def __init__(self, env, agent_params):

        self.env = env
        self.agent_params = agent_params
        self.batch_size = agent_params['batch_size']
        self.last_obs = self.env.reset()

        self.num_actions = agent_params['ac_dim']
        self.learning_starts = agent_params['learning_starts']
        self.learning_freq = agent_params['learning_freq']
        self.target_update_freq = agent_params['target_update_freq']

        self.replay_buffer_idx = None
        self.exploration = agent_params['exploration_schedule']
        self.optimizer_spec = agent_params['optimizer_spec']

        self.critic = DQNCritic(agent_params, self.optimizer_spec)
        self.actor = ArgMaxPolicy(self.critic)

        lander = agent_params['env_name'].startswith('LunarLander')
        self.replay_buffer = MemoryOptimizedReplayBuffer(
            agent_params['replay_buffer_size'],
            agent_params['frame_history_len'],
            lander=lander)
        self.t = 0
        self.num_param_updates = 0

    def add_to_replay_buffer(self, paths):
        pass

    def step_env(self):
        """
            Step the env and store the transition
            At the end of this block of code, the simulator should have been
            advanced one step, and the replay buffer should contain one more transition.
            Note that self.last_obs must always point to the new latest observation.
        """
        self.replay_buffer_idx = self.replay_buffer.store_frame(self.last_obs)
        eps = self.exploration.value(self.t)

        if self.t < self.learning_starts:
            # initially take random actions to get diverse behavior in the buffer
            perform_random_action = True
        else:
            """
            TODO: epsilon greedy takes a random action with probability eps.
            Set the perform_random_action variable appropriately.
            """
            if np.random.uniform(0, 1) < eps:
                perform_random_action = True
            else:
                perform_random_action = False
            """
            END CODE
            """
        if perform_random_action:
            """
            TODO: take a random action if perform_random_action is True
            """
            action = self.env.action_space.sample()
            """
            END CODE
            """
        else:
            # to deal with partial observability, we take in multiple previous
            # observations and feed them to the actor
            processed_obs = self.replay_buffer.encode_recent_observation()
            """
            TODO: otherwise, take the action accoding to the argmax policy
            """
            action = self.actor.get_action(processed_obs)
            """
            END CODE
            """

        # takes a step in the environment using the action from the policy
        next_obs, reward, done, info = self.env.step(action)
        self.last_obs = next_obs.copy()

        # stores the result of taking this action into the replay buffer
        self.replay_buffer.store_effect(self.replay_buffer_idx, action, reward,
                                        done)

        # If taking this step resulted in the episode terminating, reset the env (and the latest observation)
        if done:
            self.last_obs = self.env.reset()

    def sample(self, batch_size):
        if self.replay_buffer.can_sample(self.batch_size):
            return self.replay_buffer.sample(batch_size)
        else:
            return [], [], [], [], []

    def train(self, ob_no, ac_na, re_n, next_ob_no, terminal_n):
        log = {}
        if (self.t > self.learning_starts and self.t % self.learning_freq == 0
                and self.replay_buffer.can_sample(self.batch_size)):

            log = self.critic.update(
                ob_no,
                ac_na,
                next_ob_no,
                re_n,
                terminal_n,
            )

            # update the target network periodically
            if self.num_param_updates % self.target_update_freq == 0:
                self.critic.update_target_network()

            self.num_param_updates += 1

        self.t += 1
        return log