Beispiel #1
0
class Agent():
    """Interacts with and learns from the environment."""
    def __init__(self, state_size: int, action_size: int, seed: int,
                 n_agent: int):
        """Initialize an Agent object.

        Params
        ======
            state_size (int): dimension of each state
            action_size (int): dimension of each action
            seed (int): random seed
        """
        self.state_size = state_size
        self.action_size = action_size
        self.n_agent = n_agent
        self.seed = random.seed(seed)
        self.global_step = 0
        self.update_step = 0

        # Initialize actor and critic local and target networks
        self.actor = Actor(state_size,
                           action_size,
                           seed,
                           ACTOR_NETWORK_LINEAR_SIZES,
                           batch_normalization=ACTOR_BATCH_NORM).to(device)
        self.actor_target = Actor(
            state_size,
            action_size,
            seed,
            ACTOR_NETWORK_LINEAR_SIZES,
            batch_normalization=ACTOR_BATCH_NORM).to(device)
        self.critic = Critic(state_size,
                             action_size,
                             seed,
                             CRITIC_NETWORK_LINEAR_SIZES,
                             batch_normalization=CRITIC_BATCH_NORM).to(device)
        self.critic_second = Critic(
            state_size,
            action_size,
            seed,
            CRITIC_SECOND_NETWORK_LINEAR_SIZES,
            batch_normalization=CRITIC_BATCH_NORM).to(device)
        self.critic_second_target = Critic(
            state_size,
            action_size,
            seed,
            CRITIC_SECOND_NETWORK_LINEAR_SIZES,
            batch_normalization=CRITIC_BATCH_NORM).to(device)
        self.critic_target = Critic(
            state_size,
            action_size,
            seed,
            CRITIC_NETWORK_LINEAR_SIZES,
            batch_normalization=CRITIC_BATCH_NORM).to(device)
        self.actor_optimizer = optim.Adam(self.actor.parameters(),
                                          lr=ACTOR_LEARNING_RATE)
        self.critic_optimizer = optim.Adam(self.critic.parameters(),
                                           lr=CRITIC_LEARNING_RATE)
        self.critic_second_optimizer = optim.Adam(
            self.critic_second.parameters(), lr=CRITIC_LEARNING_RATE)
        # Replay memory
        self.memory = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed)
        # Initialize time step (for updating every UPDATE_EVERY steps)
        self.t_step = [0] * n_agent
        self.noise = OUNoise(action_size, seed, decay_period=50)

        # Copy parameters from local network to target network
        for target_param, param in zip(self.actor_target.parameters(),
                                       self.actor.parameters()):
            target_param.data.copy_(param.data)
        for target_param, param in zip(self.critic_target.parameters(),
                                       self.critic.parameters()):
            target_param.data.copy_(param.data)
        for target_param, param in zip(self.critic_second_target.parameters(),
                                       self.critic_second.parameters()):
            target_param.data.copy_(param.data)

    def step(self, state: np.array, action, reward, next_state, done, i_agent):
        # Save experience in replay memory
        self.memory.add(state, action, reward, next_state, done)

        # Learn every UPDATE_EVERY time steps.
        self.t_step[i_agent] = (self.t_step[i_agent] + 1) % UPDATE_EVERY
        # Learn, if enough samples are available in memory every UPDATE_EVERY
        if len(self.memory) > BATCH_SIZE and (not any(self.t_step)):
            for _ in range(LEARN_TIMES):
                experiences = self.memory.sample()
                self.learn(experiences, GAMMA)

    def noise_reset(self):
        self.noise.reset()

    def save_model(self, checkpoint_path: str = "./checkpoints/"):
        torch.save(self.actor.state_dict(), f"{checkpoint_path}/actor.pt")
        torch.save(self.critic.state_dict(), f"{checkpoint_path}/critic.pt")

    def load_model(self, checkpoint_path: str = "./checkpoints/checkpoint.pt"):
        self.actor.load_state_dict(torch.load(f"{checkpoint_path}/actor.pt"))
        self.critic.load_state_dict(torch.load(f"{checkpoint_path}/critic.pt"))

    def act(self, states: np.array, step: int):
        """Returns actions for given state as per current policy.

        Params
        ======
            state (array_like): current state
        """
        self.global_step += 1
        if self.global_step < WARM_UP_STEPS:
            action_values = np.random.rand(self.n_agent, self.action_size)
            return action_values
        states = torch.from_numpy(states).float().to(device)
        self.actor.eval()
        with torch.no_grad():
            action_values = self.actor(states).cpu().data.numpy()
        self.actor.train()
        action_values = [
            self.noise.get_action(action, t=step) for action in action_values
        ]
        return action_values

    def learn(self, experiences: tuple, gamma=GAMMA):
        """Update value parameters using given batch of experience tuples.

        Params
        ======
            experiences (Tuple[torch.Variable]): tuple of (s, a, r, s', done) tuples 
            gamma (float): discount factor
        """
        self.update_step += 1
        states, actions, rewards, next_states, dones = experiences
        # Critic loss
        mask = torch.tensor(1 - dones).detach().to(device)
        Q_values = self.critic(states, actions)
        Q_values_second = self.critic_second(states, actions)
        next_actions = self.actor_target(next_states)
        next_Q = torch.min(
            self.critic_target(next_states, next_actions.detach()),
            self.critic_second_target(next_states, next_actions.detach()))
        Q_prime = rewards + gamma * next_Q * mask

        if self.update_step % CRITIC_UPDATE_EVERY == 0:
            critic_loss = F.mse_loss(Q_values, Q_prime.detach())
            critic_second_loss = F.mse_loss(Q_values_second, Q_prime.detach())
            # Update first critic network
            self.critic_optimizer.zero_grad()
            critic_loss.backward()
            if CRITIC_GRADIENT_CLIPPING_VALUE:
                torch.nn.utils.clip_grad_norm_(self.critic.parameters(),
                                               CRITIC_GRADIENT_CLIPPING_VALUE)
            self.critic_optimizer.step()

            # Update second critic network
            self.critic_second_optimizer.zero_grad()
            critic_second_loss.backward()
            if CRITIC_GRADIENT_CLIPPING_VALUE:
                torch.nn.utils.clip_grad_norm_(self.critic_second.parameters(),
                                               CRITIC_GRADIENT_CLIPPING_VALUE)
            self.critic_second_optimizer.step()

        # Actor loss

        if self.update_step % POLICY_UPDATE_EVERY == 0:
            policy_loss = -self.critic(states, self.actor(states)).mean()

            # Update actor network
            self.actor_optimizer.zero_grad()
            policy_loss.backward()
            if ACTOR_GRADIENT_CLIPPING_VALUE:
                torch.nn.utils.clip_grad_norm_(self.actor.parameters(),
                                               ACTOR_GRADIENT_CLIPPING_VALUE)
            self.actor_optimizer.step()

        self.actor_soft_update()
        self.critic_soft_update()
        self.critic_second_soft_update()

    def actor_soft_update(self, tau: float = TAU):
        """Soft update for actor target network

        Args:
            tau (float, optional). Defaults to TAU.
        """
        for target_param, param in zip(self.actor_target.parameters(),
                                       self.actor.parameters()):
            target_param.data.copy_(tau * param.data +
                                    (1.0 - tau) * target_param.data)

    def critic_soft_update(self, tau: float = TAU):
        """Soft update for critic target network

        Args:
            tau (float, optional). Defaults to TAU.
        """
        for target_param, param in zip(self.critic_target.parameters(),
                                       self.critic.parameters()):
            target_param.detach_()
            target_param.data.copy_(tau * param.data +
                                    (1.0 - tau) * target_param.data)

    def critic_second_soft_update(self, tau: float = TAU):
        """Soft update for critic target network

        Args:
            tau (float, optional). Defaults to TAU.
        """
        for target_param, param in zip(self.critic_second_target.parameters(),
                                       self.critic_second.parameters()):
            target_param.detach_()
            target_param.data.copy_(tau * param.data +
                                    (1.0 - tau) * target_param.data)
