Exemple #1
0
class Agent(object):
    """Interacts with and learns from the environment."""
    def __init__(self, state_size, action_size, random_seed, config):
        """
        Initialize an Agent object.

        Params
        ======
            state_size (int): dimension of each state
            action_size (int): dimension of each action
            random_seed (int): random seed
            config (dict) : dictionary of hyper-parameters
        """
        self.state_size = state_size
        self.action_size = action_size
        self.seed = random.seed(random_seed)
        self.config = config

        # Actor Network (w/ Target Network)
        self.actor_local = Actor(state_size,
                                 action_size,
                                 random_seed,
                                 fc_units=[512, 256]).to(device)
        self.actor_target = Actor(state_size,
                                  action_size,
                                  random_seed,
                                  fc_units=[512, 256]).to(device)
        self.actor_optimizer = optim.Adam(self.actor_local.parameters(),
                                          lr=self.config["LR_ACTOR"])

        # Critic Network (w/ Target Network)
        self.critic_local = Critic(state_size,
                                   action_size,
                                   random_seed,
                                   fcs_units=[512],
                                   fc_units=[256]).to(device)
        self.critic_target = Critic(state_size,
                                    action_size,
                                    random_seed,
                                    fcs_units=[512],
                                    fc_units=[256]).to(device)
        self.critic_optimizer = optim.Adam(
            self.critic_local.parameters(),
            lr=self.config["LR_CRITIC"],
            weight_decay=self.config["WEIGHT_DECAY"])

        self.hard_copy_weights(self.actor_target, self.actor_local)
        self.hard_copy_weights(self.critic_target, self.critic_local)

        # Noise process
        self.noise = OUNoise(action_size, random_seed)

        # Replay memory
        self.memory = ReplayBuffer(action_size, self.config["BUFFER_SIZE"],
                                   self.config["BATCH_SIZE"], random_seed)

    def hard_copy_weights(self, target, source):
        # Copy weights from source to target network (part of initialization)
        for target_param, param in zip(target.parameters(),
                                       source.parameters()):
            target_param.data.copy_(param.data)

    def step(self, state, action, reward, next_state, done):
        # Save experience in replay memory, and use random sample from buffer to learn.
        self.memory.add(state, action, reward, next_state, done)

        # If enough samples are available in memory, train models
        if len(self.memory) > self.config["BATCH_SIZE"]:
            experiences = self.memory.sample()
            self.learn(experiences, self.config["GAMMA"])

    def act(self, state, add_noise=True):
        # Returns noisy actions for given state as per current policy.
        state = torch.from_numpy(state).float().to(device)

        self.actor_local.eval()
        with torch.no_grad():
            action = self.actor_local(state).cpu().data.numpy()
        self.actor_local.train()

        if add_noise:
            action += self.noise.sample()

        return np.clip(action, -1, 1)

    def reset(self):
        # Reset Noise Generator
        self.noise.reset()

    def learn(self, experiences, gamma, t_step=1, update_every=1):
        """Update policy and value parameters using given batch of experience tuples.
        Q_targets = r + γ * critic_target(next_state, actor_target(next_state))
        where:
            actor_target(state) -> action
            critic_target(state, action) -> Q-value
        Params
        ======
            experiences (Tuple[torch.Tensor]): tuple of (s, a, r, s', done) tuples
            gamma (float): discount factor
            t_step (int) : total_step count
            update_every (int) : update model once in every n steps
        """
        states, actions, rewards, next_states, dones = experiences

        # ---------------------------- update critic ---------------------------- #
        # Get predicted next-state actions and Q values from target models
        actions_next = self.actor_target(next_states)
        Q_targets_next = self.critic_target(next_states, actions_next)
        # Compute Q targets for current states (y_i)
        Q_targets = rewards + (gamma * Q_targets_next * (1 - dones))
        # Compute critic loss
        Q_expected = self.critic_local(states, actions)
        critic_loss = F.mse_loss(Q_expected, Q_targets)
        # Minimize the loss
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        # Clip gradients of critic
        # torch.nn.utils.clip_grad_norm_(self.critic_local.parameters(), 2)
        self.critic_optimizer.step()

        # ---------------------------- update actor ---------------------------- #
        # Compute actor loss
        actions_pred = self.actor_local(states)
        actor_loss = -self.critic_local(states, actions_pred).mean()
        # Minimize the loss
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        # Clip gradients of actor
        # torch.nn.utils.clip_grad_norm_(self.actor_local.parameters(), 1)
        self.actor_optimizer.step()

        if t_step % update_every == 0:
            # ----------------------- update target networks ----------------------- #
            self.soft_update(self.critic_local, self.critic_target,
                             self.config["TAU"])
            self.soft_update(self.actor_local, self.actor_target,
                             self.config["TAU"])

    def soft_update(self, local_model, target_model, tau):
        """Soft update model parameters.
        θ_target = τ*θ_local + (1 - τ)*θ_target
        Params
        ======
            local_model: PyTorch model (weights will be copied from)
            target_model: PyTorch model (weights will be copied to)
            tau (float): interpolation parameter
        """
        for target_param, local_param in zip(target_model.parameters(),
                                             local_model.parameters()):
            target_param.data.copy_(tau * local_param.data +
                                    (1.0 - tau) * target_param.data)

    def save(self, name="agent_1_"):
        # Save actor and critic models
        torch.save(self.actor_local.state_dict(),
                   f"{name}_actor_checkpoint_actor.pth")
        torch.save(self.critic_local.state_dict(),
                   f"{name}_critic_checkpoint_actor.pth")
