class Agent(AgentABC):
    """Interacts with and learns from the environment."""
    def __init__(self, state_size, action_size, num_agents, random_seed):
        """Initialize an MADDPG Agent object.
        Params
        ======
            :param state_size: dimension of each state
            :param action_size: dimension of each action
            :param num_agents: number of inner agents
            :param random_seed: random seed
        """
        super().__init__(state_size, action_size, num_agents, random_seed)
        self.state_size = state_size
        self.action_size = action_size
        self.num_agents = num_agents
        self.seed = random.seed(random_seed)

        self.actors_local = []
        self.actors_target = []
        self.actor_optimizers = []
        self.critics_local = []
        self.critics_target = []
        self.critic_optimizers = []
        for i in range(num_agents):
            # Actor Network (w/ Target Network)
            self.actors_local.append(
                Actor(state_size, action_size, random_seed).to(device))
            self.actors_target.append(
                Actor(state_size, action_size, random_seed).to(device))
            self.actor_optimizers.append(
                optim.Adam(self.actors_local[i].parameters(), lr=LR_ACTOR))
            # Critic Network (w/ Target Network)
            self.critics_local.append(
                Critic(num_agents * state_size, num_agents * action_size,
                       random_seed).to(device))
            self.critics_target.append(
                Critic(num_agents * state_size, num_agents * action_size,
                       random_seed).to(device))
            self.critic_optimizers.append(
                optim.Adam(self.critics_local[i].parameters(),
                           lr=LR_CRITIC,
                           weight_decay=WEIGHT_DECAY))

        # Noise process for each agent
        self.noise = OUNoise((num_agents, action_size), random_seed)

        # Replay memory
        self.memory = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE,
                                   random_seed)

        # debugging variables
        self.step_count = 0
        self.mse_error_list = []

    def step(self, states, actions, rewards, next_states, dones):
        """Save experience in replay memory, and use random sample from buffer to learn."""
        # Save experience / reward
        self.memory.add(states, actions, rewards, next_states, dones)

        # Learn, if enough samples are available in memory
        # in order to add some stability to the learning, we don't modify weights every turn.
        self.step_count += 1
        if (self.step_count %
                UPDATE_EVERY) == 0:  # learn every #UPDATE_EVERY steps
            for i in range(NUM_UPDATES):  # update #NUM_UPDATES times
                if len(self.memory) > 1000:
                    experiences = self.memory.sample()
                    self.learn(experiences)
                    self.debug_loss = np.mean(self.mse_error_list)
            self.update_target_networks()

    def act(self, state, add_noise=True):
        """Returns actions for given state as per current policy."""
        state = torch.from_numpy(state).float().to(device)
        acts = np.zeros((self.num_agents, self.action_size))
        for agent in range(self.num_agents):
            self.actors_local[agent].eval()
            with torch.no_grad():
                acts[agent, :] = self.actors_local[agent](
                    state[agent, :]).cpu().data.numpy()
            self.actors_local[agent].train()
        if add_noise:
            acts += self.noise.sample()
        return np.clip(acts, -1, 1)

    def reset(self):
        """ see abstract class """
        super().reset()
        self.noise.reset()
        self.mse_error_list = []

    def learn(self, experiences):
        """Update policy and value parameters using given batch of experience tuples.
        Q_targets = r + γ * critic_target(next_full_state, actors_target(next_partial_state) )
        where:
            actor_target(state) -> action
            critic_target(state, action) -> Q-value
        Params
        ======
            experiences (Tuple[torch.Tensor]): tuple of (s, a, r, s', done) tuples 
            gamma (float): discount factor
        """
        states_batched, actions_batched, rewards, next_states_batched, dones = experiences
        states_concated = states_batched.view(
            [BATCH_SIZE, self.num_agents * self.state_size])
        next_states_concated = next_states_batched.view(
            [BATCH_SIZE, self.num_agents * self.state_size])
        actions_concated = actions_batched.view(
            [BATCH_SIZE, self.num_agents * self.action_size])

        for agent in range(self.num_agents):
            actions_next_batched = [
                self.actors_target[i](next_states_batched[:, i, :])
                for i in range(self.num_agents)
            ]
            actions_next_whole = torch.cat(actions_next_batched, 1)
            # ---------------------------- update critic ---------------------------- #
            # Get predicted next-state actions and Q values from target models
            q_targets_next = self.critics_target[agent](next_states_concated,
                                                        actions_next_whole)
            # Compute Q targets for current states (y_i)
            q_targets = rewards[:, agent].view(
                BATCH_SIZE, -1) + (GAMMA * q_targets_next *
                                   (1 - dones[:, agent].view(BATCH_SIZE, -1)))
            # Compute critic loss
            q_expected = self.critics_local[agent](states_concated,
                                                   actions_concated)
            critic_loss = F.mse_loss(q_expected, q_targets)
            # Minimize the loss
            self.critic_optimizers[agent].zero_grad()
            critic_loss.backward()
            self.critic_optimizers[agent].step()
            # save the error for statistics
            self.mse_error_list.append(critic_loss.detach().cpu().numpy())

            # ---------------------------- update actor ---------------------------- #
            action_i = self.actors_local[agent](states_batched[:, agent, :])
            actions_pred = actions_batched.clone()
            actions_pred[:, agent, :] = action_i
            actions_pred_whole = actions_pred.view(BATCH_SIZE, -1)
            # Compute actor loss
            actor_loss = -self.critics_local[agent](states_concated,
                                                    actions_pred_whole).mean()
            # Minimize the loss
            self.actor_optimizers[agent].zero_grad()
            actor_loss.backward()
            self.actor_optimizers[agent].step()

    def update_target_networks(self):
        # ----------------------- update target networks ----------------------- #
        for agent in range(self.num_agents):
            self.soft_update(self.critics_local[agent],
                             self.critics_target[agent], TAU)
            self.soft_update(self.actors_local[agent],
                             self.actors_target[agent], TAU)

    @staticmethod
    def soft_update(local_model, target_model, tau):
        """Soft update model parameters.
        θ_target = τ*θ_local + (1 - τ)*θ_target

        Params
        ======
            local_model: PyTorch model (weights will be copied from)
            target_model: PyTorch model (weights will be copied to)
            tau (float): interpolation parameter 
        """
        for target_param, local_param in zip(target_model.parameters(),
                                             local_model.parameters()):
            target_param.data.copy_(tau * local_param.data +
                                    (1.0 - tau) * target_param.data)

    def load_weights(self, directory_path):
        """ see abstract class """
        super().load_weights(directory_path)
        actor_weights = os.path.join(directory_path, an_filename)
        critic_weights = os.path.join(directory_path, cn_filename)
        for agent in range(self.num_agents):
            self.actors_target[agent].load_state_dict(
                torch.load(actor_weights + "_" + str(agent),
                           map_location=device))
            self.critics_target[agent].load_state_dict(
                torch.load(critic_weights + "_" + str(agent),
                           map_location=device))
            self.actors_local[agent].load_state_dict(
                torch.load(actor_weights + "_" + str(agent),
                           map_location=device))
            self.critics_local[agent].load_state_dict(
                torch.load(critic_weights + "_" + str(agent),
                           map_location=device))

    def save_weights(self, directory_path):
        """ see abstract class """
        super().save_weights(directory_path)
        actor_weights = os.path.join(directory_path, an_filename)
        critic_weights = os.path.join(directory_path, cn_filename)
        for agent in range(self.num_agents):
            torch.save(self.actors_local[agent].state_dict(),
                       actor_weights + "_" + str(agent))
            torch.save(self.critics_local[agent].state_dict(),
                       critic_weights + "_" + str(agent))

    def save_mem(self, directory_path):
        """ see abstract class """
        super().save_mem(directory_path)
        self.memory.save(os.path.join(directory_path, memory_filename))

    def load_mem(self, directory_path):
        """ see abstract class """
        super().load_mem(directory_path)
        self.memory.load(os.path.join(directory_path, memory_filename))
