class ExplorationOrExploitationAgent(DQNAgent):
    def __init__(self, env, agent_params):
        super(ExplorationOrExploitationAgent, self).__init__(env, agent_params)

        self.replay_buffer = MemoryOptimizedReplayBuffer(100000,
                                                         1,
                                                         float_obs=True)
        self.num_exploration_steps = agent_params['num_exploration_steps']
        self.offline_exploitation = agent_params['offline_exploitation']

        self.exploitation_critic = CQLCritic(agent_params, self.optimizer_spec)
        self.exploration_critic = DQNCritic(agent_params, self.optimizer_spec)

        self.exploration_model = RNDModel(agent_params, self.optimizer_spec)
        self.explore_weight_schedule = agent_params['explore_weight_schedule']
        self.exploit_weight_schedule = agent_params['exploit_weight_schedule']

        self.actor = ArgMaxPolicy(self.exploration_critic)
        self.eval_policy = ArgMaxPolicy(self.exploitation_critic)
        self.exploit_rew_shift = agent_params['exploit_rew_shift']
        self.exploit_rew_scale = agent_params['exploit_rew_scale']
        self.eps = agent_params['eps']
        self.l2_info = agent_params['l2_info']

    def dist2(self, x, c):
        """
        dist2  Calculates squared distance between two sets of points.

        Description
        D = DIST2(X, C) takes two matrices of vectors and calculates the
        squared Euclidean distance between them.  Both matrices must be of
        the same column dimension.  If X has M rows and N columns, and C has
        L rows and N columns, then the result has M rows and L columns.  The
        I, Jth entry is the  squared distance from the Ith row of X to the
        Jth row of C.

        Adapted from code by Christopher M Bishop and Ian T Nabney.
        """
        ndata, dimx = x.shape
        ncenters, dimc = c.shape
        return (np.ones((ncenters, 1)) * np.sum((x**2).T, axis=0)).T + \
                np.ones((   ndata, 1)) * np.sum((c**2).T, axis=0)    - \
                2 * np.inner(x, c)

    def train(self, ob_no, ac_na, re_n, next_ob_no, terminal_n):
        log = {}
        # if len(ob_no) != 0:
        #     print (ob_no.shape) #(256,2)

        if self.t > self.num_exploration_steps:
            # TODO: After exploration is over, set the actor to optimize the extrinsic critic
            #HINT: Look at method ArgMaxPolicy.set_critic
            self.actor.set_critic(self.exploitation_critic)

        if (self.t > self.learning_starts and self.t % self.learning_freq == 0
                and self.replay_buffer.can_sample(self.batch_size)):

            # Get Reward Weights
            # TODO: Get the current explore reward weight and exploit reward weight
            #       using the schedule's passed in (see __init__)
            # COMMENT: Until part 3, explore_weight = 1, and exploit_weight = 0
            explore_weight = self.explore_weight_schedule.value(self.t)
            exploit_weight = self.exploit_weight_schedule.value(self.t)

            # Run Exploration Model #
            # TODO: Evaluate the exploration model on s' to get the exploration bonus
            # HINT: Normalize the exploration bonus, as RND values vary highly in magnitude
            if self.l2_info:
                dist = self.dist2(next_ob_no, ob_no)
                expl_bonus = np.sum(dist, axis=1)
            else:
                expl_bonus = self.exploration_model.forward_np(next_ob_no)
            expl_bonus = (expl_bonus - np.mean(expl_bonus)) / np.std(
                expl_bonus)  # TODO: Normalize

            # Reward Calculations #
            # TODO: Calculate mixed rewards, which will be passed into the exploration critic
            # HINT: See doc for definition of mixed_reward
            mixed_reward = explore_weight * expl_bonus + exploit_weight * re_n

            # TODO: Calculate the environment reward
            # HINT: For part 1, env_reward is just 're_n'
            #       After this, env_reward is 're_n' shifted by self.exploit_rew_shift,
            #       and scaled by self.exploit_rew_scale
            env_reward = re_n
            env_reward = (env_reward +
                          self.exploit_rew_shift) * self.exploit_rew_scale

            # Update Critics And Exploration Model #

            # TODO 1): Update the exploration model (based off s')
            # TODO 2): Update the exploration critic (based off mixed_reward)
            # TODO 3): Update the exploitation critic (based off env_reward)
            expl_model_loss = self.exploration_model.update(
                ptu.from_numpy(next_ob_no))
            exploration_critic_loss = self.exploration_critic.update(
                ob_no, ac_na, next_ob_no, mixed_reward, terminal_n)
            exploitation_critic_loss = self.exploitation_critic.update(
                ob_no, ac_na, next_ob_no, env_reward, terminal_n)

            # Target Networks #
            if self.num_param_updates % self.target_update_freq == 0:
                # TODO: Update the exploitation and exploration target networks
                self.exploration_critic.update_target_network()
                self.exploitation_critic.update_target_network()

            # Logging #
            log['Exploration Critic Loss'] = exploration_critic_loss[
                'Training Loss']
            log['Exploitation Critic Loss'] = exploitation_critic_loss[
                'Training Loss']
            log['Exploration Model Loss'] = expl_model_loss

            # TODO: Uncomment these lines after completing cql_critic.py
            log['Exploitation Data q-values'] = exploitation_critic_loss[
                'Data q-values']
            log['Exploitation OOD q-values'] = exploitation_critic_loss[
                'OOD q-values']
            log['Exploitation CQL Loss'] = exploitation_critic_loss['CQL Loss']

            self.num_param_updates += 1

        self.t += 1
        return log

    def step_env(self):
        """
            Step the env and store the transition
            At the end of this block of code, the simulator should have been
            advanced one step, and the replay buffer should contain one more transition.
            Note that self.last_obs must always point to the new latest observation.
        """
        if (not self.offline_exploitation) or (self.t <=
                                               self.num_exploration_steps):
            self.replay_buffer_idx = self.replay_buffer.store_frame(
                self.last_obs)

        perform_random_action = np.random.random(
        ) < self.eps or self.t < self.learning_starts

        if perform_random_action:
            action = self.env.action_space.sample()
        else:
            processed = self.replay_buffer.encode_recent_observation()
            action = self.actor.get_action(processed)

        next_obs, reward, done, info = self.env.step(action)
        self.last_obs = next_obs.copy()

        if (not self.offline_exploitation) or (self.t <=
                                               self.num_exploration_steps):
            self.replay_buffer.store_effect(self.replay_buffer_idx, action,
                                            reward, done)

        if done:
            self.last_obs = self.env.reset()