class Agent(AgentABC):
    def __init__(self, state_size, action_size, num_agents, random_seed):
        """
        Initialize an DDPG Agent object.
            :param state_size (int): dimension of each state
            :param action_size (int): dimension of each action
            :param num_agents (int): number of agents in environment ot use ddpg
            :param random_seed (int): random seed
        """
        super().__init__(state_size, action_size, num_agents, random_seed)
        self.state_size = state_size
        self.action_size = action_size
        self.num_agents = num_agents
        self.seed = random.seed(random_seed)

        # Actor Network (w/ Target Network)
        self.actor_local = Actor(state_size, action_size,
                                 random_seed).to(device)
        self.actor_target = Actor(state_size, action_size,
                                  random_seed).to(device)
        self.actor_optimizer = optim.Adam(self.actor_local.parameters(),
                                          lr=LR_ACTOR)

        # Critic Network (w/ Target Network)
        self.critic_local = Critic(state_size, action_size,
                                   random_seed).to(device)
        self.critic_target = Critic(state_size, action_size,
                                    random_seed).to(device)
        self.critic_optimizer = optim.Adam(self.critic_local.parameters(),
                                           lr=LR_CRITIC,
                                           weight_decay=WEIGHT_DECAY)

        # Noise process for each agent
        self.noise = OUNoise((num_agents, action_size), random_seed)

        # Replay memory
        self.memory = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE,
                                   random_seed)

        # debug of the MSE critic loss
        self.mse_error_list = []

    def step(self, states, actions, rewards, next_states, dones):
        """Save experience in replay memory, and use random sample from buffer to learn."""
        # Save experience / reward
        for agent in range(self.num_agents):
            self.memory.add(states[agent, :], actions[agent, :],
                            rewards[agent], next_states[agent, :],
                            dones[agent])

        # Learn, if enough samples are available in memory
        if len(self.memory) > BATCH_SIZE:
            experiences = self.memory.sample()
            self.learn(experiences)
            self.debug_loss = np.mean(self.mse_error_list)

    def act(self, state, add_noise=True):
        """Returns actions for given state as per current policy."""
        state = torch.from_numpy(state).float().to(device)
        acts = np.zeros((self.num_agents, self.action_size))
        self.actor_local.eval()
        with torch.no_grad():
            for agent in range(self.num_agents):
                acts[agent, :] = self.actor_local(
                    state[agent, :]).cpu().data.numpy()
        self.actor_local.train()
        if add_noise:
            noise = self.noise.sample()
            acts += noise
        return np.clip(acts, -1, 1)

    def reset(self):
        """ see abstract class """
        super().reset()
        self.noise.reset()
        self.mse_error_list = []

    def learn(self, experiences):
        """Update policy and value parameters using given batch of experience tuples.
        Q_targets = r + γ * critic_target(next_state, actor_target(next_state))
        where:
            actor_target(state) -> action
            critic_target(state, action) -> Q-value

        Params
        ======
            experiences (Tuple[torch.Tensor]): tuple of (s, a, r, s', done) tuples 
            gamma (float): discount factor
        """
        states, actions, rewards, next_states, dones = experiences

        # ---------------------------- update critic ---------------------------- #
        # Get predicted next-state actions and Q values from target models
        actions_next = self.actor_target(next_states)
        Q_targets_next = self.critic_target(next_states, actions_next)
        # Compute Q targets for current states (y_i)
        Q_targets = rewards.view(BATCH_SIZE,
                                 -1) + (GAMMA * Q_targets_next *
                                        (1 - dones).view(BATCH_SIZE, -1))
        # Compute critic loss
        Q_expected = self.critic_local(states, actions)
        critic_loss = F.mse_loss(Q_expected, Q_targets)
        self.mse_error_list.append(critic_loss.detach().cpu().numpy())
        # Minimize the loss
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

        # ---------------------------- update actor ---------------------------- #
        # Compute actor loss
        actions_pred = self.actor_local(states)
        actor_loss = -self.critic_local(states, actions_pred).mean()
        # Minimize the loss
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        # ----------------------- update target networks ----------------------- #
        self.soft_update(self.critic_local, self.critic_target, TAU)
        self.soft_update(self.actor_local, self.actor_target, TAU)

    @staticmethod
    def soft_update(local_model, target_model, tau):
        """Soft update model parameters.
        θ_target = τ*θ_local + (1 - τ)*θ_target

        Params
        ======
            local_model: PyTorch model (weights will be copied from)
            target_model: PyTorch model (weights will be copied to)
            tau (float): interpolation parameter 
        """
        for target_param, local_param in zip(target_model.parameters(),
                                             local_model.parameters()):
            target_param.data.copy_(tau * local_param.data +
                                    (1.0 - tau) * target_param.data)

    def load_weights(self, directory_path):
        """ see abstract class """
        super().load_weights(directory_path)
        self.actor_target.load_state_dict(
            torch.load(os.path.join(directory_path, an_filename),
                       map_location=device))
        self.critic_target.load_state_dict(
            torch.load(os.path.join(directory_path, cn_filename),
                       map_location=device))
        self.actor_local.load_state_dict(
            torch.load(os.path.join(directory_path, an_filename),
                       map_location=device))
        self.critic_local.load_state_dict(
            torch.load(os.path.join(directory_path, cn_filename),
                       map_location=device))

    def save_weights(self, directory_path):
        """ see abstract class """
        super().save_weights(directory_path)
        torch.save(self.actor_local.state_dict(),
                   os.path.join(directory_path, an_filename))
        torch.save(self.critic_local.state_dict(),
                   os.path.join(directory_path, cn_filename))

    def save_mem(self, directory_path):
        """ see abstract class """
        super().save_mem(directory_path)
        self.memory.save(os.path.join(directory_path, "ddpg_memory"))

    def load_mem(self, directory_path):
        """ see abstract class """
        super().load_mem(directory_path)
        self.memory.load(os.path.join(directory_path, "ddpg_memory"))