Beispiel #2
0
class BiCNet():
    def __init__(self, s_dim, a_dim, n_agents, **kwargs):
        self.s_dim = s_dim
        self.a_dim = a_dim
        self.config = kwargs['config']
        self.n_agents = n_agents
        self.device = 'cuda' if self.config.use_cuda else 'cpu'
        # Networks
        self.policy = Actor(s_dim, a_dim, n_agents)
        self.policy_target = Actor(s_dim, a_dim, n_agents)
        self.critic = Critic(s_dim, a_dim, n_agents)
        self.critic_target = Critic(s_dim, a_dim, n_agents)

        if self.config.use_cuda:
            self.policy.cuda()
            self.policy_target.cuda()
            self.critic.cuda()
            self.critic_target.cuda()

        self.policy_optimizer = torch.optim.Adam(self.policy.parameters(),
                                                 lr=self.config.a_lr)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(),
                                                 lr=self.config.c_lr)

        hard_update(self.policy, self.policy_target)
        hard_update(self.critic, self.critic_target)

        self.random_process = OrnsteinUhlenbeckProcess(
            size=self.a_dim,
            theta=self.config.ou_theta,
            mu=self.config.ou_mu,
            sigma=self.config.ou_sigma)
        self.replay_buffer = list()
        self.epsilon = 1.
        self.depsilon = self.epsilon / self.config.epsilon_decay

        self.c_loss = None
        self.a_loss = None
        self.action_log = list()

    def choose_action(self, obs, noisy=True):
        obs = torch.Tensor([obs]).to(self.device)

        action = self.policy(obs).cpu().detach().numpy()[0]
        self.action_log.append(action)

        if noisy:
            for agent_idx in range(self.n_agents):
                pass
                # action[agent_idx] += self.epsilon * self.random_process.sample()
            self.epsilon -= self.depsilon
            self.epsilon = max(self.epsilon, 0.001)
        np.clip(action, -1., 1.)

        return action

    def reset(self):
        self.random_process.reset_states()
        self.action_log.clear()

    def prep_train(self):
        self.policy.train()
        self.critic.train()
        self.policy_target.train()
        self.critic_target.train()

    def prep_eval(self):
        self.policy.eval()
        self.critic.eval()
        self.policy_target.eval()
        self.critic_target.eval()

    def random_action(self):
        return np.random.uniform(low=-1, high=1, size=(self.n_agents, 2))

    def memory(self, s, a, r, s_, done):
        self.replay_buffer.append((s, a, r, s_, done))

        if len(self.replay_buffer) >= self.config.memory_length:
            self.replay_buffer.pop(0)

    def get_batches(self):
        experiences = random.sample(self.replay_buffer, self.config.batch_size)

        state_batches = np.array([_[0] for _ in experiences])
        action_batches = np.array([_[1] for _ in experiences])
        reward_batches = np.array([_[2] for _ in experiences])
        next_state_batches = np.array([_[3] for _ in experiences])
        done_batches = np.array([_[4] for _ in experiences])

        return state_batches, action_batches, reward_batches, next_state_batches, done_batches

    def train(self):

        state_batches, action_batches, reward_batches, next_state_batches, done_batches = self.get_batches(
        )

        state_batches = torch.Tensor(state_batches).to(self.device)
        action_batches = torch.Tensor(action_batches).to(self.device)
        reward_batches = torch.Tensor(reward_batches).reshape(
            self.config.batch_size, self.n_agents, 1).to(self.device)
        next_state_batches = torch.Tensor(next_state_batches).to(self.device)
        done_batches = torch.Tensor(
            (done_batches == False) * 1).reshape(self.config.batch_size,
                                                 self.n_agents,
                                                 1).to(self.device)

        target_next_actions = self.policy_target.forward(next_state_batches)
        target_next_q = self.critic_target.forward(next_state_batches,
                                                   target_next_actions)
        main_q = self.critic(state_batches, action_batches)
        '''
        How to concat each agent's Q value?
        '''
        #target_next_q = target_next_q
        #main_q = main_q.mean(dim=1)
        '''
        Reward Norm
        '''
        # reward_batches = (reward_batches - reward_batches.mean(dim=0)) / reward_batches.std(dim=0) / 1024

        # Critic Loss
        self.critic.zero_grad()
        baselines = reward_batches + done_batches * self.config.gamma * target_next_q
        loss_critic = torch.nn.MSELoss()(main_q, baselines.detach())
        loss_critic.backward()
        torch.nn.utils.clip_grad_norm_(self.critic.parameters(), 0.5)
        self.critic_optimizer.step()

        # Actor Loss
        self.policy.zero_grad()
        clear_action_batches = self.policy.forward(state_batches)
        loss_actor = -self.critic.forward(state_batches,
                                          clear_action_batches).mean()
        loss_actor += (clear_action_batches**2).mean() * 1e-3
        loss_actor.backward()
        torch.nn.utils.clip_grad_norm_(self.policy.parameters(), 0.5)
        self.policy_optimizer.step()

        # This is for logging
        self.c_loss = loss_critic.item()
        self.a_loss = loss_actor.item()

        soft_update(self.policy, self.policy_target, self.config.tau)
        soft_update(self.critic, self.critic_target, self.config.tau)

    def get_loss(self):
        return self.c_loss, self.a_loss

    def get_action_std(self):
        return np.array(self.action_log).std(axis=-1).mean()
Beispiel #3
0
class DDPGAgent(object):
    """
    General class for DDPG agents (policy, critic, target policy, target
    critic, exploration noise)
    """
    def __init__(self, num_in_pol, num_out_pol, num_in_critic, hidden_dim_actor=120,
    hidden_dim_critic=64,lr_actor=0.01,lr_critic=0.01,batch_size=64,
    max_episode_len=100,tau=0.02,gamma = 0.99,agent_name='one', discrete_action=False):
        """
        Inputs:
            num_in_pol (int): number of dimensions for policy input
            num_out_pol (int): number of dimensions for policy output
            num_in_critic (int): number of dimensions for critic input
        """
        self.policy = Actor(num_in_pol, num_out_pol,
                                 hidden_dim=hidden_dim_actor,
                                 discrete_action=discrete_action)
        self.critic = Critic(num_in_pol, 1,num_out_pol,
                                 hidden_dim=hidden_dim_critic)
        self.target_policy = Actor(num_in_pol, num_out_pol,
                                        hidden_dim=hidden_dim_actor,
                                        discrete_action=discrete_action)
        self.target_critic = Critic(num_in_pol, 1,num_out_pol,
                                        hidden_dim=hidden_dim_critic)
        hard_update(self.target_policy, self.policy)
        hard_update(self.target_critic, self.critic)
        self.policy_optimizer = Adam(self.policy.parameters(), lr=lr_actor)
        self.critic_optimizer = Adam(self.critic.parameters(), lr=lr_critic,weight_decay=0)
        
        self.policy = self.policy.float()
        self.critic = self.critic.float()
        self.target_policy = self.target_policy.float()
        self.target_critic = self.target_critic.float()

        self.agent_name = agent_name
        self.gamma = gamma
        self.tau = tau
        self.batch_size = batch_size
        #self.replay_buffer = ReplayBuffer(1e7)
        self.replay_buffer = ReplayBufferOption(500000,self.batch_size,12)
        self.max_replay_buffer_len = batch_size * max_episode_len
        self.replay_sample_index = None
        self.niter = 0
        self.eps = 5.0
        self.eps_decay = 1/(250*5)

        self.exploration = OUNoise(num_out_pol)
        self.discrete_action = discrete_action

        self.num_history = 2
        self.states = []
        self.actions = []
        self.rewards = []
        self.next_states = []
        self.dones = []

    def reset_noise(self):
        if not self.discrete_action:
            self.exploration.reset()

    def scale_noise(self, scale):
        if self.discrete_action:
            self.exploration = scale
        else:
            self.exploration.scale = scale

    def act(self, obs, explore=False):
        """
        Take a step forward in environment for a minibatch of observations
        Inputs:
            obs : Observations for this agent
            explore (boolean): Whether or not to add exploration noise
        Outputs:
            action (PyTorch Variable): Actions for this agent
        """
        #obs = obs.reshape(1,48)
        state = Variable(torch.Tensor(obs),requires_grad=False)

        self.policy.eval()
        with torch.no_grad():
            action = self.policy(state)
        self.policy.train()
        # continuous action
        if explore:
            action += Variable(Tensor(self.eps * self.exploration.sample()),requires_grad=False)
            action = torch.clamp(action, min=-1, max=1)
        return action

    def step(self, agent_id, state, action, reward, next_state, done,t_step):
        
        self.states.append(state)
        self.actions.append(action)
        self.rewards.append(reward)
        self.next_states.append(next_state)
        self.dones.append(done)

        #self.replay_buffer.add(state, action, reward, next_state, done)
        if t_step % self.num_history == 0:
            # Save experience / reward
            
            self.replay_buffer.add(self.states, self.actions, self.rewards, self.next_states, self.dones)
            self.states = []
            self.actions = []
            self.rewards = []
            self.next_states = []
            self.dones = []

        # Learn, if enough samples are available in memory
        if len(self.replay_buffer) > self.batch_size:
            
            obs, acs, rews, next_obs, don = self.replay_buffer.sample()     
            self.update(agent_id ,obs,  acs, rews, next_obs, don,t_step)
        


    def update(self, agent_id, obs, acs, rews, next_obs, dones ,t_step, logger=None):
    
        obs = torch.from_numpy(obs).float()
        acs = torch.from_numpy(acs).float()
        rews = torch.from_numpy(rews[:,agent_id]).float()
        next_obs = torch.from_numpy(next_obs).float()
        dones = torch.from_numpy(dones[:,agent_id]).float()

        acs = acs.view(-1,2)
                
        # --------- update critic ------------ #        
        self.critic_optimizer.zero_grad()
        
        all_trgt_acs = self.target_policy(next_obs) 
    
        target_value = (rews + self.gamma *
                        self.target_critic(next_obs,all_trgt_acs) *
                        (1 - dones)) 
        
        actual_value = self.critic(obs,acs)
        vf_loss = MSELoss(actual_value, target_value.detach())

        # Minimize the loss
        vf_loss.backward()
        #torch.nn.utils.clip_grad_norm_(self.critic.parameters(), 1)
        self.critic_optimizer.step()

        # --------- update actor --------------- #
        self.policy_optimizer.zero_grad()

        if self.discrete_action:
            curr_pol_out = self.policy(obs)
            curr_pol_vf_in = gumbel_softmax(curr_pol_out, hard=True)
        else:
            curr_pol_out = self.policy(obs)
            curr_pol_vf_in = curr_pol_out


        pol_loss = -self.critic(obs,curr_pol_vf_in).mean()
        #pol_loss += (curr_pol_out**2).mean() * 1e-3
        pol_loss.backward()

        torch.nn.utils.clip_grad_norm_(self.policy.parameters(), 1)
        self.policy_optimizer.step()

        self.update_all_targets()
        self.eps -= self.eps_decay
        self.eps = max(self.eps, 0)
        

        if logger is not None:
            logger.add_scalars('agent%i/losses' % self.agent_name,
                               {'vf_loss': vf_loss,
                                'pol_loss': pol_loss},
                               self.niter)

    def update_all_targets(self):
        """
        Update all target networks (called after normal updates have been
        performed for each agent)
        """
        
        soft_update(self.critic, self.target_critic, self.tau)
        soft_update(self.policy, self.target_policy, self.tau)
   
    def get_params(self):
        return {'policy': self.policy.state_dict(),
                'critic': self.critic.state_dict(),
                'target_policy': self.target_policy.state_dict(),
                'target_critic': self.target_critic.state_dict(),
                'policy_optimizer': self.policy_optimizer.state_dict(),
                'critic_optimizer': self.critic_optimizer.state_dict()}

    def load_params(self, params):
        self.policy.load_state_dict(params['policy'])
        self.critic.load_state_dict(params['critic'])
        self.target_policy.load_state_dict(params['target_policy'])
        self.target_critic.load_state_dict(params['target_critic'])
        self.policy_optimizer.load_state_dict(params['policy_optimizer'])
        self.critic_optimizer.load_state_dict(params['critic_optimizer'])