class ExplorationOrExploitationAgent(DQNAgent):
    exploration_model: BaseExplorationModel

    def __init__(self, env, agent_params):
        super(ExplorationOrExploitationAgent, self).__init__(env, agent_params)

        self.replay_buffer = MemoryOptimizedReplayBuffer(100000,
                                                         1,
                                                         float_obs=True)
        self.num_exploration_steps = agent_params['num_exploration_steps']
        self.offline_exploitation = agent_params['offline_exploitation']

        self.exploitation_critic = CQLCritic(agent_params, self.optimizer_spec)
        self.exploration_critic = DQNCritic(agent_params, self.optimizer_spec)

        if agent_params['use_cbe']:
            self.exploration_model = CountBasedModel(
                agent_params['cbe_coefficient'], env)
        else:
            self.exploration_model = RNDModel(agent_params,
                                              self.optimizer_spec)

        self.explore_weight_schedule: Schedule = agent_params[
            'explore_weight_schedule']
        self.exploit_weight_schedule: Schedule = agent_params[
            'exploit_weight_schedule']

        self.actor = ArgMaxPolicy(self.exploration_critic)
        self.eval_policy = ArgMaxPolicy(self.exploitation_critic)
        self.exploit_rew_shift = agent_params['exploit_rew_shift']
        self.exploit_rew_scale = agent_params['exploit_rew_scale']
        self.eps = agent_params['eps']

    def train(self, ob_no, ac_na, re_n, next_ob_no, terminal_n):
        log = {}

        if self.t > self.num_exploration_steps:
            # After exploration is over, set the actor to optimize the extrinsic critic
            # HINT: Look at method ArgMaxPolicy.set_critic
            self.actor.set_critic(self.exploitation_critic)

        if (self.t > self.learning_starts and self.t % self.learning_freq == 0
                and self.replay_buffer.can_sample(self.batch_size)):
            # Get Reward Weights
            # Get the current explore reward weight and exploit reward weight
            #       using the schedule's passed in (see __init__)
            # COMMENT: Until part 3, explore_weight = 1, and exploit_weight = 0
            explore_weight = self.explore_weight_schedule.value(self.t)
            exploit_weight = self.exploit_weight_schedule.value(self.t)

            # Run Exploration Model #
            # Evaluate the exploration model on s' to get the exploration bonus
            # HINT: Normalize the exploration bonus, as RND values vary highly in magnitude
            expl_bonus = self.exploration_model.forward_np(next_ob_no)
            expl_bonus = normalize(
                expl_bonus,
                expl_bonus.mean(),
                expl_bonus.std(),
            )

            # Reward Calculations #
            # Calculate mixed rewards, which will be passed into the exploration critic
            # HINT: See doc for definition of mixed_reward
            mixed_reward = (explore_weight * expl_bonus +
                            exploit_weight * re_n)

            # Calculate the environment reward
            # HINT: For part 1, env_reward is just 're_n'
            #       After this, env_reward is 're_n' shifted by self.exploit_rew_shift,
            #       and scaled by self.exploit_rew_scale
            env_reward = (re_n +
                          self.exploit_rew_shift) * self.exploit_rew_scale

            # Update Critics And Exploration Model #

            # 1): Update the exploration model (based off s')
            expl_model_loss = self.exploration_model.update(next_ob_no)
            # 2): Update the exploration critic (based off mixed_reward)
            exploration_critic_loss = self.exploration_critic.update(
                ob_no, ac_na, next_ob_no, mixed_reward, terminal_n)
            # 3): Update the exploitation critic (based off env_reward)
            exploitation_critic_loss = self.exploitation_critic.update(
                ob_no, ac_na, next_ob_no, env_reward, terminal_n)

            # Target Networks #
            if self.num_param_updates % self.target_update_freq == 0:
                self.exploration_critic.update_target_network()
                self.exploitation_critic.update_target_network()

            # Logging #
            log['Exploration Critic Loss'] = exploration_critic_loss[
                'Training Loss']
            log['Exploitation Critic Loss'] = exploitation_critic_loss[
                'Training Loss']
            log['Exploration Model Loss'] = expl_model_loss

            # Uncomment these lines after completing cql_critic.py
            log['Exploitation Data q-values'] = exploitation_critic_loss[
                'Data q-values']
            log['Exploitation OOD q-values'] = exploitation_critic_loss[
                'OOD q-values']
            log['Exploitation CQL Loss'] = exploitation_critic_loss['CQL Loss']

            self.num_param_updates += 1

        self.t += 1
        return log

    def step_env(self):
        """
            Step the env and store the transition
            At the end of this block of code, the simulator should have been
            advanced one step, and the replay buffer should contain one more transition.
            Note that self.last_obs must always point to the new latest observation.
        """
        if (not self.offline_exploitation) or (self.t <=
                                               self.num_exploration_steps):
            self.replay_buffer_idx = self.replay_buffer.store_frame(
                self.last_obs)

        perform_random_action = np.random.random(
        ) < self.eps or self.t < self.learning_starts

        if perform_random_action:
            action = self.env.action_space.sample()
        else:
            processed = self.replay_buffer.encode_recent_observation()
            action = self.actor.get_action(processed)

        next_obs, reward, done, info = self.env.step(int(action))
        self.last_obs = next_obs.copy()

        if (not self.offline_exploitation) or (self.t <=
                                               self.num_exploration_steps):
            self.replay_buffer.store_effect(self.replay_buffer_idx, action,
                                            reward, done)

        if done:
            self.last_obs = self.env.reset()