class Agent(AgentABC):
    """Interacts with and learns from the environment."""
    def __init__(self, state_size, action_size, num_agents, random_seed):
        """Initialize an MADDPG Agent object.
        Params
        ======
            :param state_size: dimension of each state
            :param action_size: dimension of each action
            :param num_agents: number of inner agents
            :param random_seed: random seed
        """
        super().__init__(state_size, action_size, num_agents, random_seed)
        self.state_size = state_size
        self.action_size = action_size
        self.num_agents = num_agents
        self.seed = random.seed(random_seed)

        self.actors_local = []
        self.actors_target = []
        self.actor_optimizers = []
        self.critics_local = []
        self.critics_target = []
        self.critic_optimizers = []
        for i in range(num_agents):
            # Actor Network (w/ Target Network)
            self.actors_local.append(
                Actor(state_size, action_size, random_seed).to(device))
            self.actors_target.append(
                Actor(state_size, action_size, random_seed).to(device))
            self.actor_optimizers.append(
                optim.Adam(self.actors_local[i].parameters(), lr=LR_ACTOR))
            # Critic Network (w/ Target Network)
            self.critics_local.append(
                Critic(num_agents * state_size, num_agents * action_size,
                       random_seed).to(device))
            self.critics_target.append(
                Critic(num_agents * state_size, num_agents * action_size,
                       random_seed).to(device))
            self.critic_optimizers.append(
                optim.Adam(self.critics_local[i].parameters(),
                           lr=LR_CRITIC,
                           weight_decay=WEIGHT_DECAY))

        # Noise process for each agent
        self.noise = OUNoise((num_agents, action_size), random_seed)

        # Replay memory
        self.memory = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE,
                                   random_seed)

        # debugging variables
        self.step_count = 0
        self.mse_error_list = []

    def step(self, states, actions, rewards, next_states, dones):
        """Save experience in replay memory, and use random sample from buffer to learn."""
        # Save experience / reward
        self.memory.add(states, actions, rewards, next_states, dones)

        # Learn, if enough samples are available in memory
        # in order to add some stability to the learning, we don't modify weights every turn.
        self.step_count += 1
        if (self.step_count %
                UPDATE_EVERY) == 0:  # learn every #UPDATE_EVERY steps
            for i in range(NUM_UPDATES):  # update #NUM_UPDATES times
                if len(self.memory) > 1000:
                    experiences = self.memory.sample()
                    self.learn(experiences)
                    self.debug_loss = np.mean(self.mse_error_list)
            self.update_target_networks()

    def act(self, state, add_noise=True):
        """Returns actions for given state as per current policy."""
        state = torch.from_numpy(state).float().to(device)
        acts = np.zeros((self.num_agents, self.action_size))
        for agent in range(self.num_agents):
            self.actors_local[agent].eval()
            with torch.no_grad():
                acts[agent, :] = self.actors_local[agent](
                    state[agent, :]).cpu().data.numpy()
            self.actors_local[agent].train()
        if add_noise:
            acts += self.noise.sample()
        return np.clip(acts, -1, 1)

    def reset(self):
        """ see abstract class """
        super().reset()
        self.noise.reset()
        self.mse_error_list = []

    def learn(self, experiences):
        """Update policy and value parameters using given batch of experience tuples.
        Q_targets = r + γ * critic_target(next_full_state, actors_target(next_partial_state) )
        where:
            actor_target(state) -> action
            critic_target(state, action) -> Q-value
        Params
        ======
            experiences (Tuple[torch.Tensor]): tuple of (s, a, r, s', done) tuples 
            gamma (float): discount factor
        """
        states_batched, actions_batched, rewards, next_states_batched, dones = experiences
        states_concated = states_batched.view(
            [BATCH_SIZE, self.num_agents * self.state_size])
        next_states_concated = next_states_batched.view(
            [BATCH_SIZE, self.num_agents * self.state_size])
        actions_concated = actions_batched.view(
            [BATCH_SIZE, self.num_agents * self.action_size])

        for agent in range(self.num_agents):
            actions_next_batched = [
                self.actors_target[i](next_states_batched[:, i, :])
                for i in range(self.num_agents)
            ]
            actions_next_whole = torch.cat(actions_next_batched, 1)
            # ---------------------------- update critic ---------------------------- #
            # Get predicted next-state actions and Q values from target models
            q_targets_next = self.critics_target[agent](next_states_concated,
                                                        actions_next_whole)
            # Compute Q targets for current states (y_i)
            q_targets = rewards[:, agent].view(
                BATCH_SIZE, -1) + (GAMMA * q_targets_next *
                                   (1 - dones[:, agent].view(BATCH_SIZE, -1)))
            # Compute critic loss
            q_expected = self.critics_local[agent](states_concated,
                                                   actions_concated)
            critic_loss = F.mse_loss(q_expected, q_targets)
            # Minimize the loss
            self.critic_optimizers[agent].zero_grad()
            critic_loss.backward()
            self.critic_optimizers[agent].step()
            # save the error for statistics
            self.mse_error_list.append(critic_loss.detach().cpu().numpy())

            # ---------------------------- update actor ---------------------------- #
            action_i = self.actors_local[agent](states_batched[:, agent, :])
            actions_pred = actions_batched.clone()
            actions_pred[:, agent, :] = action_i
            actions_pred_whole = actions_pred.view(BATCH_SIZE, -1)
            # Compute actor loss
            actor_loss = -self.critics_local[agent](states_concated,
                                                    actions_pred_whole).mean()
            # Minimize the loss
            self.actor_optimizers[agent].zero_grad()
            actor_loss.backward()
            self.actor_optimizers[agent].step()

    def update_target_networks(self):
        # ----------------------- update target networks ----------------------- #
        for agent in range(self.num_agents):
            self.soft_update(self.critics_local[agent],
                             self.critics_target[agent], TAU)
            self.soft_update(self.actors_local[agent],
                             self.actors_target[agent], TAU)

    @staticmethod
    def soft_update(local_model, target_model, tau):
        """Soft update model parameters.
        θ_target = τ*θ_local + (1 - τ)*θ_target

        Params
        ======
            local_model: PyTorch model (weights will be copied from)
            target_model: PyTorch model (weights will be copied to)
            tau (float): interpolation parameter 
        """
        for target_param, local_param in zip(target_model.parameters(),
                                             local_model.parameters()):
            target_param.data.copy_(tau * local_param.data +
                                    (1.0 - tau) * target_param.data)

    def load_weights(self, directory_path):
        """ see abstract class """
        super().load_weights(directory_path)
        actor_weights = os.path.join(directory_path, an_filename)
        critic_weights = os.path.join(directory_path, cn_filename)
        for agent in range(self.num_agents):
            self.actors_target[agent].load_state_dict(
                torch.load(actor_weights + "_" + str(agent),
                           map_location=device))
            self.critics_target[agent].load_state_dict(
                torch.load(critic_weights + "_" + str(agent),
                           map_location=device))
            self.actors_local[agent].load_state_dict(
                torch.load(actor_weights + "_" + str(agent),
                           map_location=device))
            self.critics_local[agent].load_state_dict(
                torch.load(critic_weights + "_" + str(agent),
                           map_location=device))

    def save_weights(self, directory_path):
        """ see abstract class """
        super().save_weights(directory_path)
        actor_weights = os.path.join(directory_path, an_filename)
        critic_weights = os.path.join(directory_path, cn_filename)
        for agent in range(self.num_agents):
            torch.save(self.actors_local[agent].state_dict(),
                       actor_weights + "_" + str(agent))
            torch.save(self.critics_local[agent].state_dict(),
                       critic_weights + "_" + str(agent))

    def save_mem(self, directory_path):
        """ see abstract class """
        super().save_mem(directory_path)
        self.memory.save(os.path.join(directory_path, memory_filename))

    def load_mem(self, directory_path):
        """ see abstract class """
        super().load_mem(directory_path)
        self.memory.load(os.path.join(directory_path, memory_filename))