Beispiel #4
0
class DDPG:
    def __init__(self, beta, epsilon, learning_rate, gamma, tau, hidden_size_dim0, hidden_size_dim1, num_inputs, action_space, train_mode, alpha, replay_size,
                 optimizer, two_player, normalize_obs=True, normalize_returns=False, critic_l2_reg=1e-2):
        if torch.cuda.is_available():
            self.device = torch.device('cuda')
            torch.backends.cudnn.enabled = False
            self.Tensor = torch.cuda.FloatTensor
        else:
            self.device = torch.device('cpu')
            self.Tensor = torch.FloatTensor

        self.alpha = alpha
        self.train_mode = train_mode

        self.num_inputs = num_inputs
        self.action_space = action_space
        self.critic_l2_reg = critic_l2_reg

        self.actor = Actor(hidden_size_dim0, hidden_size_dim1, self.num_inputs, self.action_space).to(self.device)
        self.adversary = Actor(hidden_size_dim0, hidden_size_dim1, self.num_inputs, self.action_space).to(self.device)
        if self.train_mode:
            self.actor_target = Actor(hidden_size_dim0, hidden_size_dim1, self.num_inputs, self.action_space).to(self.device)
            self.actor_bar = Actor(hidden_size_dim0, hidden_size_dim1, self.num_inputs, self.action_space).to(self.device)
            self.actor_outer = Actor(hidden_size_dim0, hidden_size_dim1, self.num_inputs, self.action_space).to(self.device)
            if(optimizer == 'SGLD'):
                self.actor_optim = SGLD(self.actor.parameters(), lr=1e-4, noise=epsilon, alpha=0.999)
            elif(optimizer == 'RMSprop'):
                self.actor_optim = RMSprop(self.actor.parameters(), lr=1e-4, alpha=0.999)
            else:
                self.actor_optim = ExtraAdam(self.actor.parameters(), lr=1e-4)

            self.critic = Critic(hidden_size_dim0, hidden_size_dim1, self.num_inputs, self.action_space).to(self.device)
            self.critic_target = Critic(hidden_size_dim0, hidden_size_dim1, self.num_inputs, self.action_space).to(self.device)
            self.critic_optim = Adam(self.critic.parameters(), lr=1e-3, weight_decay=critic_l2_reg)

            self.adversary_target = Actor(hidden_size_dim0, hidden_size_dim1, self.num_inputs, self.action_space).to(self.device)
            self.adversary_bar = Actor(hidden_size_dim0, hidden_size_dim1, self.num_inputs, self.action_space).to(self.device)
            self.adversary_outer = Actor(hidden_size_dim0, hidden_size_dim1, self.num_inputs, self.action_space).to(self.device)
            if(optimizer == 'SGLD'):
                self.adversary_optim = SGLD(self.adversary.parameters(), lr=1e-4, noise=epsilon, alpha=0.999)
            elif(optimizer == 'RMSprop'):
                self.adversary_optim = RMSprop(self.adversary.parameters(), lr=1e-4, alpha=0.999)
            else:
                self.adversary_optim = ExtraAdam(self.adversary.parameters(), lr=1e-4)
			
            hard_update(self.adversary_target, self.adversary)  # Make sure target is with the same weight
            hard_update(self.actor_target, self.actor)  # Make sure target is with the same weight
            hard_update(self.critic_target, self.critic)

        self.gamma = gamma
        self.tau = tau
        self.beta = beta
        self.epsilon = epsilon
        self.learning_rate = learning_rate
        self.normalize_observations = normalize_obs
        self.normalize_returns = normalize_returns
        self.optimizer = optimizer
        self.two_player = two_player
        if self.normalize_observations:
            self.obs_rms = RunningMeanStd(shape=num_inputs)
        else:
            self.obs_rms = None

        if self.normalize_returns:
            self.ret_rms = RunningMeanStd(shape=1)
            self.ret = 0
            self.cliprew = 10.0
        else:
            self.ret_rms = None

        self.memory = ReplayMemory(replay_size)
       
    def eval(self):
        self.actor.eval()
        self.adversary.eval()
        if self.train_mode:
            self.critic.eval()

    def train(self):
        self.actor.train()
        self.adversary.train()
        if self.train_mode:
            self.critic.train()

    def select_action(self, state, action_noise=None, param_noise=None, mdp_type='mdp'):
        state = normalize(Variable(state).to(self.device), self.obs_rms, self.device)

        if mdp_type != 'mdp':
            
            if(self.optimizer == 'SGLD' and self.two_player):
                mu = self.actor_outer(state)
            else:
                mu = self.actor(state)
            mu = mu.data
            if action_noise is not None:
                mu += self.Tensor(action_noise()).to(self.device)

            mu = mu.clamp(-1, 1) * (1 - self.alpha)

            if(self.optimizer == 'SGLD' and self.two_player):
                adv_mu = self.adversary_outer(state)
            else:
                adv_mu = self.adversary(state)
            adv_mu = adv_mu.data.clamp(-1, 1) * self.alpha
            mu += adv_mu
            
        else:
 
            if(self.optimizer == 'SGLD' and self.two_player):
                mu = self.actor_outer(state)
            else:
                mu = self.actor(state)

            mu = mu.data
            if action_noise is not None:
                mu += self.Tensor(action_noise()).to(self.device)

            mu = mu.clamp(-1, 1)

        return mu

    def update_robust_non_flip(self, state_batch, action_batch, reward_batch, mask_batch, next_state_batch,
                      mdp_type, robust_update_type):
        # TRAIN CRITIC
        if robust_update_type == 'full':            
            next_action_batch = (1 - self.alpha) * self.actor_target(next_state_batch) \
                                    + self.alpha * self.adversary_target(next_state_batch)

            next_state_action_values = self.critic_target(next_state_batch, next_action_batch)
            expected_state_action_batch = reward_batch + self.gamma * mask_batch * next_state_action_values

            self.critic_optim.zero_grad()
            state_action_batch = self.critic(state_batch, action_batch)

            value_loss = F.mse_loss(state_action_batch, expected_state_action_batch)
            value_loss.backward()
            self.critic_optim.step()
            value_loss = value_loss.item()
        else:
            value_loss = 0
        
        # TRAIN ADVERSARY
        self.adversary_optim.zero_grad() 
        with torch.no_grad():
            if(self.optimizer == 'SGLD' and self.two_player):
                real_action = self.actor_outer(next_state_batch)
            else: 
                real_action = self.actor_target(next_state_batch)
        action = (1 - self.alpha) * real_action + self.alpha * self.adversary(next_state_batch)
        adversary_loss = self.critic(state_batch, action)
        adversary_loss = adversary_loss.mean()
        adversary_loss.backward()
        self.adversary_optim.step()
        adversary_loss = adversary_loss.item()
            
        # TRAIN ACTOR
        self.actor_optim.zero_grad()
        with torch.no_grad():
            if(self.optimizer == 'SGLD' and self.two_player):
                adversary_action = self.adversary_outer(next_state_batch)
            else:
                adversary_action = self.adversary_target(next_state_batch)
        action = (1 - self.alpha) * self.actor(next_state_batch) + self.alpha * adversary_action
        policy_loss = -self.critic(state_batch, action)
        policy_loss = policy_loss.mean()
        policy_loss.backward()
        self.actor_optim.step()
        policy_loss = policy_loss.item()
           
        return value_loss, policy_loss, adversary_loss
  
    def update_robust_flip(self, state_batch, action_batch, reward_batch, mask_batch, next_state_batch, adversary_update,
                      mdp_type, robust_update_type):
        # TRAIN CRITIC
        if robust_update_type == 'full':
           
            next_action_batch = (1 - self.alpha) * self.actor_target(next_state_batch) \
                                    + self.alpha * self.adversary_target(next_state_batch)

            next_state_action_values = self.critic_target(next_state_batch, next_action_batch)
            expected_state_action_batch = reward_batch + self.gamma * mask_batch * next_state_action_values

            self.critic_optim.zero_grad()

            state_action_batch = self.critic(state_batch, action_batch)

            value_loss = F.mse_loss(state_action_batch, expected_state_action_batch)
            value_loss.backward()
            self.critic_optim.step()
            value_loss = value_loss.item()
        else:
            value_loss = 0

        if adversary_update:
            # TRAIN ADVERSARY
            self.adversary_optim.zero_grad()
           
            with torch.no_grad():
                real_action = self.actor_target(next_state_batch)
            action = (1 - self.alpha) * real_action + self.alpha * self.adversary(next_state_batch)
            adversary_loss = self.critic(state_batch, action)

            adversary_loss = adversary_loss.mean()
            adversary_loss.backward()
            self.adversary_optim.step()
            adversary_loss = adversary_loss.item()
            policy_loss = 0
        else:
            # TRAIN ACTOR
            self.actor_optim.zero_grad()
            with torch.no_grad():
                adversary_action = self.adversary_target(next_state_batch)
            action = (1 - self.alpha) * self.actor(next_state_batch) + self.alpha * adversary_action
            policy_loss = -self.critic(state_batch, action)

            policy_loss = policy_loss.mean()
            policy_loss.backward()
            self.actor_optim.step()
            policy_loss = policy_loss.item()
            adversary_loss = 0

        return value_loss, policy_loss, adversary_loss
  
    def update_non_robust(self, state_batch, action_batch, reward_batch, mask_batch, next_state_batch):
        
        # TRAIN CRITIC

        next_action_batch = self.actor_target(next_state_batch)
        next_state_action_values = self.critic_target(next_state_batch, next_action_batch)

        expected_state_action_batch = reward_batch + self.gamma * mask_batch * next_state_action_values

        self.critic_optim.zero_grad()
        state_action_batch = self.critic(state_batch, action_batch)
        value_loss = F.mse_loss(state_action_batch, expected_state_action_batch)
        value_loss.backward()
        self.critic_optim.step()

        # TRAIN ACTOR
        self.actor_optim.zero_grad()
        action = self.actor(next_state_batch)
        policy_loss = -self.critic(state_batch, action)
        policy_loss = policy_loss.mean()
        policy_loss.backward()
        self.actor_optim.step()
        policy_loss = policy_loss.item()
        adversary_loss = 0

        return value_loss.item(), policy_loss, adversary_loss

    def store_transition(self, state, action, mask, next_state, reward):
        B = state.shape[0]
        for b in range(B):
            self.memory.push(state[b], action[b], mask[b], next_state[b], reward[b])
            if self.normalize_observations:
                self.obs_rms.update(state[b].cpu().numpy())
            if self.normalize_returns:
                self.ret = self.ret * self.gamma + reward[b]
                self.ret_rms.update(np.array([self.ret]))
                if mask[b] == 0:  # if terminal is True
                    self.ret = 0

    def update_parameters(self, batch_size, sgld_outer_update, mdp_type='mdp', exploration_method='mdp'):
        transitions = self.memory.sample(batch_size)
        batch = Transition(*zip(*transitions))

        if mdp_type != 'mdp':
            robust_update_type = 'full'
        elif exploration_method != 'mdp':
            robust_update_type = 'adversary'
        else:
            robust_update_type = None

        state_batch = normalize(Variable(torch.stack(batch.state)).to(self.device), self.obs_rms, self.device)
        action_batch = Variable(torch.stack(batch.action)).to(self.device)
        reward_batch = normalize(Variable(torch.stack(batch.reward)).to(self.device).unsqueeze(1), self.ret_rms, self.device)
        mask_batch = Variable(torch.stack(batch.mask)).to(self.device).unsqueeze(1)
        next_state_batch = normalize(Variable(torch.stack(batch.next_state)).to(self.device), self.obs_rms, self.device)

        if self.normalize_returns:
            reward_batch = torch.clamp(reward_batch, -self.cliprew, self.cliprew)

        value_loss = 0
        policy_loss = 0
        adversary_loss = 0
        
        if robust_update_type is not None:
           
            _value_loss, _policy_loss, _adversary_loss = self.update_robust_non_flip(state_batch, action_batch, reward_batch,
                                                                            mask_batch, next_state_batch,                                                                                                                            mdp_type, robust_update_type)
            value_loss += _value_loss
            policy_loss += _policy_loss
            adversary_loss += _adversary_loss
           
        if robust_update_type != 'full':
            _value_loss, _policy_loss, _adversary_loss = self.update_non_robust(state_batch, action_batch,
                                                                                reward_batch,
                                                                                mask_batch, next_state_batch)
            value_loss += _value_loss
            policy_loss += _policy_loss
            adversary_loss += _adversary_loss
        
        if(self.optimizer == 'SGLD' and self.two_player):   
            self.sgld_inner_update()
        self.soft_update()
        if(sgld_outer_update and self.optimizer == 'SGLD' and self.two_player):
            self.sgld_outer_update()

        return value_loss, policy_loss, adversary_loss

    def initialize(self):
        hard_update(self.actor_bar, self.actor_outer)
        hard_update(self.adversary_bar, self.adversary_outer)
        hard_update(self.actor, self.actor_outer)
        hard_update(self.adversary, self.adversary_outer)

    def sgld_inner_update(self): #target source
        sgld_update(self.actor_bar, self.actor, self.beta)
        sgld_update(self.adversary_bar, self.adversary, self.beta)

    def sgld_outer_update(self): #target source
        sgld_update(self.actor_outer, self.actor_bar, self.beta)
        sgld_update(self.adversary_outer, self.adversary_bar, self.beta)

    def soft_update(self):
        soft_update(self.actor_target, self.actor, self.tau)
        soft_update(self.adversary_target, self.adversary, self.tau)
        soft_update(self.critic_target, self.critic, self.tau)