class Agent(AgentABC):
    def __init__(self, state_size, action_size, num_agents, random_seed):
        """
        Initialize an DDPG Agent object.
            :param state_size (int): dimension of each state
            :param action_size (int): dimension of each action
            :param num_agents (int): number of agents in environment ot use ddpg
            :param random_seed (int): random seed
        """
        super().__init__(state_size, action_size, num_agents, random_seed)
        self.state_size = state_size
        self.action_size = action_size
        self.num_agents = num_agents
        self.seed = random.seed(random_seed)

        # Actor Network (w/ Target Network)
        self.actor_local = Actor(state_size, action_size,
                                 random_seed).to(device)
        self.actor_target = Actor(state_size, action_size,
                                  random_seed).to(device)
        self.actor_optimizer = optim.Adam(self.actor_local.parameters(),
                                          lr=LR_ACTOR)

        # Critic Network (w/ Target Network)
        self.critic_local = Critic(state_size, action_size,
                                   random_seed).to(device)
        self.critic_target = Critic(state_size, action_size,
                                    random_seed).to(device)
        self.critic_optimizer = optim.Adam(self.critic_local.parameters(),
                                           lr=LR_CRITIC,
                                           weight_decay=WEIGHT_DECAY)

        # Noise process for each agent
        self.noise = OUNoise((num_agents, action_size), random_seed)

        # Replay memory
        self.memory = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE,
                                   random_seed)

        # debug of the MSE critic loss
        self.mse_error_list = []

    def step(self, states, actions, rewards, next_states, dones):
        """Save experience in replay memory, and use random sample from buffer to learn."""
        # Save experience / reward
        for agent in range(self.num_agents):
            self.memory.add(states[agent, :], actions[agent, :],
                            rewards[agent], next_states[agent, :],
                            dones[agent])

        # Learn, if enough samples are available in memory
        if len(self.memory) > BATCH_SIZE:
            experiences = self.memory.sample()
            self.learn(experiences)
            self.debug_loss = np.mean(self.mse_error_list)

    def act(self, state, add_noise=True):
        """Returns actions for given state as per current policy."""
        state = torch.from_numpy(state).float().to(device)
        acts = np.zeros((self.num_agents, self.action_size))
        self.actor_local.eval()
        with torch.no_grad():
            for agent in range(self.num_agents):
                acts[agent, :] = self.actor_local(
                    state[agent, :]).cpu().data.numpy()
        self.actor_local.train()
        if add_noise:
            noise = self.noise.sample()
            acts += noise
        return np.clip(acts, -1, 1)

    def reset(self):
        """ see abstract class """
        super().reset()
        self.noise.reset()
        self.mse_error_list = []

    def learn(self, experiences):
        """Update policy and value parameters using given batch of experience tuples.
        Q_targets = r + γ * critic_target(next_state, actor_target(next_state))
        where:
            actor_target(state) -> action
            critic_target(state, action) -> Q-value

        Params
        ======
            experiences (Tuple[torch.Tensor]): tuple of (s, a, r, s', done) tuples 
            gamma (float): discount factor
        """
        states, actions, rewards, next_states, dones = experiences

        # ---------------------------- update critic ---------------------------- #
        # Get predicted next-state actions and Q values from target models
        actions_next = self.actor_target(next_states)
        Q_targets_next = self.critic_target(next_states, actions_next)
        # Compute Q targets for current states (y_i)
        Q_targets = rewards.view(BATCH_SIZE,
                                 -1) + (GAMMA * Q_targets_next *
                                        (1 - dones).view(BATCH_SIZE, -1))
        # Compute critic loss
        Q_expected = self.critic_local(states, actions)
        critic_loss = F.mse_loss(Q_expected, Q_targets)
        self.mse_error_list.append(critic_loss.detach().cpu().numpy())
        # Minimize the loss
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

        # ---------------------------- update actor ---------------------------- #
        # Compute actor loss
        actions_pred = self.actor_local(states)
        actor_loss = -self.critic_local(states, actions_pred).mean()
        # Minimize the loss
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        # ----------------------- update target networks ----------------------- #
        self.soft_update(self.critic_local, self.critic_target, TAU)
        self.soft_update(self.actor_local, self.actor_target, TAU)

    @staticmethod
    def soft_update(local_model, target_model, tau):
        """Soft update model parameters.
        θ_target = τ*θ_local + (1 - τ)*θ_target

        Params
        ======
            local_model: PyTorch model (weights will be copied from)
            target_model: PyTorch model (weights will be copied to)
            tau (float): interpolation parameter 
        """
        for target_param, local_param in zip(target_model.parameters(),
                                             local_model.parameters()):
            target_param.data.copy_(tau * local_param.data +
                                    (1.0 - tau) * target_param.data)

    def load_weights(self, directory_path):
        """ see abstract class """
        super().load_weights(directory_path)
        self.actor_target.load_state_dict(
            torch.load(os.path.join(directory_path, an_filename),
                       map_location=device))
        self.critic_target.load_state_dict(
            torch.load(os.path.join(directory_path, cn_filename),
                       map_location=device))
        self.actor_local.load_state_dict(
            torch.load(os.path.join(directory_path, an_filename),
                       map_location=device))
        self.critic_local.load_state_dict(
            torch.load(os.path.join(directory_path, cn_filename),
                       map_location=device))

    def save_weights(self, directory_path):
        """ see abstract class """
        super().save_weights(directory_path)
        torch.save(self.actor_local.state_dict(),
                   os.path.join(directory_path, an_filename))
        torch.save(self.critic_local.state_dict(),
                   os.path.join(directory_path, cn_filename))

    def save_mem(self, directory_path):
        """ see abstract class """
        super().save_mem(directory_path)
        self.memory.save(os.path.join(directory_path, "ddpg_memory"))

    def load_mem(self, directory_path):
        """ see abstract class """
        super().load_mem(directory_path)
        self.memory.load(os.path.join(directory_path, "ddpg_memory"))