Exemple #4
0
class Agent():
    """Interacts with and learns from the environment."""
    
    def __init__(self, state_size, action_size, num_agents, random_seed):
        """Initialize an Agent object.
        
        Params
        ======
            state_size (int): dimension of each state
            action_size (int): dimension of each action
            random_seed (int): random seed
        """
        self.state_size = state_size
        self.action_size = action_size
        self.num_agents = num_agents
        self.seed = random.seed(random_seed)

        # Actor Network (w/ Target Network)
        self.actor_local = Actor(state_size, action_size, random_seed).to(device)
        self.actor_target = Actor(state_size, action_size, random_seed).to(device)
        self.actor_optimizer = optim.Adam(self.actor_local.parameters(), lr=LR_ACTOR)
        for target_param, local_param in zip(self.actor_target.parameters(), self.actor_local.parameters()):
            target_param.data.copy_(local_param.data)

        # Critic Network (w/ Target Network)
        self.critic_local = Critic(state_size, action_size, random_seed).to(device)
        self.critic_target = Critic(state_size, action_size, random_seed).to(device)
        self.critic_optimizer = optim.Adam(self.critic_local.parameters(), lr=LR_CRITIC, weight_decay=WEIGHT_DECAY)
        for target_param, local_param in zip(self.critic_target.parameters(), self.critic_local.parameters()):
            target_param.data.copy_(local_param.data)

        # Noise process
        self.noise = OUNoise((num_agents, action_size), random_seed)

        # Replay memory
        self.memory = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, random_seed)
        
        # Initialize time step (for updating every UPDATE_EVERY steps)
        self.t_step = 0

    
    def step(self, states, actions, rewards, next_states, dones):
        """Save experience in replay memory, and use random sample from buffer to learn."""
        # Save experience / reward
        for i in range(self.num_agents):
            self.memory.add(states[i], actions[i], rewards[i], next_states[i], dones[i])
        
        self.t_step = (self.t_step + 1) % UPDATE_EVERY
        if self.t_step == 0:
            # Learn, if enough samples are available in memory
            if len(self.memory) > BATCH_SIZE:
                for _ in range(NUM_LEARN):
                    experiences = self.memory.sample()
                    self.learn(experiences, GAMMA)

    def act(self, states, eps=0, add_noise=True):
        """Returns actions for given state as per current policy."""
        states = torch.from_numpy(states).float().to(device)
        self.actor_local.eval()
        with torch.no_grad():
            actions = self.actor_local(states).cpu().data.numpy()
        self.actor_local.train()
        
        if add_noise and random.random() < eps:
            actions += self.noise.sample()
        return np.clip(actions, -1, 1)

    def reset(self):
        self.noise.reset()

    def learn(self, experiences, gamma):
        """Update policy and value parameters using given batch of experience tuples.
        Q_targets = r + γ * critic_target(next_state, actor_target(next_state))
        where:
            actor_target(state) -> action
            critic_target(state, action) -> Q-value

        Params
        ======
            experiences (Tuple[torch.Tensor]): tuple of (s, a, r, s', done) tuples 
            gamma (float): discount factor
        """
        states, actions, rewards, next_states, dones = experiences

        # ---------------------------- update critic ---------------------------- #
        # Get predicted next-state actions and Q values from target models
        actions_next = self.actor_target(next_states)
        Q_targets_next = self.critic_target(next_states, actions_next)
        # Compute Q targets for current states (y_i)
        Q_targets = rewards + (gamma * Q_targets_next * (1 - dones))
        # Compute critic loss
        Q_expected = self.critic_local(states, actions)
        critic_loss = F.mse_loss(Q_expected, Q_targets)
        # Minimize the loss
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        # Gradient clipping prevents the gradient update moving the parameters extermly far,
        # when there are cliff-like structures in the strongly non-linear function 
        # that a neural network wants to approximate. {Goodfellow, Section 10.11.1, Deep Learning}
        # {Mikolov, 2012; Pascanu et al., 2013}
        torch.nn.utils.clip_grad_norm_(self.critic_local.parameters(), 1.0)
        self.critic_optimizer.step()

        # ---------------------------- update actor ---------------------------- #
        # Compute actor loss
        actions_pred = self.actor_local(states)
        actor_loss = -self.critic_local(states, actions_pred).mean()
        # Minimize the loss
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        # ----------------------- update target networks ----------------------- #
        self.soft_update(self.critic_local, self.critic_target, TAU)
        self.soft_update(self.actor_local, self.actor_target, TAU)                     

    def soft_update(self, local_model, target_model, tau):
        """Soft update model parameters.
        θ_target = τ*θ_local + (1 - τ)*θ_target

        Params
        ======
            local_model: PyTorch model (weights will be copied from)
            target_model: PyTorch model (weights will be copied to)
            tau (float): interpolation parameter 
        """
        for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
            target_param.data.copy_(tau*local_param.data + (1.0-tau)*target_param.data)
    def __init__(self,
                 algo_type="MADDPG",
                 act_space=None,
                 obs_space=None,
                 rnn_policy=False,
                 rnn_critic=False,
                 hidden_dim=64,
                 lr=0.01,
                 env_obs_space=None,
                 env_act_space=None):
        """
        Inputs:
            act_space: single agent action space (single space or Dict)
            obs_space: single agent observation space (single space Dict)
        """
        self.algo_type = algo_type
        self.act_space = act_space
        self.obs_space = obs_space

        # continuous or discrete action (only look at `move` action, assume
        # move and comm space both discrete or continuous)
        if isinstance(act_space, Box) or isinstance(act_space["move"], Box):
            discrete_action = False
        elif isinstance(act_space, Discrete) or isinstance(
                act_space["move"], Discrete):
            discrete_action = True
        self.discrete_action = discrete_action

        # Exploration noise
        if not discrete_action:
            # `move`, `comm` share same continuous noise source
            self.exploration = OUNoise(self.get_shape(act_space))
        else:
            self.exploration = 0.3  # epsilon for eps-greedy

        # Policy (supports multiple outputs)
        self.rnn_policy = rnn_policy
        self.policy_hidden_states = None

        num_in_pol = obs_space.shape[0]
        if isinstance(act_space, Dict):
            # hard specify now, could generalize later
            num_out_pol = {
                "move": self.get_shape(act_space, "move"),
                "comm": self.get_shape(act_space, "comm")
            }
        else:
            num_out_pol = self.get_shape(act_space)

        self.policy = Policy(num_in_pol,
                             num_out_pol,
                             hidden_dim=hidden_dim,
                             constrain_out=True,
                             discrete_action=discrete_action,
                             rnn_policy=rnn_policy)
        self.target_policy = Policy(num_in_pol,
                                    num_out_pol,
                                    hidden_dim=hidden_dim,
                                    constrain_out=True,
                                    discrete_action=discrete_action,
                                    rnn_policy=rnn_policy)
        hard_update(self.target_policy, self.policy)

        # Critic
        self.rnn_critic = rnn_critic
        self.critic_hidden_states = None

        if algo_type == "MADDPG":
            num_in_critic = 0
            for oobsp in env_observation_space:
                num_in_critic += oobsp.shape[0]
            for oacsp in env_action_space:
                # feed all acts to centralized critic
                num_in_critic += self.get_shape(oacsp)
        else:  # only DDPG, local critic
            num_in_critic = obs_space.shape[0] + self.get_shape(act_space)

        critic_net_fn = RecurrentNetwork if rnn_critic else MLPNetwork
        self.critic = critic_net_fn(num_in_critic,
                                    1,
                                    hidden_dim=hidden_dim,
                                    constrain_out=False)
        self.target_critic = critic_net_fn(num_in_critic,
                                           1,
                                           hidden_dim=hidden_dim,
                                           constrain_out=False)
        hard_update(self.target_critic, self.critic)

        # Optimizers
        self.policy_optimizer = Adam(self.policy.parameters(), lr=lr)
        self.critic_optimizer = Adam(self.critic.parameters(), lr=lr)