Beispiel #5
0
class Agent():
    """Interacts with and learns from the environment."""
    def __init__(self, state_size, action_size, random_seed):
        """Initialize an Agent object.
        
        Params
        ======
            state_size (int): dimension of each state
            action_size (int): dimension of each action
            random_seed (int): random seed
        """
        self.state_size = state_size
        self.action_size = action_size
        self.seed = random.seed(random_seed)

        # Actor Network (w/ Target Network)
        self.actor_local = Actor(state_size, action_size,
                                 random_seed).to(device)
        self.actor_target = Actor(state_size, action_size,
                                  random_seed).to(device)
        self.actor_optimizer = optim.Adam(self.actor_local.parameters(),
                                          lr=LR_ACTOR)

        # Critic Network (w/ Target Network)
        self.critic_local = Critic(state_size, action_size,
                                   random_seed).to(device)
        self.critic_target = Critic(state_size, action_size,
                                    random_seed).to(device)
        self.critic_optimizer = optim.Adam(self.critic_local.parameters(),
                                           lr=LR_CRITIC)

        # Replay memory
        self.memory = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE,
                                   random_seed)

        self.noise = OUNoise((action_size), random_seed)

        # Make sure target is initialized with the same weight as the source (found on slack to make big difference)
        self.hard_update(self.actor_target, self.actor_local)
        self.hard_update(self.critic_target, self.critic_local)

    def step(self, states, actions, rewards, next_states, dones):
        """Save experience in replay memory, and use random sample from buffer to learn."""
        self.memory.add(states, actions, rewards, next_states, dones)

        if len(self.memory) > BATCH_SIZE:
            experiences = self.memory.sample()
            self.learn(experiences, GAMMA)

    def act(self, states, add_noise=True):
        """Returns actions for given state as per current policy."""
        states = torch.from_numpy(states).float().to(device)
        self.actor_local.eval()
        with torch.no_grad():
            actions = self.actor_local(states).cpu().data.numpy()
        self.actor_local.train()

        if add_noise:
            actions += self.noise.sample()

        return np.clip(actions, -1, 1)

    def reset(self):
        """Noise reset."""
        self.noise.reset()
        self.i_step = 0

    def learn(self, experiences, gamma):
        """Update policy and value parameters using given batch of experience tuples.
        Q_targets = r + γ * critic_target(next_state, actor_target(next_state))
        where:
            actor_target(state) -> action
            critic_target(state, action) -> Q-value

        Params
        ======
            experiences (Tuple[torch.Tensor]): tuple of (s, a, r, s', done) tuples 
            gamma (float): discount factor
        """
        states, actions, rewards, next_states, dones = experiences

        # ---------------------------- update critic ---------------------------- #
        # Get predicted next-state actions and Q values from target models
        actions_next = self.actor_target(next_states)
        Q_targets_next = self.critic_target(next_states, actions_next)
        # Compute Q targets for current states (y_i)
        Q_targets = rewards + (gamma * Q_targets_next * (1 - dones))
        # Compute critic loss
        Q_expected = self.critic_local(states, actions)
        critic_loss = F.mse_loss(Q_expected, Q_targets)
        # Minimize the loss
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

        # ---------------------------- update actor ---------------------------- #
        # Compute actor loss
        actions_pred = self.actor_local(states)
        actor_loss = -self.critic_local(states, actions_pred).mean()
        # Minimize the loss
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        # ----------------------- update target networks ----------------------- #
        self.soft_update(self.critic_local, self.critic_target, TAU)
        self.soft_update(self.actor_local, self.actor_target, TAU)

    def soft_update(self, local_model, target_model, tau):
        """Soft update model parameters.
        θ_target = τ*θ_local + (1 - τ)*θ_target

        Params
        ======
            local_model: PyTorch model (weights will be copied from)
            target_model: PyTorch model (weights will be copied to)
            tau (float): interpolation parameter 
        """
        for target_param, local_param in zip(target_model.parameters(),
                                             local_model.parameters()):
            target_param.data.copy_(tau * local_param.data +
                                    (1.0 - tau) * target_param.data)

    def hard_update(self, target, source):
        for target_param, source_param in zip(target.parameters(),
                                              source.parameters()):
            target_param.data.copy_(source_param.data)