예제 #3
0
    def train(self, exp_schedule, lr_schedule):
        """
        Performs training of Q

        Args:
            exp_schedule: Exploration instance s.t.
                exp_schedule.get_action(best_action) returns an action
            lr_schedule: Schedule for learning rate
        """

        # initialize replay buffer and variables
        if not self.config.batch:
            replay_buffer = ReplayBuffer(
                self.config.buffer_size, self.config.state_history
            )
        else:
            self.logger.info(
                'Loading replay buffer from {}'.format(self.config.buffer_path)
            )
            replay_buffer = ReplayBuffer.load(self.config.buffer_path)
            self.logger.info(
                'Loaded buffer with {} observations and {} in buffer'.format(
                    len(replay_buffer.obs), replay_buffer.num_in_buffer
                )
            )

        rewards = deque(maxlen=self.config.num_episodes_test)
        max_q_values = deque(maxlen=1000)
        q_values = deque(maxlen=1000)
        episode_lengths = deque(maxlen=1000)
        max_episode_length = 0
        self.init_averages()

        t = last_eval = last_record = 0  # time control of nb of steps
        scores_eval = []  # list of scores computed at iteration time
        scores_eval += [self.evaluate()]

        prog = Progbar(target=self.config.nsteps_train)

        # interact with environment
        while t < self.config.nsteps_train:
            total_reward = 0

            if not self.config.batch:
                state = self.env.reset()

            episode_step = 0
            avg_episode_length = (
                np.nan if len(episode_lengths) == 0 else np.mean(episode_lengths)
            )

            while True:
                t += 1
                episode_step += 1
                last_eval += 1
                last_record += 1
                if self.config.render_train:
                    self.env.render()

                if not self.config.batch:
                    get_action = functools.partial(
                        exp_schedule.get_action,
                        episode_num=len(episode_lengths),
                        episode_step=episode_step,
                        avg_episode_length=avg_episode_length
                    )
                    state, reward, done, _q_values = self.interact(
                        replay_buffer, state, get_action
                    )
                else:
                    reward = 0
                    done = True
                    _q_values = [0]

                # store q values
                max_q_values.append(max(_q_values))
                q_values.extend(list(_q_values))

                # perform a training step
                loss_eval, grad_eval = self.train_step(
                    t, replay_buffer, lr_schedule.epsilon
                )

                # logging stuff
                learning = (t > self.config.learning_start)
                learning_and_loggging = (
                    learning and
                    (t % self.config.log_freq == 0) and
                    (t % self.config.learning_freq == 0)
                )
                if learning_and_loggging:
                    self.update_averages(
                        rewards, max_q_values, q_values,
                        scores_eval, episode_lengths, max_episode_length
                    )
                    exp_schedule.update(t)
                    lr_schedule.update(t)
                    if len(rewards) > 0:
                        if self.config.batch:
                            exact = [
                                ("Loss", loss_eval),
                                ("Grads", grad_eval),
                                ("lr", lr_schedule.epsilon),
                            ]
                        else:
                            exact = [
                                ("Loss", loss_eval),
                                ("Avg_R", self.avg_reward),
                                ("Max_R", np.max(rewards)),
                                ("eps", exp_schedule.epsilon),
                                ("Grads", grad_eval),
                                ("Max_Q", self.max_q),
                                ("lr", lr_schedule.epsilon),
                                ("avg_ep_len", avg_episode_length)
                            ]

                        prog.update(t + 1, exact=exact)

                elif not learning and (t % self.config.log_freq == 0):
                    sys.stdout.write(
                        "\rPopulating the memory {}/{}...".format(
                            t, self.config.learning_start
                        )
                    )
                    sys.stdout.flush()

                # count reward
                total_reward += reward
                if done or t >= self.config.nsteps_train:
                    episode_lengths.append(episode_step)
                    if episode_step > max_episode_length:
                        max_episode_length = episode_step

                        # retrain the clusters every time the max episode
                        # length changes
                        if hasattr(self, 'reset_counts'):
                            self.reset_counts(
                                n_clusters=max_episode_length,
                                states=replay_buffer.get_encoded_states(),
                                actions=replay_buffer.get_actions()
                            )

                    break

            # updates to perform at the end of an episode
            rewards.append(total_reward)

            should_evaluate = (
                (t > self.config.learning_start) and
                (last_eval > self.config.eval_freq)
            )
            if should_evaluate:
                # evaluate our policy
                last_eval = 0
                print("")
                scores_eval.append(self.evaluate())

            should_record = (
                (t > self.config.learning_start) and
                self.config.record and
                (last_record > self.config.record_freq)
            )
            if should_record:
                self.logger.info("Recording...")
                last_record = 0
                self.record()

        # last words
        self.logger.info("- Training done.")
        self.save()
        scores_eval.append(self.evaluate())
        export_plot(scores_eval, "Scores", self.config.plot_output)

        if not self.config.batch:
            # save replay buffer
            self.logger.info(
                'Saving buffer to {}'.format(self.config.buffer_path)
            )
            replay_buffer.save(self.config.buffer_path)