Exemple #6
0
class Agent():
    def __init__(self, num_agents, state_size, action_size, opts):
        self.num_agents = num_agents
        self.state_size = state_size
        self.action_size = action_size
        self.opts = opts

        # Actor Network
        self.actor_local = ActorNet(state_size,
                                    action_size,
                                    fc1_units=opts.a_fc1,
                                    fc2_units=opts.a_fc2).to(opts.device)
        self.actor_target = ActorNet(state_size,
                                     action_size,
                                     fc1_units=opts.a_fc1,
                                     fc2_units=opts.a_fc2).to(opts.device)
        self.actor_optimizer = torch.optim.Adam(self.actor_local.parameters(),
                                                lr=opts.actor_lr)

        # Critic Network
        self.critic_local = CriticNet(state_size,
                                      action_size,
                                      fc1_units=opts.c_fc1,
                                      fc2_units=opts.c_fc2).to(opts.device)
        self.critic_target = CriticNet(state_size,
                                       action_size,
                                       fc1_units=opts.c_fc1,
                                       fc2_units=opts.c_fc2).to(opts.device)
        self.critic_optimizer = torch.optim.Adam(
            self.critic_local.parameters(),
            lr=opts.critic_lr,
            weight_decay=opts.critic_weight_decay)

        # Noise process
        self.noise = OUNoise((num_agents, action_size), opts.random_seed)
        self.step_idx = 0

        # Replay memory
        self.memory = ReplayBuffer(action_size, opts.buffer_size,
                                   opts.batch_size, opts.random_seed,
                                   opts.device)

    def step(self, state, action, reward, next_state, done):
        for i in range(self.num_agents):
            self.memory.add(state[i, :], action[i, :], reward[i],
                            next_state[i, :], done[i])

        self.step_idx += 1
        is_learn_iteration = (self.step_idx % self.opts.learn_every) == 0
        is_update_iteration = (self.step_idx % self.opts.update_every) == 0

        if len(self.memory) > self.opts.batch_size:
            if is_learn_iteration:
                experiences = self.memory.sample()
                self.learn(experiences, self.opts.gamma)

            if is_update_iteration:
                soft_update(self.critic_local, self.critic_target,
                            self.opts.tau)
                soft_update(self.actor_local, self.actor_target, self.opts.tau)

    def act(self, state):
        state = torch.from_numpy(state).float().to(self.opts.device)

        self.actor_local.eval()
        with torch.no_grad():
            action = self.actor_local(state).cpu().data.numpy()
        self.actor_local.train()

        action += self.noise.sample()
        return np.clip(action, self.opts.minimum_action_value,
                       self.opts.maximum_action_value)

    def save(self):
        torch.save(self.critic_local.state_dict(),
                   self.opts.output_data_path + "critic_local.pth")
        torch.save(self.critic_target.state_dict(),
                   self.opts.output_data_path + "critic_target.pth")
        torch.save(self.actor_local.state_dict(),
                   self.opts.output_data_path + "actor_local.pth")
        torch.save(self.actor_target.state_dict(),
                   self.opts.output_data_path + "actor_target.pth")

    def learn(self, experiences, gamma):
        states, actions, rewards, next_states, dones = experiences

        states = tensor(states, self.opts.device)
        actions = tensor(actions, self.opts.device)
        rewards = tensor(rewards, self.opts.device)
        next_states = tensor(next_states, self.opts.device)
        mask = tensor(1 - dones, self.opts.device)

        # Update critic
        actions_next = self.actor_target(next_states)
        Q_targets_next = self.critic_target(next_states, actions_next)
        Q_targets = rewards + (gamma * Q_targets_next * mask)

        # Compute & minimize critic loss
        Q_expected = self.critic_local(states, actions)
        critic_loss = F.mse_loss(Q_expected, Q_targets)
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

        # Update actor
        actions_pred = self.actor_local(states)

        # Compute & minimize critic loss
        actor_loss = -self.critic_local(states, actions_pred).mean()
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()
class DDPGAgent(object):
    """
    General class for DDPG agents (policy, critic, target policy, target
    critic, exploration noise)
    """
    def __init__(self,
                 algo_type="MADDPG",
                 act_space=None,
                 obs_space=None,
                 rnn_policy=False,
                 rnn_critic=False,
                 hidden_dim=64,
                 lr=0.01,
                 env_obs_space=None,
                 env_act_space=None):
        """
        Inputs:
            act_space: single agent action space (single space or Dict)
            obs_space: single agent observation space (single space Dict)
        """
        self.algo_type = algo_type
        self.act_space = act_space
        self.obs_space = obs_space

        # continuous or discrete action (only look at `move` action, assume
        # move and comm space both discrete or continuous)
        if isinstance(act_space, Box) or isinstance(act_space["move"], Box):
            discrete_action = False
        elif isinstance(act_space, Discrete) or isinstance(
                act_space["move"], Discrete):
            discrete_action = True
        self.discrete_action = discrete_action

        # Exploration noise
        if not discrete_action:
            # `move`, `comm` share same continuous noise source
            self.exploration = OUNoise(self.get_shape(act_space))
        else:
            self.exploration = 0.3  # epsilon for eps-greedy

        # Policy (supports multiple outputs)
        self.rnn_policy = rnn_policy
        self.policy_hidden_states = None

        num_in_pol = obs_space.shape[0]
        if isinstance(act_space, Dict):
            # hard specify now, could generalize later
            num_out_pol = {
                "move": self.get_shape(act_space, "move"),
                "comm": self.get_shape(act_space, "comm")
            }
        else:
            num_out_pol = self.get_shape(act_space)

        self.policy = Policy(num_in_pol,
                             num_out_pol,
                             hidden_dim=hidden_dim,
                             constrain_out=True,
                             discrete_action=discrete_action,
                             rnn_policy=rnn_policy)
        self.target_policy = Policy(num_in_pol,
                                    num_out_pol,
                                    hidden_dim=hidden_dim,
                                    constrain_out=True,
                                    discrete_action=discrete_action,
                                    rnn_policy=rnn_policy)
        hard_update(self.target_policy, self.policy)

        # Critic
        self.rnn_critic = rnn_critic
        self.critic_hidden_states = None

        if algo_type == "MADDPG":
            num_in_critic = 0
            for oobsp in env_observation_space:
                num_in_critic += oobsp.shape[0]
            for oacsp in env_action_space:
                # feed all acts to centralized critic
                num_in_critic += self.get_shape(oacsp)
        else:  # only DDPG, local critic
            num_in_critic = obs_space.shape[0] + self.get_shape(act_space)

        critic_net_fn = RecurrentNetwork if rnn_critic else MLPNetwork
        self.critic = critic_net_fn(num_in_critic,
                                    1,
                                    hidden_dim=hidden_dim,
                                    constrain_out=False)
        self.target_critic = critic_net_fn(num_in_critic,
                                           1,
                                           hidden_dim=hidden_dim,
                                           constrain_out=False)
        hard_update(self.target_critic, self.critic)

        # Optimizers
        self.policy_optimizer = Adam(self.policy.parameters(), lr=lr)
        self.critic_optimizer = Adam(self.critic.parameters(), lr=lr)

    def get_shape(self, x, key=None):
        """ func to infer action output shape """
        if isinstance(x, Dict):
            if key is None:  # sum of action space dims
                return sum([
                    x[k].n if self.discrete_action else x[k].shape[0]
                    for k in x.spaces
                ])
            elif key in x.spaces:
                return x[key].n if self.discrete_action else x[key].shape[0]
            else:  # key not in action spaces
                return 0
        else:
            return x.n if self.discrete_action else x.shape[0]

    def reset_noise(self):
        if not self.discrete_action:
            self.exploration.reset()

    def scale_noise(self, scale):
        if self.discrete_action:
            self.exploration = scale
        else:
            self.exploration.scale = scale

    def init_hidden(self, batch_size):
        # (1,H) -> (B,H)
        # policy.init_hidden().unsqueeze(0).expand(batch_size, self.n_agents, -1)
        if self.rnn_policy:
            self.policy_hidden_states = self.policy.init_hidden().expand(
                batch_size, -1)
        if self.rnn_critic:
            self.critic_hidden_states = self.policy.init_hidden().expand(
                batch_size, -1)

    def compute_q_val(self, vf_in, h=None, target=False):
        """ training critic forward with specified policy 
        Arguments:
            agent_i: index to agent; critic: critic network to agent; vf_in: (B,T,K);
            bs: batch size; ts: length of episode; target: if use target network
        Returns:
            q: (B*T,1)
        """
        bs, ts, _ = vf_in.shape
        critic = self.target_critic if target else self.critic

        if self.rnn_critic:
            q = []  # (B,1)*T
            h_t = self.critic_hidden_states.clone()  # (B,H)
            for t in range(ts):
                q_t, h_t = critic(vf_in[:, t], h_t)
                q.append(q_t)
            q = torch.stack(q, 0).permute(1, 0, 2)  # (T,B,1) -> (B,T,1)
            q = q.reshape(bs * ts, -1)  # (B*T,1)
        else:
            # (B,T,D) -> (B*T,1)
            q = critic(vf_in.reshape(bs * ts, -1))
        return q

    def compute_action(self, obs, h=None, target=False, requires_grad=True):
        """ traininsg actor forward with specified policy 
        concat all actions to be fed in critics
        Arguments:
            agent_i: index to agent; pi: policy to agent; obs: (B,T,O);
            bs: batch size; ts: length of episode; target: if use target network
        Returns:
            act: dict of (B,T,A) 
        """
        def _soft_act(x):  # x: (B,A)
            if not self.discrete_action:
                return x
            if requires_grad:
                return gumbel_softmax(x, hard=True)
            else:
                return onehot_from_logits(x)

        bs, ts, _ = obs.shape
        pi = self.target_policy if target else self.policy

        if self.rnn_policy:
            act = []  # [(B,sum(A_k))]*T
            h_t = self.policy_hidden_states.clone()  # (B,H)
            for t in range(ts):
                act_t, h_t = pi(obs[:, t], h_t)  # act_t is dict!!
                cat_act = torch.concat(
                    [_soft_act(a) for k, a in act_t.items()],
                    -1)  # (B,sum(A_k))
                act.append(cat_act)
            act = torch.stack(act, 0).permute(1, 0, 2)  # (B,T,sum(A_k))
        else:
            stacked_obs = obs.reshape(bs * ts, -1)  # (B*T,O)
            act, _ = pi(stacked_obs)  # act is dict of (B*T,A)
            act = torch.concat(
                [_soft_act(a).reshape(bs, ts, -1) for k, a in act.items()],
                -1)  # (B,T,sum(A_k))
        return act

    def step(self, obs, explore=False):
        """
        Take a step forward in environment for a minibatch of observations
        equivalent to `act` or `compute_actions`
        Arguments:
            obs: (B,O)
            explore: Whether or not to add exploration noise
        Returns:
            action: dict of actions for this agent, (B,A)
        """
        with torch.no_grad():
            action, hidden_states = self.policy(obs, self.policy_hidden_states)
            self.policy_hidden_states = hidden_states  # if mlp, still defafult None

            if self.discrete_action:
                for k in action:
                    if explore:
                        action[k] = gumbel_softmax(action[k], hard=True)
                    else:
                        action[k] = onehot_from_logits(action[k])
            else:  # continuous action
                idx = 0
                noise = Variable(Tensor(self.exploration.noise()),
                                 requires_grad=False)
                for k in action:
                    if explore:
                        dim = action[k].shape[-1]
                        action[k] += noise[idx:idx + dim]
                        idx += dim
                    action[k] = action[k].clamp(-1, 1)
        return action

    def get_params(self):
        return {
            'policy': self.policy.state_dict(),
            'critic': self.critic.state_dict(),
            'target_policy': self.target_policy.state_dict(),
            'target_critic': self.target_critic.state_dict(),
            'policy_optimizer': self.policy_optimizer.state_dict(),
            'critic_optimizer': self.critic_optimizer.state_dict()
        }

    def load_params(self, params):
        self.policy.load_state_dict(params['policy'])
        self.critic.load_state_dict(params['critic'])
        self.target_policy.load_state_dict(params['target_policy'])
        self.target_critic.load_state_dict(params['target_critic'])
        self.policy_optimizer.load_state_dict(params['policy_optimizer'])
        self.critic_optimizer.load_state_dict(params['critic_optimizer'])