Beispiel #6
0
class Agent():
    """Interacts with and learns from the environment."""
    def __init__(self,
                 state_size,
                 action_size,
                 aid=0,
                 num_agents=2,
                 seed=1234):
        """Initialize an Agent object.
        
        Params
        ======
            state_size (int): dimension of each state
            action_size (int): dimension of each action
            random_seed (int): random seed
        """
        self.state_size = state_size
        self.action_size = action_size

        # Actor Network (w/ Target Network)
        self.actor_local = Actor(state_size, action_size, seed=seed).to(device)
        self.actor_target = Actor(state_size, action_size,
                                  seed=seed).to(device)
        self.actor_optimizer = Adam(self.actor_local.parameters(), lr=LR_ACTOR)

        # Critic Network (w/ Target Network)
        self.critic_local = Critic(state_size,
                                   action_size,
                                   num_agents=num_agents,
                                   seed=seed).to(device)
        self.critic_target = Critic(state_size,
                                    action_size,
                                    num_agents=num_agents,
                                    seed=seed).to(device)
        self.critic_optimizer = Adam(self.critic_local.parameters(),
                                     lr=LR_CRITIC,
                                     weight_decay=WEIGHT_DECAY)

        # Noise process
        self.noise = OUNoise(action_size, seed=seed)

    def act(self, state, add_noise=True):
        """Returns actions for given state as per current policy."""
        state = torch.from_numpy(state).float().to(device)
        self.actor_local.eval()
        with torch.no_grad():
            action = self.actor_local(state).cpu().data.numpy()
        self.actor_local.train()
        if add_noise: action += self.noise.sample()
        return np.clip(action, -1, 1)

    def reset(self):
        self.noise.reset()

    def update_targets(self, tau):
        soft_update(self.critic_local, self.critic_target, tau)
        soft_update(self.actor_local, self.actor_target, tau)

    def update_critic(self, states, actions, next_states, next_actions,
                      rewards, dones):
        Q_targets_next = self.critic_target(next_states, next_actions)
        # Compute Q targets for current states (y_i)
        Q_targets = rewards + (GAMMA * Q_targets_next * (1 - dones))
        # Compute critic loss
        Q_expected = self.critic_local(states, actions)
        critic_loss = F.mse_loss(Q_expected, Q_targets)
        # Minimize the loss
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.critic_local.parameters(), 1)
        self.critic_optimizer.step()

    def update_actor(self, states, actions):
        actor_loss = -self.critic_local(states, actions).mean()
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()
class DDPGAgent:
    def __init__(self, state_size=24, action_size=2, seed=1, num_agents=2):
        self.state_size = state_size
        self.action_size = action_size
        self.seed = random.seed(seed)
        self.num_agents = num_agents

        # DDPG specific configuration
        hidden_size = 512
        self.CHECKPOINT_FOLDER = './'

        # Defining networks
        self.actor = Actor(state_size, hidden_size, action_size).to(device)
        self.actor_target = Actor(state_size, hidden_size, action_size).to(device)

        self.critic = Critic(state_size, self.action_size, hidden_size, 1).to(device)
        self.critic_target = Critic(state_size, self.action_size, hidden_size, 1).to(device)

        self.optimizer_actor = optim.Adam(self.actor.parameters(), lr=ACTOR_LR)
        self.optimizer_critic = optim.Adam(self.critic.parameters(), lr=CRITIC_LR)

        # Noise
        self.noises = OUNoise((num_agents, action_size), seed)

        # Initialize replay buffer
        self.memory = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed)

    def act(self, state, add_noise=True):
        '''
        Returns action to be taken based on state provided as the input
        '''
        state = torch.from_numpy(state).float().to(device)
        actions = np.zeros((self.num_agents, self.action_size))

        self.actor.eval()

        with torch.no_grad():
            for agent_num, state in enumerate(state):
                action = self.actor(state).cpu().data.numpy()
                actions[agent_num, :] = action

        self.actor.train()

        if add_noise:
            actions += self.noises.sample()

        return np.clip(actions, -1, 1)

    def reset(self):
        self.noises.reset()

    def learn(self, experiences):
        '''
        Trains the actor critic network using experiences
        '''
        states, actions, rewards, next_states, dones = experiences

        # Update Critic
        actions_next = self.actor_target(next_states)
        # print(next_states.shape, actions_next.shape)
        Q_targets_next = self.critic_target(next_states, actions_next)
        Q_targets = rewards + GAMMA*Q_targets_next*(1-dones)

        Q_expected = self.critic(states, actions)
        critic_loss = F.mse_loss(Q_expected, Q_targets)

        self.optimizer_critic.zero_grad()
        critic_loss.backward()
        self.optimizer_critic.step()

        # Update Actor
        actions_pred = self.actor(states)
        actor_loss = -self.critic(states, actions_pred).mean()

        self.optimizer_actor.zero_grad()
        actor_loss.backward()
        self.optimizer_actor.step()

        # Updating the local networks
        self.soft_update(self.critic, self.critic_target)
        self.soft_update(self.actor, self.actor_target)


    def soft_update(self, model, model_target):
        tau = TAU
        for target_param, local_param in zip(model_target.parameters(), model.parameters()):
            target_param.data.copy_(tau * local_param.data + (1.0 - tau) * target_param.data)

    def step(self, state, action, reward, next_state, done):
        '''
        Adds experience to memory and learns if the memory contains sufficient samples
        '''
        for i in range(self.num_agents):
            self.memory.add(state[i, :], action[i, :], reward[i], next_state[i, :], done[i])

        if len(self.memory) > BATCH_SIZE:
            # print("Now Learning")
            experiences = self.memory.sample()
            self.learn(experiences)

    def checkpoint(self):
        '''
        Saves the actor critic network on disk
        '''
        torch.save(self.actor.state_dict(), self.CHECKPOINT_FOLDER + 'checkpoint_actor.pth')
        torch.save(self.critic.state_dict(), self.CHECKPOINT_FOLDER + 'checkpoint_critic.pth')

    def load(self):
        '''
        Loads the actor critic network from disk
        '''
        self.actor.load_state_dict(torch.load(self.CHECKPOINT_FOLDER + 'checkpoint_actor.pth'))
        self.actor_target.load_state_dict(torch.load(self.CHECKPOINT_FOLDER + 'checkpoint_actor.pth'))
        self.critic.load_state_dict(torch.load(self.CHECKPOINT_FOLDER + 'checkpoint_critic.pth'))
        self.critic_target.load_state_dict(torch.load(self.CHECKPOINT_FOLDER + 'checkpoint_critic.pth'))
Beispiel #8
0
class Agent():
    def __init__(self,
                 state_size,
                 action_size,
                 replay_memory,
                 random_seed=0,
                 nb_agent=20,
                 bs=128,
                 gamma=0.99,
                 tau=1e-3,
                 lr_actor=1e-4,
                 lr_critic=1e-4,
                 wd_actor=0,
                 wd_critic=0,
                 clip_actor=None,
                 clip_critic=None,
                 update_interval=20,
                 update_times=10):

        self.state_size = state_size
        self.action_size = action_size
        self.seed = random.seed(random_seed)
        self.nb_agent = nb_agent
        self.bs = bs
        self.update_interval = update_interval
        self.update_times = update_times
        self.timestep = 0

        self.gamma = gamma
        self.tau = tau
        self.lr_actor = lr_actor
        self.lr_critic = lr_critic
        self.wd_critic = wd_critic
        self.wd_actor = wd_actor
        self.clip_critic = clip_critic
        self.clip_actor = clip_actor
        self.actor_losses = []
        self.critic_losses = []

        # Actor Network (w/ Target Network)
        self.actor_local = Actor(state_size, action_size,
                                 random_seed).to(device)
        self.actor_target = Actor(state_size, action_size,
                                  random_seed).to(device)
        self.actor_optimizer = optim.Adam(self.actor_local.parameters(),
                                          lr=self.lr_actor,
                                          weight_decay=self.wd_actor)

        # Critic Network (w/ Target Network)
        self.critic_local = Critic(state_size, action_size,
                                   random_seed).to(device)
        self.critic_target = Critic(state_size, action_size,
                                    random_seed).to(device)
        self.critic_optimizer = optim.Adam(self.critic_local.parameters(),
                                           lr=self.lr_critic,
                                           weight_decay=self.wd_critic)

        # Noise process
        self.noise = OUNoise((self.nb_agent, action_size), random_seed)

        # Replay memory
        self.memory = replay_memory

    def step(self, states, actions, rewards, next_states, dones):
        """Save experience in replay memory, and use random sample from buffer to learn."""
        #increment timestep
        self.timestep += 1

        # Save experience / reward
        for state, action, reward, next_state, done in zip(
                states, actions, rewards, next_states, dones):
            self.memory.add(state, action, reward, next_state, done)

        # Learn, if enough samples are available in memory
        if self.timestep % self.update_interval == 0:
            for i in range(self.update_times):
                if len(self.memory) > self.bs:
                    experiences = self.memory.sample(self.bs)
                    self.learn(experiences)

    def act(self, state, add_noise=True):
        """Returns actions for given state as per current policy."""
        state = torch.from_numpy(state).float().to(device)
        self.actor_local.eval()
        with torch.no_grad():
            action = self.actor_local(state).cpu().data.numpy()
        self.actor_local.train()
        if add_noise:
            action += self.noise.sample()
        return np.clip(action, -1, 1)

    def reset_noise(self):
        self.noise.reset()

    def learn(self, experiences):
        """Update policy and value parameters using given batch of experience tuples.
        Q_targets = r + γ * critic_target(next_state, actor_target(next_state))
        where:
            actor_target(state) -> action
            critic_target(state, action) -> Q-value
        Params
        ======
            experiences (Tuple[torch.Tensor]): tuple of (s, a, r, s', done) tuples 
            gamma (float): discount factor
        """
        states, actions, rewards, next_states, dones = experiences

        # ---------------------------- update critic ---------------------------- #
        # Get predicted next-state actions and Q values from target models
        actions_next = self.actor_target(next_states)
        Q_targets_next = self.critic_target(next_states, actions_next)
        # Compute Q targets for current states (y_i)
        Q_targets = rewards + (self.gamma * Q_targets_next * (1 - dones))
        # Compute critic loss
        Q_expected = self.critic_local(states, actions)
        critic_loss = F.mse_loss(Q_expected, Q_targets)
        # Minimize the loss
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        if self.clip_critic:
            torch.nn.utils.clip_grad_norm(self.critic_local.parameters(),
                                          self.clip_critic)
        self.critic_optimizer.step()

        # ---------------------------- update actor ---------------------------- #
        # Compute actor loss
        actions_pred = self.actor_local(states)
        actor_loss = -self.critic_local(states, actions_pred).mean()
        # Minimize the loss
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        if self.clip_actor:
            torch.nn.utils.clip_grad_norm(self.actor_local.parameters(),
                                          self.clip_actor)
        self.actor_optimizer.step()

        # ----------------------- update target networks ----------------------- #
        self.soft_update(self.critic_local, self.critic_target)
        self.soft_update(self.actor_local, self.actor_target)

        self.actor_losses.append(actor_loss.cpu().data.numpy())
        self.critic_losses.append(critic_loss.cpu().data.numpy())

    def soft_update(self, local_model, target_model):
        """Soft update model parameters.
        θ_target = τ*θ_local + (1 - τ)*θ_target
        Params
        ======
            local_model: PyTorch model (weights will be copied from)
            target_model: PyTorch model (weights will be copied to)
            tau (float): interpolation parameter 
        """
        for target_param, local_param in zip(target_model.parameters(),
                                             local_model.parameters()):
            target_param.data.copy_(self.tau * local_param.data +
                                    (1.0 - self.tau) * target_param.data)