Beispiel #3
0
class ExplorationOrExploitationAgent(DQNAgent):
    def __init__(self, env, agent_params):
        super(ExplorationOrExploitationAgent, self).__init__(env, agent_params)
        
        self.replay_buffer = MemoryOptimizedReplayBuffer(100000, 1, float_obs=True)
        self.num_exploration_steps = agent_params['num_exploration_steps']
        self.offline_exploitation = agent_params['offline_exploitation']

        self.exploitation_critic = CQLCritic(agent_params, self.optimizer_spec)
        self.exploration_critic = DQNCritic(agent_params, self.optimizer_spec)
        
        self.exploration_model = RNDModel(agent_params, self.optimizer_spec)
        self.explore_weight_schedule = agent_params['explore_weight_schedule']
        self.exploit_weight_schedule = agent_params['exploit_weight_schedule']
        
        self.actor = ArgMaxPolicy(self.exploration_critic)
        self.eval_policy = ArgMaxPolicy(self.exploitation_critic)
        self.exploit_rew_shift = agent_params['exploit_rew_shift']
        self.exploit_rew_scale = agent_params['exploit_rew_scale']
        self.eps = agent_params['eps']

    def train(self, ob_no, ac_na, re_n, next_ob_no, terminal_n):
        log = {}

        if self.t > self.num_exploration_steps:
            self.actor.set_critic(self.exploitation_critic)

        if (self.t > self.learning_starts
                and self.t % self.learning_freq == 0
                and self.replay_buffer.can_sample(self.batch_size)
        ):

            # Get Reward Weights
            # COMMENT: Until part 3, explore_weight = 1, and exploit_weight = 0
            # explore_weight = 1
            # exploit_weight = 0
            explore_weight = self.explore_weight_schedule.value(self.t)
            exploit_weight = self.exploit_weight_schedule.value(self.t)

            # Run Exploration Model #
            expl_bonus = self.exploration_model.forward_np(next_ob_no)
            expl_bonus = normalize(expl_bonus, np.mean(expl_bonus), np.std(expl_bonus))

            # Reward Calculations #
            mixed_reward = explore_weight * expl_bonus + exploit_weight * re_n
            env_reward = (re_n + self.exploit_rew_shift) * self.exploit_rew_scale

            # Update Critics And Exploration Model #
            expl_model_loss = self.exploration_model.update(next_ob_no)
            exploration_critic_loss = self.exploration_critic.update(ob_no, ac_na, next_ob_no,
                                                                     mixed_reward, terminal_n)
            exploitation_critic_loss = self.exploitation_critic.update(ob_no, ac_na, next_ob_no,
                                                                       env_reward, terminal_n)

            # Target Networks #
            if self.num_param_updates % self.target_update_freq == 0:
                self.exploitation_critic.update_target_network()
                self.exploration_critic.update_target_network()

            # Logging #
            log['Exploration Critic Loss'] = exploration_critic_loss['Training Loss']
            log['Exploitation Critic Loss'] = exploitation_critic_loss['Training Loss']
            log['Exploration Model Loss'] = expl_model_loss
            log['Exploitation Data q-values'] = exploitation_critic_loss['Data q-values']
            log['Exploitation OOD q-values'] = exploitation_critic_loss['OOD q-values']
            log['Exploitation CQL Loss'] = exploitation_critic_loss['CQL Loss']

            self.num_param_updates += 1

        self.t += 1
        return log


    def step_env(self):
        """
            Step the env and store the transition
            At the end of this block of code, the simulator should have been
            advanced one step, and the replay buffer should contain one more transition.
            Note that self.last_obs must always point to the new latest observation.
        """
        if (not self.offline_exploitation) or (self.t <= self.num_exploration_steps):
            self.replay_buffer_idx = self.replay_buffer.store_frame(self.last_obs)

        perform_random_action = np.random.random() < self.eps or self.t < self.learning_starts

        if perform_random_action:
            action = self.env.action_space.sample()
        else:
            processed = self.replay_buffer.encode_recent_observation()
            action = self.actor.get_action(processed)

        next_obs, reward, done, info = self.env.step(action)
        self.last_obs = next_obs.copy()

        if (not self.offline_exploitation) or (self.t <= self.num_exploration_steps):
            self.replay_buffer.store_effect(self.replay_buffer_idx, action, reward, done)

        if done:
            self.last_obs = self.env.reset()