if __name__ == '__main__':

    with tf.Session() as sess:

        env = gym.make('LunarLanderContinuous-v2')

        env.seed(0)
        np.random.seed(0)
        tf.set_random_seed(0)

        ep = 2000
        tau = 0.001
        gamma = 0.99
        min_batch = 65
        actor_lr = 0.00005
        critic_lr = 0.0005
        buffer_size = 1000000

        state_dim = env.observation_space.shape[0]
        action_dim = env.action_space.shape[0]
        action_bound = env.action_space.high

        actor_noise = OUNoise(mu=np.zeros(action_dim))
        actor = ActorNetwork(sess, state_dim, action_dim, action_bound, actor_lr, tau, min_batch)
        critic = CriticNetwork(sess, state_dim, action_dim, critic_lr, tau, gamma, actor.get_num_trainable_vars())
        scores = train(sess, env, actor, critic, actor_noise, buffer_size, min_batch, ep)

        plt.plot([i + 1 for i in range(0, len(scores), 4)], scores[::4])
        plt.show()

Exemple #9
0
    def __init__(self, algo_type="MADDPG", act_space=None, obs_space=None, 
                rnn_policy=False, rnn_critic=False, hidden_dim=64, lr=0.01,
                norm_in=False, constrain_out=False, thought_dim=64,
                env_obs_space=None, env_act_space=None, **kwargs):
        """
        Inputs:
            act_space: single agent action space (single space or Dict)
            obs_space: single agent observation space (single space Dict)
        """
        self.algo_type = algo_type
        self.act_space = act_space 
        self.obs_space = obs_space

        # continuous or discrete action (only look at `move` action, assume
        # move and comm space both discrete or continuous)
        tmp = act_space.spaces["move"] if isinstance(act_space, Dict) else act_space
        self.discrete_action = False if isinstance(tmp, Box) else True 

        # Exploration noise 
        if not self.discrete_action:
            # `move`, `comm` share same continuous noise source
            self.exploration = OUNoise(self.get_shape(act_space))
        else:
            self.exploration = 0.3  # epsilon for eps-greedy
        
        # Policy (supports multiple outputs)
        self.rnn_policy = rnn_policy
        self.policy_hidden_states = None 

        num_in_pol = obs_space.shape[0]
        if isinstance(act_space, Dict):
            # hard specify now, could generalize later 
            num_out_pol = {
                "move": self.get_shape(act_space, "move"), 
                "comm": self.get_shape(act_space, "comm")
            }
        else:
            num_out_pol = self.get_shape(act_space)

        # atoc policy 
        policy_kwargs = dict(
            hidden_dim=hidden_dim,
            norm_in=norm_in,
            constrain_out=constrain_out,
            discrete_action=self.discrete_action,
            rnn_policy=rnn_policy,
            thought_dim=thought_dim
        )
        self.policy = ATOCPolicy(num_in_pol, num_out_pol, **policy_kwargs)
        self.target_policy = ATOCPolicy(num_in_pol, num_out_pol, **policy_kwargs)
        hard_update(self.target_policy, self.policy)
        
        # Critic 
        self.rnn_critic = rnn_critic
        self.critic_hidden_states = None 
        
        if algo_type == "MADDPG":
            num_in_critic = 0
            for oobsp in env_obs_space:
                num_in_critic += oobsp.shape[0]
            for oacsp in env_act_space:
                # feed all acts to centralized critic
                num_in_critic += self.get_shape(oacsp)
        else:   # only DDPG, local critic 
            num_in_critic = obs_space.shape[0] + self.get_shape(act_space)

        critic_net_fn = RecurrentNetwork if rnn_critic else MLPNetwork
        critic_kwargs = dict(
            hidden_dim=hidden_dim,
            norm_in=norm_in,
            constrain_out=constrain_out
        )
        self.critic = critic_net_fn(num_in_critic, 1, **critic_kwargs)
        self.target_critic = critic_net_fn(num_in_critic, 1, **critic_kwargs)
        hard_update(self.target_critic, self.critic)

        # NOTE: atoc modules 
        # attention unit, MLP (used here) or RNN, output comm probability
        self.thought_dim = thought_dim
        self.attention_unit = nn.Sequential(
            MLPNetwork(thought_dim, 1, hidden_dim=hidden_dim, 
                norm_in=norm_in, constrain_out=False), 
            nn.Sigmoid()
        )
        
        # communication channel, bi-LSTM (used here) or graph 
        self.comm_channel = nn.LSTM(thought_dim, thought_dim, 1, 
            batch_first=False, bidirectional=True)

        # Optimizers 
        self.policy_optimizer = Adam(self.policy.parameters(), lr=lr)
        self.critic_optimizer = Adam(self.critic.parameters(), lr=lr)
        self.attention_unit_optimizer = Adam(self.attention_unit.parameters(), lr=lr)
        self.comm_channel_optimizer = Adam(self.comm_channel.parameters(), lr=lr)