Beispiel #9
0
class DDPG:
    def __init__(self,
                 gamma,
                 tau,
                 hidden_size,
                 num_inputs,
                 action_space,
                 train_mode,
                 alpha,
                 replay_size,
                 normalize_obs=True,
                 normalize_returns=False,
                 critic_l2_reg=1e-2):
        if torch.cuda.is_available():
            self.device = torch.device('cuda')
            torch.backends.cudnn.enabled = False
            self.Tensor = torch.cuda.FloatTensor
        else:
            self.device = torch.device('cpu')
            self.Tensor = torch.FloatTensor

        self.alpha = alpha
        self.train_mode = train_mode

        self.num_inputs = num_inputs
        self.action_space = action_space
        self.critic_l2_reg = critic_l2_reg

        self.actor = Actor(hidden_size, self.num_inputs,
                           self.action_space).to(self.device)
        self.adversary = Actor(hidden_size, self.num_inputs,
                               self.action_space).to(self.device)
        if self.train_mode:
            self.actor_target = Actor(hidden_size, self.num_inputs,
                                      self.action_space).to(self.device)
            self.actor_perturbed = Actor(hidden_size, self.num_inputs,
                                         self.action_space).to(self.device)
            self.actor_optim = Adam(self.actor.parameters(), lr=1e-4)

            self.critic = Critic(hidden_size, self.num_inputs,
                                 self.action_space).to(self.device)
            self.critic_target = Critic(hidden_size, self.num_inputs,
                                        self.action_space).to(self.device)
            self.critic_optim = Adam(self.critic.parameters(),
                                     lr=1e-3,
                                     weight_decay=critic_l2_reg)

            self.adversary_target = Actor(hidden_size, self.num_inputs,
                                          self.action_space).to(self.device)
            self.adversary_perturbed = Actor(hidden_size, self.num_inputs,
                                             self.action_space).to(self.device)
            self.adversary_optim = Adam(self.adversary.parameters(), lr=1e-4)

            hard_update(
                self.adversary_target,
                self.adversary)  # Make sure target is with the same weight
            hard_update(self.actor_target,
                        self.actor)  # Make sure target is with the same weight
            hard_update(self.critic_target, self.critic)

        self.gamma = gamma
        self.tau = tau
        self.normalize_observations = normalize_obs
        self.normalize_returns = normalize_returns

        if self.normalize_observations:
            self.obs_rms = RunningMeanStd(shape=num_inputs)
        else:
            self.obs_rms = None

        if self.normalize_returns:
            self.ret_rms = RunningMeanStd(shape=1)
            self.ret = 0
            self.cliprew = 10.0
        else:
            self.ret_rms = None

        self.memory = ReplayMemory(replay_size)

    def eval(self):
        self.actor.eval()
        self.adversary.eval()
        if self.train_mode:
            self.critic.eval()

    def train(self):
        self.actor.train()
        self.adversary.train()
        if self.train_mode:
            self.critic.train()

    def select_action(self,
                      state,
                      action_noise=None,
                      param_noise=None,
                      mdp_type='mdp'):
        state = normalize(
            Variable(state).to(self.device), self.obs_rms, self.device)

        if mdp_type != 'mdp':
            if mdp_type == 'nr_mdp':
                if param_noise is not None:
                    mu = self.actor_perturbed(state)
                else:
                    mu = self.actor(state)
                mu = mu.data
                if action_noise is not None:
                    mu += self.Tensor(action_noise()).to(self.device)

                mu = mu.clamp(-1, 1) * (1 - self.alpha)

                if param_noise is not None:
                    adv_mu = self.adversary_perturbed(state)
                else:
                    adv_mu = self.adversary(state)

                adv_mu = adv_mu.data.clamp(-1, 1) * self.alpha

                mu += adv_mu
            else:  # mdp_type == 'pr_mdp':
                if np.random.rand() < (1 - self.alpha):
                    if param_noise is not None:
                        mu = self.actor_perturbed(state)
                    else:
                        mu = self.actor(state)
                    mu = mu.data
                    if action_noise is not None:
                        mu += self.Tensor(action_noise()).to(self.device)

                    mu = mu.clamp(-1, 1)
                else:
                    if param_noise is not None:
                        mu = self.adversary_perturbed(state)
                    else:
                        mu = self.adversary(state)

                    mu = mu.data.clamp(-1, 1)

        else:
            if param_noise is not None:
                mu = self.actor_perturbed(state)
            else:
                mu = self.actor(state)
            mu = mu.data
            if action_noise is not None:
                mu += self.Tensor(action_noise()).to(self.device)

            mu = mu.clamp(-1, 1)

        return mu

    def update_robust(self, state_batch, action_batch, reward_batch,
                      mask_batch, next_state_batch, adversary_update, mdp_type,
                      robust_update_type):
        # TRAIN CRITIC
        if robust_update_type == 'full':
            if mdp_type == 'nr_mdp':
                next_action_batch = (1 - self.alpha) * self.actor_target(next_state_batch) \
                                    + self.alpha * self.adversary_target(next_state_batch)

                next_state_action_values = self.critic_target(
                    next_state_batch, next_action_batch)
            else:  # mdp_type == 'pr_mdp':
                next_action_actor_batch = self.actor_target(next_state_batch)
                next_action_adversary_batch = self.adversary_target(
                    next_state_batch)

                next_state_action_values = self.critic_target(next_state_batch, next_action_actor_batch) * (
                            1 - self.alpha) \
                                           + self.critic_target(next_state_batch,
                                                                       next_action_adversary_batch) * self.alpha

            expected_state_action_batch = reward_batch + self.gamma * mask_batch * next_state_action_values

            self.critic_optim.zero_grad()

            state_action_batch = self.critic(state_batch, action_batch)

            value_loss = F.mse_loss(state_action_batch,
                                    expected_state_action_batch)
            value_loss.backward()
            self.critic_optim.step()
            value_loss = value_loss.item()
        else:
            value_loss = 0

        if adversary_update:
            # TRAIN ADVERSARY
            self.adversary_optim.zero_grad()

            if mdp_type == 'nr_mdp':
                with torch.no_grad():
                    real_action = self.actor_target(next_state_batch)
                action = (
                    1 - self.alpha
                ) * real_action + self.alpha * self.adversary(next_state_batch)
                adversary_loss = self.critic(state_batch, action)
            else:  # mdp_type == 'pr_mdp'
                action = self.adversary(next_state_batch)
                adversary_loss = self.critic(state_batch, action) * self.alpha

            adversary_loss = adversary_loss.mean()
            adversary_loss.backward()
            self.adversary_optim.step()
            adversary_loss = adversary_loss.item()
            policy_loss = 0
        else:
            if robust_update_type == 'full':
                # TRAIN ACTOR
                self.actor_optim.zero_grad()

                if mdp_type == 'nr_mdp':
                    with torch.no_grad():
                        adversary_action = self.adversary_target(
                            next_state_batch)
                    action = (1 - self.alpha) * self.actor(
                        next_state_batch) + self.alpha * adversary_action
                    policy_loss = -self.critic(state_batch, action)
                else:  # mdp_type == 'pr_mdp':
                    action = self.actor(next_state_batch)
                    policy_loss = -self.critic(state_batch, action) * (
                        1 - self.alpha)

                policy_loss = policy_loss.mean()
                policy_loss.backward()
                self.actor_optim.step()

                policy_loss = policy_loss.item()
                adversary_loss = 0
            else:
                policy_loss = 0
                adversary_loss = 0

        return value_loss, policy_loss, adversary_loss

    def update_non_robust(self, state_batch, action_batch, reward_batch,
                          mask_batch, next_state_batch):
        # TRAIN CRITIC

        next_action_batch = self.actor_target(next_state_batch)
        next_state_action_values = self.critic_target(next_state_batch,
                                                      next_action_batch)

        expected_state_action_batch = reward_batch + self.gamma * mask_batch * next_state_action_values

        self.critic_optim.zero_grad()

        state_action_batch = self.critic(state_batch, action_batch)

        value_loss = F.mse_loss(state_action_batch,
                                expected_state_action_batch)
        value_loss.backward()
        self.critic_optim.step()

        # TRAIN ACTOR

        self.actor_optim.zero_grad()

        action = self.actor(next_state_batch)

        policy_loss = -self.critic(state_batch, action)

        policy_loss = policy_loss.mean()
        policy_loss.backward()
        self.actor_optim.step()

        policy_loss = policy_loss.item()
        adversary_loss = 0

        return value_loss.item(), policy_loss, adversary_loss

    def store_transition(self, state, action, mask, next_state, reward):
        B = state.shape[0]
        for b in range(B):
            self.memory.push(state[b], action[b], mask[b], next_state[b],
                             reward[b])
            if self.normalize_observations:
                self.obs_rms.update(state[b].cpu().numpy())
            if self.normalize_returns:
                self.ret = self.ret * self.gamma + reward[b]
                self.ret_rms.update(np.array([self.ret]))
                if mask[b] == 0:  # if terminal is True
                    self.ret = 0

    def update_parameters(self,
                          batch_size,
                          mdp_type='mdp',
                          adversary_update=False,
                          exploration_method='mdp'):
        transitions = self.memory.sample(batch_size)
        batch = Transition(*zip(*transitions))

        if mdp_type != 'mdp':
            robust_update_type = 'full'
        elif exploration_method != 'mdp':
            robust_update_type = 'adversary'
        else:
            robust_update_type = None

        state_batch = normalize(
            Variable(torch.stack(batch.state)).to(self.device), self.obs_rms,
            self.device)
        action_batch = Variable(torch.stack(batch.action)).to(self.device)
        reward_batch = normalize(
            Variable(torch.stack(batch.reward)).to(self.device).unsqueeze(1),
            self.ret_rms, self.device)
        mask_batch = Variable(torch.stack(batch.mask)).to(
            self.device).unsqueeze(1)
        next_state_batch = normalize(
            Variable(torch.stack(batch.next_state)).to(self.device),
            self.obs_rms, self.device)

        if self.normalize_returns:
            reward_batch = torch.clamp(reward_batch, -self.cliprew,
                                       self.cliprew)

        value_loss = 0
        policy_loss = 0
        adversary_loss = 0
        if robust_update_type is not None:
            _value_loss, _policy_loss, _adversary_loss = self.update_robust(
                state_batch, action_batch, reward_batch, mask_batch,
                next_state_batch, adversary_update, mdp_type,
                robust_update_type)
            value_loss += _value_loss
            policy_loss += _policy_loss
            adversary_loss += _adversary_loss
        if robust_update_type != 'full':
            _value_loss, _policy_loss, _adversary_loss = self.update_non_robust(
                state_batch, action_batch, reward_batch, mask_batch,
                next_state_batch)
            value_loss += _value_loss
            policy_loss += _policy_loss
            adversary_loss += _adversary_loss

        self.soft_update()

        return value_loss, policy_loss, adversary_loss

    def soft_update(self):
        soft_update(self.actor_target, self.actor, self.tau)
        soft_update(self.adversary_target, self.adversary, self.tau)
        soft_update(self.critic_target, self.critic, self.tau)

    def perturb_actor_parameters(self, param_noise):
        """Apply parameter noise to actor model, for exploration"""
        hard_update(self.actor_perturbed, self.actor)
        params = self.actor_perturbed.state_dict()
        for name in params:
            if 'ln' in name:
                pass
            param = params[name]
            param += torch.randn(param.shape).to(
                self.device) * param_noise.current_stddev
        """Apply parameter noise to adversary model, for exploration"""
        hard_update(self.adversary_perturbed, self.adversary)
        params = self.adversary_perturbed.state_dict()
        for name in params:
            if 'ln' in name:
                pass
            param = params[name]
            param += torch.randn(param.shape).to(
                self.device) * param_noise.current_stddev
Beispiel #10
0
class DDPG(Policy):
    def __init__(self,
                 gamma,
                 tau,
                 num_inputs,
                 action_space,
                 replay_size,
                 normalize_obs=True,
                 normalize_returns=False,
                 critic_l2_reg=1e-2,
                 num_outputs=1,
                 entropy_coeff=0.1,
                 action_coeff=0.1):

        super(DDPG, self).__init__(gamma=gamma,
                                   tau=tau,
                                   num_inputs=num_inputs,
                                   action_space=action_space,
                                   replay_size=replay_size,
                                   normalize_obs=normalize_obs,
                                   normalize_returns=normalize_returns)

        self.num_outputs = num_outputs
        self.entropy_coeff = entropy_coeff
        self.action_coeff = action_coeff
        self.critic_l2_reg = critic_l2_reg

        self.actor = Actor(self.num_inputs, self.action_space,
                           self.num_outputs).to(self.device)
        self.actor_target = Actor(self.num_inputs, self.action_space,
                                  self.num_outputs).to(self.device)
        self.actor_perturbed = Actor(self.num_inputs, self.action_space,
                                     self.num_outputs).to(self.device)
        self.actor_optim = Adam(self.actor.parameters(), lr=1e-4)

        self.critic = Critic(self.num_inputs + self.action_space.shape[0]).to(
            self.device)
        self.critic_target = Critic(self.num_inputs +
                                    self.action_space.shape[0]).to(self.device)
        self.critic_optim = Adam(self.critic.parameters(),
                                 lr=1e-3,
                                 weight_decay=critic_l2_reg)

        hard_update(self.actor_target,
                    self.actor)  # Make sure target is with the same weight
        hard_update(self.critic_target, self.critic)

    def eval(self):
        self.actor.eval()
        self.critic.eval()

    def train(self):
        self.actor.train()
        self.critic.train()

    def policy(self, actor, state):
        return actor(state), None

    def update_critic(self, state_batch, action_batch, reward_batch,
                      mask_batch, next_state_batch):
        batch_size = state_batch.shape[0]
        with torch.no_grad():
            tiled_next_state_batch = self._tile(next_state_batch, 0,
                                                self.num_outputs)
            next_action_batch, _, next_probs, _ = self.actor_target(
                next_state_batch)

            next_state_action_values = (self.critic_target(
                tiled_next_state_batch,
                next_action_batch.view(
                    batch_size * self.num_outputs, -1))[0].view(
                        batch_size, self.num_outputs) *
                                        next_probs).sum(-1).unsqueeze(-1)

            expected_state_action_batch = reward_batch + self.gamma * mask_batch * next_state_action_values

        self.critic_optim.zero_grad()
        state_action_batch = self.critic(state_batch, action_batch)[0]
        value_loss = F.mse_loss(state_action_batch,
                                expected_state_action_batch)
        value_loss.backward()
        self.critic_optim.step()

        return value_loss.item()

    def update_actor(self, state_batch):
        batch_size = state_batch.shape[0]

        tiled_state_batch = self._tile(state_batch, 0, self.num_outputs)
        action_batch, _, probs, dist_entropy = self.actor(state_batch)

        policy_loss = -(self.critic_target(
            tiled_state_batch,
            action_batch.view(batch_size * self.num_outputs, -1))[0].view(
                batch_size, self.num_outputs) * probs).sum(-1)
        entropy_loss = dist_entropy * self.entropy_coeff

        action_mse = 0
        action_batch = action_batch.view(batch_size, self.num_outputs, -1)
        for idx1 in range(self.num_outputs):
            for idx2 in range(idx1 + 1, self.num_outputs):
                action_mse += (
                    (action_batch[:, idx1, :] - action_batch[:, idx2, :])**
                    2).mean() * self.action_coeff / self.num_outputs

        return policy_loss - entropy_loss - action_mse

    def soft_update(self):
        soft_update(self.actor_target, self.actor, self.tau)
        soft_update(self.critic_target, self.critic, self.tau)
class Agent():
    """Interacts with and learns from the environment."""
    
    def __init__(self, state_size, action_size, random_seed):
        """Initialize an Agent object.
        
        Params
        ======
            state_size (int): dimension of each state
            action_size (int): dimension of each action
            random_seed (int): random seed
        """
        self.state_size = state_size
        self.action_size = action_size
        self.seed = random.seed(random_seed)
        self.t_step = 0
        
        # Actor Network (w/ Target Network)
        self.actor_local = Actor(state_size, action_size, random_seed).to(device)
        self.actor_target = Actor(state_size, action_size, random_seed).to(device)
        self.actor_optimizer = optim.Adam(self.actor_local.parameters(), lr=LR_ACTOR)

        # Critic Network (w/ Target Network)
        self.critic_local = Critic(state_size, action_size, random_seed).to(device)
        self.critic_target = Critic(state_size, action_size, random_seed).to(device)
        self.critic_optimizer = optim.Adam(self.critic_local.parameters(), lr=LR_CRITIC, weight_decay=WEIGHT_DECAY)

        # Noise process
        self.noise = OUNoise(action_size, random_seed)

        # Replay memory
        self.memory = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, random_seed)
    
    def step(self, state, action, reward, next_state, done):
     
        # Save experience in replay memory
        self.memory.add(state, action, reward, next_state, done)
        
        # Learn every UPDATE_EVERY time steps.
        self.t_step = (self.t_step + 1)
        if self.t_step  % UPDATE_EVERY == 0:
            # If enough samples are available in memory, get random subset and learn
            if len(self.memory) > BATCH_SIZE:
                for i in range(10):
                    experiences = self.memory.sample()
                    self.learn(experiences, GAMMA)
            
            
    def act(self, state, add_noise=True):
        """Returns actions for given state as per current policy."""
        state = torch.from_numpy(state).float().to(device)
        self.actor_local.eval()
        with torch.no_grad():
            action = self.actor_local(state).cpu().data.numpy()
        self.actor_local.train()
        if add_noise:
            action += 0.4 * self.noise.sample()
        return np.clip(action, -1, 1)

    def reset(self):
        self.noise.reset()

    def learn(self, experiences, gamma):
        """Update policy and value parameters using given batch of experience tuples.
        Q_targets = r + γ * critic_target(next_state, actor_target(next_state))
        where:
            actor_target(state) -> action
            critic_target(state, action) -> Q-value

        Params
        ======
            experiences (Tuple[torch.Tensor]): tuple of (s, a, r, s', done) tuples 
            gamma (float): discount factor
        """
        states, actions, rewards, next_states, dones = experiences

        # ---------------------------- update critic ---------------------------- #
        # Get predicted next-state actions and Q values from target models
        actions_next = self.actor_target(next_states)
        Q_targets_next = self.critic_target(next_states, actions_next)
        # Compute Q targets for current states (y_i)
        Q_targets = rewards + (gamma * Q_targets_next * (1 - dones))
        # Compute critic loss
        Q_expected = self.critic_local(states, actions)
        critic_loss = F.mse_loss(Q_expected, Q_targets)
        # Minimize the loss
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        torch.nn.utils.clip_grad_norm(self.critic_local.parameters(), 1)

        self.critic_optimizer.step()
        # ---------------------------- update actor ---------------------------- #
        # Compute actor loss
        actions_pred = self.actor_local(states)
        actor_loss = -self.critic_local(states, actions_pred).mean()
        # Minimize the loss
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        # ----------------------- update target networks ----------------------- #
        self.soft_update(self.critic_local, self.critic_target, TAU)
        self.soft_update(self.actor_local, self.actor_target, TAU)                     

    def soft_update(self, local_model, target_model, tau):
        """Soft update model parameters.
        θ_target = τ*θ_local + (1 - τ)*θ_target

        Params
        ======
            local_model: PyTorch model (weights will be copied from)
            target_model: PyTorch model (weights will be copied to)
            tau (float): interpolation parameter 
        """
        for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
            target_param.data.copy_(tau*local_param.data + (1.0-tau)*target_param.data)