Exemple #10
0
class ATOCAgent(object):
    """
    General class for DDPG agents (policy, critic, target policy, target
    critic, exploration noise)
    """
    def __init__(self, algo_type="MADDPG", act_space=None, obs_space=None, 
                rnn_policy=False, rnn_critic=False, hidden_dim=64, lr=0.01,
                norm_in=False, constrain_out=False, thought_dim=64,
                env_obs_space=None, env_act_space=None, **kwargs):
        """
        Inputs:
            act_space: single agent action space (single space or Dict)
            obs_space: single agent observation space (single space Dict)
        """
        self.algo_type = algo_type
        self.act_space = act_space 
        self.obs_space = obs_space

        # continuous or discrete action (only look at `move` action, assume
        # move and comm space both discrete or continuous)
        tmp = act_space.spaces["move"] if isinstance(act_space, Dict) else act_space
        self.discrete_action = False if isinstance(tmp, Box) else True 

        # Exploration noise 
        if not self.discrete_action:
            # `move`, `comm` share same continuous noise source
            self.exploration = OUNoise(self.get_shape(act_space))
        else:
            self.exploration = 0.3  # epsilon for eps-greedy
        
        # Policy (supports multiple outputs)
        self.rnn_policy = rnn_policy
        self.policy_hidden_states = None 

        num_in_pol = obs_space.shape[0]
        if isinstance(act_space, Dict):
            # hard specify now, could generalize later 
            num_out_pol = {
                "move": self.get_shape(act_space, "move"), 
                "comm": self.get_shape(act_space, "comm")
            }
        else:
            num_out_pol = self.get_shape(act_space)

        # atoc policy 
        policy_kwargs = dict(
            hidden_dim=hidden_dim,
            norm_in=norm_in,
            constrain_out=constrain_out,
            discrete_action=self.discrete_action,
            rnn_policy=rnn_policy,
            thought_dim=thought_dim
        )
        self.policy = ATOCPolicy(num_in_pol, num_out_pol, **policy_kwargs)
        self.target_policy = ATOCPolicy(num_in_pol, num_out_pol, **policy_kwargs)
        hard_update(self.target_policy, self.policy)
        
        # Critic 
        self.rnn_critic = rnn_critic
        self.critic_hidden_states = None 
        
        if algo_type == "MADDPG":
            num_in_critic = 0
            for oobsp in env_obs_space:
                num_in_critic += oobsp.shape[0]
            for oacsp in env_act_space:
                # feed all acts to centralized critic
                num_in_critic += self.get_shape(oacsp)
        else:   # only DDPG, local critic 
            num_in_critic = obs_space.shape[0] + self.get_shape(act_space)

        critic_net_fn = RecurrentNetwork if rnn_critic else MLPNetwork
        critic_kwargs = dict(
            hidden_dim=hidden_dim,
            norm_in=norm_in,
            constrain_out=constrain_out
        )
        self.critic = critic_net_fn(num_in_critic, 1, **critic_kwargs)
        self.target_critic = critic_net_fn(num_in_critic, 1, **critic_kwargs)
        hard_update(self.target_critic, self.critic)

        # NOTE: atoc modules 
        # attention unit, MLP (used here) or RNN, output comm probability
        self.thought_dim = thought_dim
        self.attention_unit = nn.Sequential(
            MLPNetwork(thought_dim, 1, hidden_dim=hidden_dim, 
                norm_in=norm_in, constrain_out=False), 
            nn.Sigmoid()
        )
        
        # communication channel, bi-LSTM (used here) or graph 
        self.comm_channel = nn.LSTM(thought_dim, thought_dim, 1, 
            batch_first=False, bidirectional=True)

        # Optimizers 
        self.policy_optimizer = Adam(self.policy.parameters(), lr=lr)
        self.critic_optimizer = Adam(self.critic.parameters(), lr=lr)
        self.attention_unit_optimizer = Adam(self.attention_unit.parameters(), lr=lr)
        self.comm_channel_optimizer = Adam(self.comm_channel.parameters(), lr=lr)


    def get_shape(self, x, key=None):
        """ func to infer action output shape """
        if isinstance(x, Dict):
            if key is None: # sum of action space dims
                return sum([
                    x.spaces[k].n if self.discrete_action else x.spaces[k].shape[0]
                    for k in x.spaces
                ])
            elif key in x.spaces:
                return x.spaces[key].n if self.discrete_action else x.spaces[key].shape[0]
            else:   # key not in action spaces
                return 0
        else:
            return x.n if self.discrete_action else x.shape[0]
    

    def reset_noise(self):
        if not self.discrete_action:
            self.exploration.reset()

    def scale_noise(self, scale):
        if self.discrete_action:
            self.exploration = scale
        else:
            self.exploration.scale = scale


    def init_hidden(self, batch_size):
        # (1,H) -> (B,H)
        # policy.init_hidden().unsqueeze(0).expand(batch_size, self.n_agents, -1)  
        if self.rnn_policy:
            self.policy_hidden_states = self.policy.init_hidden().expand(batch_size, -1)  
        if self.rnn_critic:
            self.critic_hidden_states = self.critic.init_hidden().expand(batch_size, -1) 


    def compute_value(self, vf_in, h_critic=None, target=False, truncate_steps=-1):
        """ training critic forward with specified policy 
        Arguments:
            vf_in: (B,T,K)
            target: if use target network
            truncate_steps: number of BPTT steps to truncate if used
        Returns:
            q: (B*T,1)
        """
        bs, ts, _ = vf_in.shape
        critic = self.target_critic if target else self.critic

        if self.rnn_critic:
            if h_critic is None:
                h_t = self.critic_hidden_states.clone() # (B,H)
            else:
                h_t = h_critic  #.clone()

            # rollout 
            q = rnn_forward_sequence(
                critic, vf_in, h_t, truncate_steps=truncate_steps)
            # q = []   # (B,1)*T
            # for t in range(ts):
            #     q_t, h_t = critic(vf_in[:,t], h_t)
            #     q.append(q_t)
            q = torch.stack(q, 0).permute(1,0,2)   # (T,B,1) -> (B,T,1)
            q = q.reshape(bs*ts, -1)  # (B*T,1)
        else:
            # (B,T,D) -> (B*T,1)
            q, _ = critic(vf_in.reshape(bs*ts, -1))
        return q 


    def _soft_act(self, x, requires_grad=True):    
        """ soften action if discrete, x: (B,A) """
        if not self.discrete_action:
            return x 
        if requires_grad:
            return gumbel_softmax(x, hard=True)
        else:
            return onehot_from_logits(x)


    def compute_action(self, obs, h_actor=None, target=False, requires_grad=True, truncate_steps=-1):
        """ traininsg actor forward with specified policy 
        concat all actions to be fed in critics
        Arguments:
            obs: (B,T,O)
            target: if use target network
            requires_grad: if use _soft_act to differentiate discrete action
        Returns:
            act: dict of (B,T,A) 
        """
        bs, ts, _ = obs.shape
        pi = self.target_policy if target else self.policy

        if self.rnn_policy:
            if h_actor is None:
                h_t = self.policy_hidden_states.clone() # (B,H)
            else:
                h_t = h_actor   #.clone()

            # rollout 
            seq_logits = rnn_forward_sequence(
                pi, obs, h_t, truncate_steps=truncate_steps)
            # seq_logits = []  
            # for t in range(ts):
            #     act_t, h_t = pi(obs[:,t], h_t)  # act_t is dict (B,A)
            #     seq_logits.append(act_t)

            # soften deterministic output for backprop 
            act = defaultdict(list)
            for act_t in seq_logits:
                for k, a in act_t.items():
                    act[k].append(self._soft_act(a, requires_grad))
            act = {
                k: torch.stack(ac, 0).permute(1,0,2) 
                for k, ac in act.items()
            }   # dict [(B,A)]*T -> dict (B,T,A)
        else:
            stacked_obs = obs.reshape(bs*ts, -1)  # (B*T,O)
            act, _ = pi(stacked_obs)  # act is dict of (B*T,A)
            act = {
                k: self._soft_act(ac, requires_grad).reshape(bs, ts, -1)  
                for k, ac in act.items()
            }   # dict of (B,T,A)
        return act

    def step_thought(self, obs):
        """ run the first part of policy to generate thought  
        Arguments:
            obs: (B,O)
        Returns:
            thought: (B,H)
        """
        with torch.no_grad():
            thought, hidden_states = self.policy.get_thought(obs, self.policy_hidden_states)
            self.policy_hidden_states = hidden_states   # if mlp, still defafult None
        return thought 


    def step_action(self, h_i, h_j=None, explore=False):
        """
        Take a step forward in environment for a minibatch of observations
        equivalent to `act` or `compute_actions`
        Arguments:
            h_i, h_j: (B,H)
            explore: Whether or not to add exploration noise
        Returns:
            action: dict of actions for this agent, (B,A)
        """
        with torch.no_grad():
            action = self.policy.get_action(h_i, h_j)

            if self.discrete_action:
                for k in action:
                    if explore:
                        action[k] = gumbel_softmax(action[k], hard=True)
                    else:
                        action[k] = onehot_from_logits(action[k])
            else:  # continuous action
                idx = 0 
                noise = Variable(Tensor(self.exploration.noise()),
                                    requires_grad=False)
                for k in action:
                    if explore:
                        dim = action[k].shape[-1]
                        action[k] += noise[idx : idx+dim]
                        idx += dim 
                    action[k] = action[k].clamp(-1, 1)
        return action

    
    def initiate_comm(self, thought_vec, binary=False):
        """ output probability if to initiate communication
            if binary, return 0 or 1s (used in sampling)
        Arguments: 
            - thought_vec: (B,H) 
        Returns:
            - is_comm: (B,1) 
        """
        is_comm = self.attention_unit(thought_vec)
        if binary:
            is_comm = (is_comm > 0.5).long()
        return is_comm


    def integrate_thoughts(self, thought_vecs)
Exemple #11
0
    writer = SummaryWriter()
else:
    writer = None

env = gym.make(args.env_name)
action_dim = env.action_space.shape[0]
state_dim = env.observation_space.shape[0]
state_rms = RunningMeanStd(state_dim)

if args.algo == 'ppo':
    agent = PPO(writer, device, state_dim, action_dim, agent_args)
elif args.algo == 'sac':
    agent = SAC(writer, device, state_dim, action_dim, agent_args)
elif args.algo == 'ddpg':
    from utils.noise import OUNoise
    noise = OUNoise(action_dim, 0)
    agent = DDPG(writer, device, state_dim, action_dim, agent_args, noise)

if (torch.cuda.is_available()) and (args.use_cuda):
    agent = agent.cuda()

if args.load != 'no':
    agent.load_state_dict(torch.load("./model_weights/" + args.load))

score_lst = []
state_lst = []

if agent_args.on_policy == True:
    score = 0.0
    state_ = (env.reset())
    state = np.clip((state_ - state_rms.mean) / (state_rms.var**0.5 + 1e-8),