Beispiel #12
0
class Agent:
    def __init__(self, replay_buffer, noise, state_dim, action_dim, seed, fc1_units = 256, fc2_units = 128,
                 device="cpu", lr_actor=1e-4, lr_critic=1e-3, batch_size=128, discount=0.99, tau=1e-3):
        torch.manual_seed(seed)

        self.actor_local = Actor(state_dim, action_dim, fc1_units, fc2_units, seed).to(device)
        self.critic_local = Critic(state_dim, action_dim, fc1_units, fc2_units, seed).to(device)
        
        self.actor_optimizer = optim.Adam(params=self.actor_local.parameters(), lr=lr_actor)
        self.critic_optimizer = optim.Adam(params=self.critic_local.parameters(), lr=lr_critic)
        
        self.actor_target = Actor(state_dim, action_dim, fc1_units, fc2_units, seed).to(device)
        self.critic_target = Critic(state_dim, action_dim, fc1_units, fc2_units, seed).to(device)

        self.buffer = replay_buffer
        self.noise = noise
        self.device = device
        self.batch_size = batch_size
        self.discount = discount

        self.tau = tau

        Agent.hard_update(model_local=self.actor_local, model_target=self.actor_target)
        Agent.hard_update(model_local=self.critic_local, model_target=self.critic_target)

    def step(self, states, actions, rewards, next_states, dones):
        for state, action, reward, next_state, done in zip(states, actions, rewards, next_states, dones):
            self.buffer.add(state=state, action=action, reward=reward, next_state=next_state, done=done)

        if self.buffer.size() >= self.batch_size:
            experiences = self.buffer.sample(self.batch_size)

            self.learn(self.to_tensor(experiences))

    def to_tensor(self, experiences):
        states, actions, rewards, next_states, dones = experiences

        states = torch.from_numpy(states).float().to(self.device)
        actions = torch.from_numpy(actions).float().to(self.device)
        rewards = torch.from_numpy(rewards).float().to(self.device)
        next_states = torch.from_numpy(next_states).float().to(self.device)
        dones = torch.from_numpy(dones.astype(np.uint8)).float().to(self.device)

        return states, actions, rewards, next_states, dones

    def act(self, states, add_noise=True):
        states = torch.from_numpy(states).float().to(device=self.device)
        self.actor_local.eval()
        with torch.no_grad():
            actions = self.actor_local(states).data.numpy()
        self.actor_local.train()

        if add_noise:
            actions += self.noise.sample()
        return np.clip(actions, -1, 1)

    def learn(self, experiences):
        states, actions, rewards, next_states, dones = experiences

        # Update critic
        next_actions = self.actor_target(next_states)
        q_target_next = self.critic_target(next_states, next_actions)
        q_target = rewards + self.discount * q_target_next * (1.0 - dones)
        q_local = self.critic_local(states, actions)
        critic_loss = F.mse_loss(input=q_local, target=q_target)

        self.critic_local.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

        actor_objective = self.critic_local(states, self.actor_local(states)).mean()
        self.actor_local.zero_grad()
        (-actor_objective).backward()
        self.actor_optimizer.step()

        Agent.soft_update(model_local=self.critic_local, model_target=self.critic_target, tau=self.tau)
        Agent.soft_update(model_local=self.actor_local, model_target=self.actor_target, tau=self.tau)

    @staticmethod
    def soft_update(model_local, model_target, tau):
        for local_param, target_param in zip(model_local.parameters(), model_target.parameters()):
            target_param.data.copy_(tau * local_param.data + (1.0 - tau) * target_param.data)

    @staticmethod
    def hard_update(model_local, model_target):
        Agent.soft_update(model_local=model_local, model_target=model_target, tau=1.0)

    def reset(self):
        self.noise.reset()
Beispiel #13
0
class Agent():
    """Interacts with and learns from the environment."""
    def __init__(self, state_size, action_size, num_agents, random_seed):
        """Initialize an Agent object.

        Params
        ======
            state_size (int): dimension of each state
            action_size (int): dimension of each action
            num_agents (int): number of agents
            random_seed (int): random seed
        """
        self.state_size = state_size
        self.action_size = action_size
        self.num_agents = num_agents
        self.seed = random.seed(random_seed)
        self.eps = eps_start
        self.eps_decay = 1 / (eps_p * LEARN_NUM
                              )  # set decay rate based on epsilon end target

        # Actor Network (w/ Target Network)
        self.actor_local = Actor(state_size, action_size,
                                 random_seed).to(device)
        self.actor_target = Actor(state_size, action_size,
                                  random_seed).to(device)
        self.actor_optimizer = optim.Adam(self.actor_local.parameters(),
                                          lr=LR_ACTOR)

        # Critic Network (w/ Target Network)
        self.critic_local = Critic(state_size, action_size,
                                   random_seed).to(device)
        self.critic_target = Critic(state_size, action_size,
                                    random_seed).to(device)
        self.critic_optimizer = optim.Adam(self.critic_local.parameters(),
                                           lr=LR_CRITIC,
                                           weight_decay=WEIGHT_DECAY)

        # Noise process
        self.noise = OUNoise((num_agents, action_size), random_seed)

        # Replay memory
        self.memory = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE,
                                   random_seed)

    def step(self, state, action, reward, next_state, done, agent_number,
             timestep):
        """Save experience in replay memory, and use random sample from buffer to learn."""

        # Save experience / reward
        self.memory.add(state, action, reward, next_state, done)
        # Learn, if enough samples are available in memory and at learning interval settings
        if len(self.memory) > BATCH_SIZE and timestep % 1 == 0:
            for _ in range(LEARN_NUM):
                experiences = self.memory.sample()
                self.learn(experiences, GAMMA, agent_number)

    def act(self, states, add_noise):
        """Returns actions for both agents as per current policy, given their respective states."""
        states = torch.from_numpy(states).float().to(device)
        actions = np.zeros((self.num_agents, self.action_size))
        self.actor_local.eval()
        with torch.no_grad():
            # get action for each agent and concatenate them
            for agent_num, state in enumerate(states):
                action = self.actor_local(state).cpu().data.numpy()
                actions[agent_num, :] = action
        self.actor_local.train()

        # add noise to actions
        if add_noise:
            actions += self.eps * self.noise.sample()
        actions = np.clip(actions, -1, 1)
        return actions

    def reset(self):
        self.noise.reset()

    def learn(self, experiences, gamma, agent_number):
        """Update policy and value parameters using given batch of experience tuples.
        Q_targets = r + γ * critic_target(next_state, actor_target(next_state))
        where:
            actor_target(state) -> action
            critic_target(state, action) -> Q-value
        Params
        ======
            experiences (Tuple[torch.Tensor]): tuple of (s, a, r, s', done) tuples
            gamma (float): discount factor
        """
        states, actions, rewards, next_states, dones = experiences

        # ---------------------------- update critic ---------------------------- #
        # Get predicted next-state actions and Q values from target models
        actions_next = self.actor_target(next_states)
        # Construct next actions vector relative to the agent
        if agent_number != 0:
            actions_next = torch.cat((actions[:, :2], actions_next), dim=1)
        else:
            actions_next = torch.cat((actions_next, actions[:, 2:]), dim=1)

        # Compute Q targets for current states (y_i)
        Q_targets_next = self.critic_target(next_states, actions_next)
        Q_targets = rewards + (gamma * Q_targets_next * (1 - dones))
        # Compute critic loss
        Q_expected = self.critic_local(states, actions)
        critic_loss = F.mse_loss(Q_expected, Q_targets)

        # Minimize the loss
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.critic_local.parameters(), 1)
        self.critic_optimizer.step()

        # ---------------------------- update actor ---------------------------- #
        # Compute actor loss
        actions_pred = self.actor_local(states)
        # Construct action prediction vector relative to each agent

        if agent_number != 0:
            actions_pred = torch.cat((actions[:, :2], actions_pred), dim=1)
        else:
            actions_pred = torch.cat((actions_pred, actions[:, 2:]), dim=1)

        # Compute actor loss
        actor_loss = -self.critic_local(states, actions_pred).mean()

        # Minimize the loss
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        # ----------------------- update target networks ----------------------- #
        self.soft_update(self.critic_local, self.critic_target, TAU)
        self.soft_update(self.actor_local, self.actor_target, TAU)

        # update noise decay parameter
        self.eps -= self.eps_decay
        self.eps = max(self.eps, eps_end)
        self.noise.reset()

    def soft_update(self, local_model, target_model, tau):
        """Soft update model parameters.
        θ_target = τ*θ_local + (1 - τ)*θ_target
        Params
        ======
            local_model: PyTorch model (weights will be copied from)
            target_model: PyTorch model (weights will be copied to)
            tau (float): interpolation parameter
        """
        for target_param, local_param in zip(target_model.parameters(),
                                             local_model.parameters()):
            target_param.data.copy_(tau * local_param.data +
                                    (1.0 - tau) * target_param.data)