예제 #1
0
class TestActor(unittest.TestCase):
    def setUp(self):
        self.state_dim = 10

        self.actor = Actor(action_dim=3,
                           state_dim=self.state_dim,
                           fc1_units=10,
                           fc2_units=10,
                           seed=0)

    def test_forward(self):
        n = 2
        batch = torch.tensor(np.random.random_sample((n, self.state_dim)),
                             dtype=torch.float)

        actions_1, probs_1, _ = self.actor.forward(batch)
        actions_2, probs_2, _ = self.actor.forward(batch, actions_1.view(-1))
        np.testing.assert_array_equal(probs_1.cpu().detach().numpy(),
                                      probs_2.cpu().detach().numpy())
class TestActor(unittest.TestCase):
    def setUp(self):
        self.state_dim = (2, 80, 80)

        self.actor = Actor(action_dim=3)

    def test_forward(self):
        n = 2
        batch = torch.tensor(np.random.random_sample((n, ) + self.state_dim),
                             dtype=torch.float)

        actions, probs = self.actor.forward(batch)
        self.assertEqual((n, 1), actions.size())
        self.assertEqual((n, 1), probs.size())
예제 #3
0
class DDPG:
    def __init__(self, state_dim, action_dim):
        self.critic = Critic(state_dim, action_dim).to(device)
        self.target_c = copy.deepcopy(self.critic)

        self.actor = Actor(state_dim).to(device)
        self.target_a = copy.deepcopy(self.actor)

        self.optimizer_c = optim.Adam(self.critic.parameters(), lr=LR)
        self.optimizer_a = optim.Adam(self.actor.parameters(), lr=LR)

    def act(self, state):
        state = torch.from_numpy(np.array(state)).float().to(device)
        return self.actor.forward(state).detach().squeeze(0).cpu().numpy()

    def update(self, batch):
        states, actions, rewards, next_states, dones = zip(*batch)
        states = torch.from_numpy(np.array(states)).float().to(device)
        actions = torch.from_numpy(np.array(actions)).float().to(device)
        rewards = torch.from_numpy(
            np.array(rewards)).float().to(device).unsqueeze(1)
        next_states = torch.from_numpy(
            np.array(next_states)).float().to(device)
        dones = torch.from_numpy(np.array(dones)).to(device)

        Q_current = self.critic(states, actions)
        Q_next = self.target_c(next_states,
                               self.target_a(next_states).detach())
        y = (rewards + GAMMA * Q_next).detach()

        ##################Update critic#######################
        loss_c = F.mse_loss(y, Q_current)
        self.optimizer_c.zero_grad()
        loss_c.backward()
        self.optimizer_c.step()

        ##################Update actor#######################
        loss_a = -self.critic.forward(states, self.actor(states)).mean()
        self.optimizer_a.zero_grad()
        loss_a.backward()
        self.optimizer_a.step()

        ##################Update targets#######################
        for target_pr, pr in zip(self.target_a.parameters(),
                                 self.actor.parameters()):
            target_pr.data.copy_(TAU * pr.data + (1 - TAU) * target_pr.data)

        for target_pr, pr in zip(self.target_c.parameters(),
                                 self.critic.parameters()):
            target_pr.data.copy_(TAU * pr.data + (1 - TAU) * target_pr.data)
class TestActor(unittest.TestCase):
    def setUp(self):
        self.state_dim = 24
        self.action_dim = 2

        self.actor = Actor(state_dim=self.state_dim,
                           action_dim=self.action_dim,
                           fc1_units=64,
                           fc2_units=64,
                           seed=0)

    def test_forward(self):
        n = 2
        states = torch.tensor(np.random.random_sample((n, self.state_dim)),
                              dtype=torch.float)

        actions = self.actor.forward(states)
        self.assertEqual((n, self.action_dim), actions.size())
예제 #5
0
        Q_S = Critic_T.forward(next_state, numpy.array([S_action])).item()
        if not is_done:
            Y.append(reward + GAMMA * Q_S)
        else:
            Y.append(reward)

    Q_s = critic.forward(numpy.array(states), numpy.array(actions))
    loss_critic = LOSS(torch.Tensor(Y), Q_s)
    S = numpy.array(states)
    loss_actor = -critic.forward(S, actor.forward(S))
    loss_actor.backward()
    loss_critic.backward()
    Actor_optim.step()
    Critic_optim.step()


episodes_cnt = 0
while (episodes_cnt <= MAX_EPISODES):
    episodes_cnt += 1
    frame = env.reset()
    is_done = False
    step_cnt = 0
    while (not is_done and step_cnt < MAX_EPISODES_STEPS):
        step_cnt += 1
        actor.eval()
        action = actor.forward(frame)[0].item()
        new_frame, reward, is_done, _ = env.step([action])
        BUFFER.append((frame, action, reward, new_frame, is_done))
        if len(BUFFER) > 64:
            train()
예제 #6
0
파일: ddpg.py 프로젝트: YuanyeMa/RL
class DDPGAgent:
    def __init__(self,
                 plot=True,
                 seed=1,
                 env: gym.Env = None,
                 batch_size=128,
                 learning_rate_actor=0.001,
                 learning_rate_critic=0.001,
                 weight_decay=0.01,
                 gamma=0.999):

        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)

        self.state_dim = env.observation_space.shape[0]
        self.action_dim = env.action_space.shape[0]

        self.batch_size = batch_size
        self.learning_rate_actor = learning_rate_actor
        self.learning_rate_critic = learning_rate_critic
        self.weight_decay = weight_decay
        self.gamma = gamma
        self.tau = 0.001

        self._to_tensor = util.to_tensor
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')

        self.actor = Actor(self.state_dim, self.action_dim).to(self.device)
        self.target_actor = Actor(self.state_dim,
                                  self.action_dim).to(self.device)
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),
                                                self.learning_rate_actor,
                                                weight_decay=self.weight_decay)

        self.critic = Critic(self.state_dim, self.action_dim).to(self.device)
        self.target_critic = Critic(self.state_dim,
                                    self.action_dim).to(self.device)
        self.critic_optimizer = torch.optim.Adam(
            self.critic.parameters(),
            self.learning_rate_critic,
            weight_decay=self.weight_decay)

        hard_update(self.target_actor, self.actor)
        hard_update(self.target_critic, self.critic)
        self.t = 0

    def _learn_from_memory(self, memory):
        ''' 从记忆学习,更新两个网络的参数
        '''
        # 随机获取记忆里的Transition
        trans_pieces = memory.sample(self.batch_size)
        s0 = np.vstack([x.state for x in trans_pieces])
        a0 = np.vstack([x.action for x in trans_pieces])
        r1 = np.vstack([x.reward for x in trans_pieces])
        s1 = np.vstack([x.next_state for x in trans_pieces])
        terminal_batch = np.vstack([x.is_done for x in trans_pieces])

        # 优化评论家网络参数
        s1 = self._to_tensor(s1, device=self.device)
        s0 = self._to_tensor(s0, device=self.device)

        next_q_values = self.target_critic.forward(
            state=s1, action=self.target_actor.forward(s1)).detach()
        target_q_batch = self._to_tensor(r1, device=self.device) + \
            self.gamma*self._to_tensor(terminal_batch.astype(np.float), device=self.device)*next_q_values
        q_batch = self.critic.forward(s0,
                                      self._to_tensor(a0, device=self.device))

        # 计算critic的loss 更新critic网络参数
        loss_critic = F.mse_loss(q_batch, target_q_batch)
        #self.critic_optimizer.zero_grad()
        self.critic.zero_grad()
        loss_critic.backward()
        self.critic_optimizer.step()

        # 反向传播,以某状态的价值估计为策略目标函数
        loss_actor = -self.critic.forward(s0, self.actor.forward(s0))  # Q的梯度上升
        loss_actor = loss_actor.mean()
        self.actor.zero_grad()
        #self.actor_optimizer.zero_grad()
        loss_actor.backward()
        self.actor_optimizer.step()

        # 软更新参数
        soft_update(self.target_actor, self.actor, self.tau)
        soft_update(self.target_critic, self.critic, self.tau)
        return (loss_critic.item(), loss_actor.item())

    def learning(self, memory):
        self.actor.train()
        return self._learn_from_memory(memory)

    def save_models(self, episode_count):
        torch.save(self.target_actor.state_dict(),
                   './Models/' + str(episode_count) + '_actor.pt')
        torch.save(self.target_critic.state_dict(),
                   './Models/' + str(episode_count) + '_critic.pt')

    def load_models(self, episode):
        self.actor.load_state_dict(
            torch.load('./Models/' + str(episode) + '_actor.pt'))
        self.critic.load_state_dict(
            torch.load('./Models/' + str(episode) + '_critic.pt'))
        hard_update(self.target_actor, self.actor)
        hard_update(self.target_critic, self.critic)
        print('Models loaded successfully')
예제 #7
0
def main(args):

    with open(args.data_dir+'/ptb.vocab.json', 'r') as file:
        vocab = json.load(file)

    # required to map between integer-value sentences and real sentences
    w2i, i2w = vocab['w2i'], vocab['i2w']

    # make sure our models for the VAE and Actor exist
    if not os.path.exists(args.load_vae):
        raise FileNotFoundError(args.load_vae)

    model = SentenceVAE(
        vocab_size=len(w2i),
        sos_idx=w2i['<sos>'],
        eos_idx=w2i['<eos>'],
        pad_idx=w2i['<pad>'],
        unk_idx=w2i['<unk>'],
        max_sequence_length=args.max_sequence_length,
        embedding_size=args.embedding_size,
        rnn_type=args.rnn_type,
        hidden_size=args.hidden_size,
        word_dropout=args.word_dropout,
        embedding_dropout=args.embedding_dropout,
        latent_size=args.latent_size,
        num_layers=args.num_layers,
        bidirectional=args.bidirectional
    )

    model.load_state_dict(
        torch.load(args.load_vae, map_location=lambda storage, loc: storage))
    model.eval()
    print("vae model loaded from %s"%(args.load_vae))

    # to run in constraint mode, we need the trained generator
    if args.constraint_mode:
        if not os.path.exists(args.load_actor):
            raise FileNotFoundError(args.load_actor)

        actor = Actor(
            dim_z=args.latent_size, dim_model=2048, num_labels=args.n_tags)
        actor.load_state_dict(
            torch.load(args.load_actor, map_location=lambda storage, loc:storage))
        actor.eval()
        print("actor model loaded from %s"%(args.load_actor))

    if torch.cuda.is_available():
        model = model.cuda()
        if args.constraint_mode:
            actor = actor.cuda() # TODO: to(self.devices)

    if args.sample:
        print('*** SAMPLE Z: ***')
        # get samples from the prior
        sample_sents, z = model.inference(n=args.num_samples)
        sample_sents, sample_tags = get_sents_and_tags(sample_sents, i2w, w2i)
        pickle_it(z.cpu().numpy(), 'samples/z_sample_n{}.pkl'.format(args.num_samples))
        pickle_it(sample_sents, 'samples/sents_sample_n{}.pkl'.format(args.num_samples))
        pickle_it(sample_tags, 'samples/tags_sample_n{}.pkl'.format(args.num_samples))
        print(sample_sents, sep='\n')

        if args.constraint_mode:

            print('*** SAMPLE Z_PRIME: ***')
            # get samples from the prior, conditioned via the actor
            all_tags_sample_prime = []
            all_sents_sample_prime = {}
            all_z_sample_prime = {}
            for i, condition in enumerate(LABELS):

                # binary vector denoting each of the PHRASE_TAGS
                labels = torch.Tensor(condition).repeat(args.num_samples, 1).cuda()

                # take z and manipulate using the actor to generate z_prime
                z_prime = actor.forward(z, labels)

                sample_sents_prime, z_prime = model.inference(
                    z=z_prime, n=args.num_samples)
                sample_sents_prime, sample_tags_prime = get_sents_and_tags(
                    sample_sents_prime, i2w, w2i)
                print('conditoned on: {}'.format(condition))
                print(sample_sents_prime, sep='\n')
                all_tags_sample_prime.append(sample_tags_prime)
                all_sents_sample_prime[LABEL_NAMES[i]] = sample_sents_prime
                all_z_sample_prime[LABEL_NAMES[i]] = z_prime.data.cpu().numpy()
            pickle_it(all_tags_sample_prime, 'samples/tags_sample_prime_n{}.pkl'.format(args.num_samples))
            pickle_it(all_sents_sample_prime, 'samples/sents_sample_prime_n{}.pkl'.format(args.num_samples))
            pickle_it(all_z_sample_prime, 'samples/z_sample_prime_n{}.pkl'.format(args.num_samples))

    if args.interpolate:
        # get random samples from the latent space
        z1 = torch.randn([args.latent_size]).numpy()
        z2 = torch.randn([args.latent_size]).numpy()
        z = to_var(torch.from_numpy(interpolate(start=z1, end=z2, steps=args.num_samples-2)).float())

        print('*** INTERP Z: ***')
        interp_sents, _ = model.inference(z=z)
        interp_sents, interp_tags = get_sents_and_tags(interp_sents, i2w, w2i)
        pickle_it(z.cpu().numpy(), 'samples/z_interp_n{}.pkl'.format(args.num_samples))
        pickle_it(interp_sents, 'samples/sents_interp_n{}.pkl'.format(args.num_samples))
        pickle_it(interp_tags, 'samples/tags_interp_n{}.pkl'.format(args.num_samples))
        print(interp_sents, sep='\n')

        if args.constraint_mode:
            print('*** INTERP Z_PRIME: ***')
            all_tags_interp_prime = []
            all_sents_interp_prime = {}
            all_z_interp_prime = {}

            for i, condition in enumerate(LABELS):

                # binary vector denoting each of the PHRASE_TAGS
                labels = torch.Tensor(condition).repeat(args.num_samples, 1).cuda()

                # z prime conditioned on this particular binary variable
                z_prime = actor.forward(z, labels)

                interp_sents_prime, z_prime = model.inference(
                    z=z_prime, n=args.num_samples)
                interp_sents_prime, interp_tags_prime = get_sents_and_tags(
                    interp_sents_prime, i2w, w2i)
                print('conditoned on: {}'.format(condition))
                print(interp_sents_prime, sep='\n')
                all_tags_interp_prime.append(interp_tags_prime)
                all_sents_interp_prime[LABEL_NAMES[i]] = interp_sents_prime
                all_z_interp_prime[LABEL_NAMES[i]] = z_prime.data.cpu().numpy()

            pickle_it(all_tags_interp_prime, 'samples/tags_interp_prime_n{}.pkl'.format(args.num_samples))
            pickle_it(all_sents_interp_prime, 'samples/sents_interp_prime_n{}.pkl'.format(args.num_samples))
            pickle_it(all_z_interp_prime, 'samples/z_interp_prime_n{}.pkl'.format(args.num_samples))

    import IPython; IPython.embed()
예제 #8
0
class Agent():
    """Interacts with and learns from the environment."""
    def __init__(self, state_size, action_size, random_seed):
        """Initialize an Agent object.
        
        Params
        ======
            state_size (int): dimension of each state
            action_size (int): dimension of each action
            random_seed (int): random seed
            num agents (int): number of agents
        """
        self.state_size = state_size
        self.action_size = action_size
        self.seed = random.seed(random_seed)

        ####self.num_agents = num_agents

        # Actor Network (w/ Target Network)
        self.actor_local = Actor(state_size, action_size,
                                 random_seed).to(device)
        self.actor_target = Actor(state_size, action_size,
                                  random_seed).to(device)
        self.actor_optimizer = optim.Adam(self.actor_local.parameters(),
                                          lr=LR_ACTOR)

        # Critic Network (w/ Target Network)
        self.critic_local = Critic(state_size, action_size,
                                   random_seed).to(device)
        self.critic_target = Critic(state_size, action_size,
                                    random_seed).to(device)
        self.critic_optimizer = optim.Adam(self.critic_local.parameters(),
                                           lr=LR_CRITIC,
                                           weight_decay=WEIGHT_DECAY)

        # Noise process
        self.noise = OUNoise(action_size, random_seed)

        # Replay memory
        self.memory = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE,
                                   random_seed)

    def step(self, state, action, reward, next_state, done):
        """Save experience in replay memory, and use random sample from buffer to learn."""
        # Save experience / reward
        self.memory.add(state, action, reward, next_state, done)

        # Learn, if enough samples are available in memory
        if len(self.memory) > BATCH_SIZE:
            experiences = self.memory.sample()
            self.learn(experiences, GAMMA)

    def act(self, state, add_noise=True):
        """Returns actions for given state as per current policy."""
        state = torch.from_numpy(state).float().unsqueeze(0).to(device)
        self.actor_local.eval()
        with torch.no_grad():
            action = self.actor_local(state).cpu().data.numpy()
        self.actor_local.train()
        if add_noise:
            action += self.noise.sample()

        return np.clip(action, -1, 1)

    def reset(self):
        self.noise.reset()

    def learn(self, experiences, gamma):
        """Update policy and value parameters using given batch of experience tuples.
        Q_targets = r + γ * critic_target(next_state, actor_target(next_state))
        where:
            actor_target(state) -> action
            critic_target(state, action) -> Q-value
        Params
        ======
            experiences (Tuple[torch.Tensor]): tuple of (s, a, r, s', done) tuples 
            gamma (float): discount factor
        """
        states, actions, rewards, next_states, dones = experiences

        # ---------------------------- update critic ---------------------------- #
        # Get predicted next-state actions and Q values from target models
        actions_next = self.actor_target.forward(next_states)
        Q_targets_next = self.critic_target.forward(next_states, actions_next)
        # Compute Q targets for current states (y_i)
        Q_targets = rewards + (gamma * Q_targets_next * (1 - dones))
        # Compute critic loss
        Q_expected = self.critic_local.forward(states, actions)
        critic_loss = F.mse_loss(Q_expected, Q_targets)
        # Minimize the loss
        self.critic_optimizer.zero_grad()
        critic_loss.backward()

        torch.nn.utils.clip_grad_norm_(self.critic_local.parameters(), 1)

        self.critic_optimizer.step()

        # ---------------------------- update actor ---------------------------- #
        # Compute actor loss
        actions_pred = self.actor_local.forward(states)

        actor_loss = -self.critic_local.forward(states, actions_pred).mean()

        # Minimize the loss
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        # ----------------------- update target networks ----------------------- #
        self.soft_update(self.critic_local, self.critic_target, TAU)
        self.soft_update(self.actor_local, self.actor_target, TAU)

    def soft_update(self, local_model, target_model, tau):
        """Soft update model parameters.
        θ_target = τ*θ_local + (1 - τ)*θ_target
        Params
        ======
            local_model: PyTorch model (weights will be copied from)
            target_model: PyTorch model (weights will be copied to)
            tau (float): interpolation parameter 
        """
        for target_param, local_param in zip(target_model.parameters(),
                                             local_model.parameters()):
            target_param.data.copy_(tau * local_param.data +
                                    (1.0 - tau) * target_param.data)
예제 #9
0
class Agent():
    def __init__(self, state_size, action_size, num_agents, seed):
        """Initialize an Agent object.

        Params
        ======
            state_size (int): dimension of each state
            action_size (int): dimension of each action
            random_seed (int): random seed
        """
        self.seed = random.seed(seed)
        self.state_size = state_size
        self.action_size = action_size
        self.num_agents = num_agents

        #Actor Network
        self.actor_local = Actor(state_size, action_size, seed).to(device)
        self.actor_target = Actor(state_size, action_size, seed).to(device)
        self.actor_optimizer = optim.Adam(self.actor_local.parameters(),
                                          lr=LR_ACTOR)

        #Critic Network
        self.critic_local = Critic(state_size, action_size, seed).to(device)
        self.critic_target = Critic(state_size, action_size, seed).to(device)
        self.critic_optimizer = optim.Adam(self.critic_local.parameters(),
                                           lr=LR_CRITIC,
                                           weight_decay=WEIGHT_DECAY)

        # Noise process
        self.noise = OUNoise((num_agents, action_size), seed)

        # Replay memory
        self.memory = ReplayBuffer(BUFFER_SIZE, BATCH_SIZE, seed)

    def act(self, state, add_noise=True):
        """Returns actions for given state as per current policy."""
        state = torch.from_numpy(state).float().to(device)
        action = np.zeros((self.num_agents, self.action_size))
        self.actor_local.eval()  # set module to evaluation mode
        with torch.no_grad():
            for agent_idx, state_ in enumerate(state):
                action[agent_idx, :] = self.actor_local.forward(
                    state_).cpu().data.numpy()
        self.actor_local.train()  # reset it back to training mode

        if add_noise:
            action += self.noise.sample()

        return np.clip(action, -1, 1)  # restrict the output boundary -1, 1

    def reset(self):
        self.noise.reset()

    def step(self, state, action, reward, next_state, done, timeStep):
        """Save experience in replay memory, and use random sample from buffer to updateWeight_local."""
        for i in range(self.num_agents):
            self.memory.add(state[i, :], action[i, :], reward[i],
                            next_state[i, :], done[i])
        if len(self.memory) > BATCH_SIZE and timeStep % 2 == 0:
            self.updateWeight_local(self.memory.sample(), GAMMA)

    def updateWeight_local(self, experiences, gamma):
        """Update policy and value parameters using given batch of experience tuples.
        Q_targets = r + γ * critic_target(next_state, actor_target(next_state))
        where:
           actor_target(state) -> action
           critic_target(state, action) -> Q-value

        Params
        ======
           experiences (Tuple[torch.Tensor]): tuple of (s, a, r, s', done) tuples
           gamma (float): discount factor
        """
        states, actions, rewards, next_states, dones = experiences

        # ---------------------------- update critic ---------------------------- #
        # Get predicted next-state actions and Q values from target models

        next_actions = self.actor_target(next_states)  # Sarsa?
        Q_target_next = self.critic_target.forward(next_states, next_actions)
        Q_target = rewards + gamma * Q_target_next * (1 - dones)
        Q_local = self.critic_local.forward(states, actions)
        critic_loss = F.mse_loss(Q_local, Q_target)

        # Minimize the loss
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        torch.nn.utils.clip_grad_norm(self.critic_local.parameters(), 1)
        self.critic_optimizer.step()

        # ---------------------------- update actor ---------------------------- #
        # Compute actor loss
        actions_pred = self.actor_local.forward(states)
        actor_loss = -self.critic_local(
            states,
            actions_pred).mean()  # '-' for Reward Maxim, gradient ascent
        # Minimize the loss
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        # ----------------------- update target networks ----------------------- #
        self.updateWeight_target(self.critic_local, self.critic_target, TAU)
        self.updateWeight_target(self.actor_local, self.actor_target, TAU)

    def updateWeight_target(self, local_model, target_model, tau):
        """Soft update TARGET model parameters.
        θ_target = τ*θ_local + (1 - τ)*θ_target

        Params
        ======
            local_model: PyTorch model (weights will be copied from)
            target_model: PyTorch model (weights will be copied to)
            tau (float): interpolation parameter
        """
        for target_param, local_param in zip(target_model.parameters(),
                                             local_model.parameters()):
            target_param.data.copy_(tau * local_param.data +
                                    (1.0 - tau) * target_param.data)
예제 #10
0
class DDPG:
    def __init__(self, args):
        """
            init function
            Args:
                - args: class with args parameter
        """
        self.state_size = args.state_size
        self.action_size = args.action_size
        self.bs = args.bs
        self.gamma = args.gamma
        self.epsilon = args.epsilon
        self.tau = args.tau
        self.discrete = args.discrete
        self.randomer = OUNoise(args.action_size)
        self.buffer = ReplayBuffer(args.max_buff)

        self.actor = Actor(self.state_size, self.action_size)
        self.actor_target = Actor(self.state_size, self.action_size)
        self.actor_opt = AdamW(self.actor.parameters(), args.lr_actor)

        self.critic = Critic(self.state_size, self.action_size)
        self.critic_target = Critic(self.state_size, self.action_size)
        self.critic_opt = AdamW(self.critic.parameters(), args.lr_critic)

        hard_update(self.actor_target, self.actor)
        hard_update(self.critic_target, self.critic)

    def reset(self):
        """
            reset noise and model
        """
        self.randomer.reset()

    def get_action(self, state):
        """
            get distribution of action
            Args:
                - state: list, shape == [state_size]
        """
        state = torch.tensor(state, dtype=torch.float).unsqueeze(0)
        action = self.actor(state).detach()
        action = action.squeeze(0).numpy()
        action += self.epsilon * self.randomer.noise()
        action = np.clip(action, -1.0, 1.0)
        return action

    def learning(self):
        """
            learn models
        """
        s1, a1, r1, t1, s2 = self.buffer.sample_batch(self.bs)
        # bool -> int
        t1 = 1 - t1
        s1 = torch.tensor(s1, dtype=torch.float)
        a1 = torch.tensor(a1, dtype=torch.float)
        r1 = torch.tensor(r1, dtype=torch.float)
        t1 = torch.tensor(t1, dtype=torch.float)
        s2 = torch.tensor(s2, dtype=torch.float)

        a2 = self.actor_target(s2).detach()
        q2 = self.critic_target(s2, a2).detach()
        q2_plus_r = r1[:, None] + t1[:, None] * self.gamma * q2
        q1 = self.critic.forward(s1, a1)

        # critic gradient
        critic_loss = nn.MSELoss()
        loss_critic = critic_loss(q1, q2_plus_r)
        self.critic_opt.zero_grad()
        loss_critic.backward()
        self.critic_opt.step()

        # actor gradient
        pred_a = self.actor.forward(s1)
        loss_actor = (-self.critic.forward(s1, pred_a)).mean()
        self.actor_opt.zero_grad()
        loss_actor.backward()
        self.actor_opt.step()

        # Notice that we only have gradient updates for actor and critic, not target
        # actor_opt.step() and critic_opt.step()
        soft_update(self.actor_target, self.actor, self.tau)
        soft_update(self.critic_target, self.critic, self.tau)

        return loss_actor.item(), loss_critic.item()
예제 #11
0
class Agent():
    """Interacts with and learns from the environment."""
    def __init__(self, state_size, action_size, seed, num_agents=20):
        """Initialize an Agent object.
        
        Params
        ======
            state_size (int): dimension of each state
            action_size (int): dimension of each action
            seed (int): random seed
        """
        print("Running on: " + str(device))

        self.state_size = state_size
        self.action_size = action_size
        self.num_agents = num_agents
        self.seed = random.seed(seed)
        self.eps = EPS_START
        self.eps_decay = 0.0005
        # Actor network
        self.actor_local = Actor(state_size, action_size, seed).to(device)
        self.actor_target = Actor(state_size, action_size, seed).to(device)
        self.actor_optim = optim.Adam(self.actor_local.parameters(),
                                      lr=LR_ACTOR)

        # Critic network
        self.critic_local = Critic(state_size, action_size, seed).to(device)
        self.critic_target = Critic(state_size, action_size, seed).to(device)
        self.critic_optim = optim.Adam(self.critic_local.parameters(),
                                       lr=LR_CRITIC)

        # Replay memory
        self.memory = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed)
        # Initialize time step (for updating every UPDATE_EVERY steps)
        self.t_step = 0

        self.noise = OUNoise((num_agents, action_size), seed)

    def step(self, state, action, reward, next_state, done, agent_id):
        # Save experience in replay memory
        self.memory.add(state, action, reward, next_state, done)

        self.t_step += 1
        # Learn every UPDATE_EVERY time steps.
        if (self.t_step % UPDATE_EVERY) == 0:
            # If enough samples are available in memory, get random subset and learn
            if len(self.memory) > BATCH_SIZE:
                for _ in range(LEARN_NUM):
                    experiences = self.memory.sample()
                    self.learn(experiences, GAMMA, agent_id)

    def act(self, states, add_noise=True):
        """Returns actions for given state as per current policy."""
        states = torch.from_numpy(states).float().to(device)
        actions = np.zeros((self.num_agents, self.action_size))

        self.actor_local.eval()
        with torch.no_grad():
            for i, state in enumerate(states):
                actions[i, :] = self.actor_local(state).cpu().data.numpy()
        self.actor_local.train()

        if add_noise:
            actions += self.eps * self.noise.sample()
        return np.clip(actions, -1, 1)

    def learn(self, experiences, gamma, agent_id):
        """Update value parameters using given batch of experience tuples.

        Params
        ======
            experiences (Tuple[torch.Variable]): tuple of (s, a, r, s', done) tuples 
            gamma (float): discount factor
        """
        states, actions, rewards, next_states, dones = experiences

        # ------------------- update critic network ------------------- #
        target_actions = self.actor_target.forward(next_states)
        # Construct next actions vector relative to the agent
        if agent_id == 0:
            target_actions = torch.cat((target_actions, actions[:, 2:]), dim=1)
        else:
            target_actions = torch.cat((actions[:, :2], target_actions), dim=1)

        next_critic_value = self.critic_target.forward(next_states,
                                                       target_actions)
        critic_value = self.critic_local.forward(states, actions)
        # Q targets for current state
        # If the episode is over, the reward from the future state will not be incorporated
        Q_targets = rewards + (gamma * next_critic_value * (1 - dones))

        critic_loss = F.mse_loss(critic_value, Q_targets)
        # Minimizing loss
        self.critic_local.train()
        self.critic_optim.zero_grad()
        critic_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.critic_local.parameters(), 1)
        self.critic_optim.step()

        self.critic_local.eval()

        # ------------------- update actor network ------------------- #
        self.actor_local.train()
        self.actor_optim.zero_grad()
        mu = self.actor_local.forward(states)
        # Construct mu vector relative to each agent
        if agent_id == 0:
            mu = torch.cat((mu, actions[:, 2:]), dim=1)
        else:
            mu = torch.cat((actions[:, :2], mu), dim=1)

        actor_loss = -self.critic_local(states, mu).mean()
        actor_loss.backward()
        self.actor_optim.step()

        self.actor_local.eval()

        # ----------------------- update target networks ----------------------- #
        self.soft_update(self.critic_local, self.critic_target, TAU)
        self.soft_update(self.actor_local, self.actor_target, TAU)

        # update noise decay parameter
        self.eps -= self.eps_decay
        self.eps = max(self.eps, EPS_FINAL)
        self.noise.reset()

    def soft_update(self, local_model, target_model, tau):
        """Soft update model parameters.
        θ_target = τ*θ_local + (1 - τ)*θ_target

        Params
        ======
            local_model (PyTorch model): weights will be copied from
            target_model (PyTorch model): weights will be copied to
            tau (float): interpolation parameter 
        """
        for target_param, local_param in zip(target_model.parameters(),
                                             local_model.parameters()):
            target_param.data.copy_(tau * local_param.data +
                                    (1.0 - tau) * target_param.data)

    def reset(self):
        self.noise.reset()
예제 #12
0
class Agent():
    def __init__(self,
                 action_space_shape,
                 observation_space_shape,
                 n_train_steps=50 * 1000000,
                 replay_memory_size=1000000,
                 k=3):

        # Cuda
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")

        # Hyperparameters - dynamic
        self.action_space_shape = action_space_shape
        self.observation_space_shape = observation_space_shape
        self.k = k
        self.observation_input_shape = multiply_tuple(
            self.observation_space_shape, self.k)
        self.n_train_steps = n_train_steps
        self.replay_memory_size = replay_memory_size
        self.replay_memory = deque(maxlen=self.replay_memory_size)

        # Hyperparameters - static
        self.training_start_time_step = 1000  # Minimum: k * minibatch_size == 3 * 64 = 192
        self.gamma = 0.99  # For reward discount
        self.tau = 0.001  # For soft update

        # Hyperparameters - Ornstein_Uhlenbeck_noise
        self.theta = 0.15
        self.sigma = 0.2
        self.Ornstein_Uhlenbeck_noise = OUNoise(
            action_space_shape=self.action_space_shape,
            theta=self.theta,
            sigma=self.sigma)

        # Hyperparameters - NN model
        self.minibatch_size = 64  # For training NN
        self.lr_actor = 10e-4
        self.lr_critic = 10e-3
        self.weight_decay_critic = 10e-2

        # Parameters - etc
        self.action = None
        self.time_step = 0
        self.train_step = 0
        self.train_complete = False

        # Modules
        self.actor = Actor(
            action_space_shape=self.action_space_shape,
            observation_space_shape=self.observation_input_shape).to(
                self.device)
        self.critic = Critic(
            action_space_shape=self.action_space_shape,
            observation_space_shape=self.observation_input_shape).to(
                self.device)
        self.actor_hat = copy.deepcopy(self.actor)
        self.critic_hat = copy.deepcopy(self.critic)

        self.optimizer_actor = optim.Adam(self.actor.parameters(),
                                          lr=self.lr_actor)
        self.optimizer_critic = optim.Adam(
            self.critic.parameters(),
            lr=self.lr_critic,
            weight_decay=self.weight_decay_critic)

        # Operations
        self.mode('train')

    def reset(self, observation):

        self.previous_observation = torch.tensor([observation] * self.k).to(
            dtype=torch.float,
            device=self.device).view(self.observation_input_shape)
        self.observation_buffer = list()
        self.reward = torch.tensor([0])  # Tensor form for compatibility
        self.Ornstein_Uhlenbeck_noise.reset()

        # Since replay memory is somewhat full, we can decrease waiting time for sufficient data to fill in the replay memory.
        self.training_start_time_step = max(
            0, self.training_start_time_step - self.time_step)
        self.time_step = 0
        # Don't reset replay_memory
        # self.replay_memory = deque(maxlen = self.replay_memory_size)

    def mode(self, mode):

        self.mode = mode
        if self.mode == 'train':
            pass
        elif self.mode == 'test':
            pass
        else:
            assert False, 'mode not specified'

    def wakeup(self):

        # Frame skipping
        # See & Select actions every kth frame. Modify ations every kth frame
        # Otherwise, skip frame
        if self.time_step % self.k == 0:
            return True
        else:
            return False

    def act(self):

        if self.wakeup() == True:
            self.action = self.actor.forward(
                self.previous_observation) + torch.as_tensor(
                    self.Ornstein_Uhlenbeck_noise(),
                    dtype=torch.float,
                    device=self.device)

        self.time_step += 1

        # Return numpy version
        return self.action.detach().numpy()

    def observe(self, observation, reward):

        if self.wakeup() == True:

            # Append observation
            self.observation_buffer.append(observation)
            self.new_observation = torch.tensor(self.observation_buffer).to(
                dtype=torch.float,
                device=self.device).view(self.observation_input_shape)

            # Add reward
            self.reward += reward

            # Store transition in replay memory
            # If memory size exceeds, the oldest memory is popped (deque property)
            # wrap self.action with torch.tensor() to reset requires_grad = False
            self.replay_memory.append(
                (self.previous_observation, self.action.clone().detach(),
                 self.reward, self.new_observation
                 ))  # self.action.new_tensor() == self.action.clone().detach()

            # The new observation will be the previous observation next time
            self.previous_observation = self.new_observation

            # Empty observation buffer, reset reward
            self.observation_buffer = list()
            self.reward = torch.tensor([0])  # Tensor form for compatibility

        else:

            self.observation_buffer.append(observation)
            self.reward += reward

    def random_sample_data(self):

        memory_size = len(self.replay_memory)

        # state, action, reward, state_next
        s_i = list()
        a_i = list()
        r_i = list()
        s_i_1 = list()

        # Random Sample transitions, append them into np arrays
        random_index = np.random.choice(
            memory_size, size=self.minibatch_size, replace=False
        )  # random_index: [0,5,4,9, ...] // "replace = False" makes the indices exclusive.

        for index in random_index:

            # Random sample transitions, 'minibatch' times
            s, a, r, s_1 = self.replay_memory[index]
            s_i.append(
                s
            )  # s_i Equivalent to [self.replay_memory[index][0] for index in random_index]
            a_i.append(a)
            r_i.append(r)
            s_i_1.append(s_1)

        s_i = torch.stack(s_i).to(dtype=torch.float, device=self.device)
        a_i = torch.stack(a_i).to(dtype=torch.float, device=self.device)
        r_i = torch.stack(r_i).to(dtype=torch.float, device=self.device)
        s_i_1 = torch.stack(s_i_1).to(dtype=torch.float, device=self.device)

        return s_i, a_i, r_i, s_i_1

    def train(self):

        if self.wakeup(
        ) == True and self.time_step >= self.training_start_time_step:
            # 1. Sample random minibatch of transitions from replay memory
            # state, action, reward, state_next
            s_i, a_i, r_i, s_i_1 = self.random_sample_data(
            )  # **minibatch info included in "self"

            # 2. Set y_i
            y_i = r_i + self.gamma * self.critic_hat.forward(
                s_i_1, self.actor_hat.forward(s_i_1))

            # 3. Calculate Loss
            self.optimizer_critic.zero_grad()
            critic_loss = F.mse_loss(y_i, self.critic.forward(s_i, a_i))

            # 4. Update Critic
            critic_loss.backward()
            self.optimizer_critic.step()

            # 5. Update Actor
            self.optimizer_actor.zero_grad()
            critic_Q_mean = -self.critic.forward(
                s_i, self.actor.forward(s_i)).mean()
            critic_Q_mean.backward()
            self.optimizer_actor.step()

            # 6. Update target networks
            self.critic_hat = self.tau * self.critic + (
                1 - self.tau) * self.critic_hat
            self.actor = self.tau * self.actor + (1 -
                                                  self.tau) * self.actor_hat

            # 7. Increment train step.
            # If train step meets its scheduled training steps, change "train_complete" status
            self.train_step += 1
            if self.train_step >= self.n_train_steps:
                self.train_complete = True
예제 #13
0
class Agent():
    def __init__(self, state_size, action_size, num_agents, seed):
        """Initialize an Agent object.

        Params
        ======
            state_size (int): dimension of each state
            action_size (int): dimension of each action
            random_seed (int): random seed
        """
        self.seed = random.seed(seed)
        self.state_size = state_size  # 24
        self.action_size = action_size  # 2
        self.num_agents = num_agents  # 2
        self.eps = eps_start

        #Actor Network: State -> Action
        self.actor_local = Actor(state_size, action_size, seed).to(device)
        self.actor_target = Actor(state_size, action_size, seed).to(device)
        self.actor_optimizer = optim.Adam(self.actor_local.parameters(),
                                          lr=LR_ACTOR)

        #Critic Network: State1 x State2 x Action1 x Action2 ... -> Qvalue
        self.critic_local = Critic(state_size * num_agents,
                                   action_size * num_agents, seed).to(device)
        self.critic_target = Critic(state_size * num_agents,
                                    action_size * num_agents, seed).to(device)
        self.critic_optimizer = optim.Adam(self.critic_local.parameters(),
                                           lr=LR_CRITIC,
                                           weight_decay=WEIGHT_DECAY)

        # Noise process
        self.noise = OUNoise(action_size, seed)

        # Replay memory
        self.memory = ReplayBuffer(BUFFER_SIZE, BATCH_SIZE, seed)

    def act(self, state, add_noise):
        """Returns actions for given state as per current policy."""
        state = torch.from_numpy(state).float().to(device)
        self.actor_local.eval()  # set module to evaluation mode
        with torch.no_grad():
            action = self.actor_local.forward(state).cpu().data.numpy()
        self.actor_local.train()  # reset it back to training mode

        if add_noise:
            action += self.noise.sample() * self.eps

        return np.clip(action, -1, 1)  # restrict the output boundary -1, 1

    def reset(self):
        self.noise.reset()

    def step(self, state, action, reward, next_state, done, timestep,
             agent_index):
        """Save experience in replay memory, and use random sample from buffer to updateWeight_local."""
        # for i in range(self.num_agents):
        self.memory.add(state, action, reward, next_state, done)
        if len(self.memory) > BATCH_SIZE and timestep % UPDATE_FREQUENCY == 0:
            self.updateWeight_local(agent_index, self.memory.sample(), GAMMA)

    def updateWeight_local(self, agent_index, experiences, gamma):
        """Update policy and value parameters using given batch of experience tuples.
        Q_targets = r + γ * critic_target(next_state, actor_target(next_state))
        where:
           actor_target(state) -> action
           critic_target(state, action) -> Q-value

        Params
        ======
           experiences (Tuple[torch.Tensor]): tuple of (s, a, r, s', done) tuples
           gamma (float): discount factor
        """
        states, actions, rewards, next_states, dones = experiences
        # states: (batchsize, 24x2)
        # actions: (batchsize, 2x2)
        # rewards: (batchsize, 1x2)
        # next_states: (batchsize, 24x2)
        # dones: (batchsize, 1x2)
        # ---------------------------- update critic ---------------------------- #
        # Get predicted next-state actions and Q values from target models

        self_next_actions = self.actor_target(
            next_states[:, self.state_size * agent_index:self.state_size *
                        (agent_index + 1)])  # actor by self obser
        notSelf_actions = actions[:, self.action_size *
                                  (1 - agent_index):self.action_size *
                                  (2 - agent_index)]  # competitor's actions
        if agent_index == 0:  # concat order by agent index
            next_acitons = torch.cat((self_next_actions, notSelf_actions),
                                     dim=1).to(device)  # index0-> self:first
        else:
            next_acitons = torch.cat((notSelf_actions, self_next_actions),
                                     dim=1).to(device)  # index1 -> self:second

        Q_target_next = self.critic_target.forward(
            next_states,
            next_acitons)  # critic by both agent's obs and actions
        Q_target = rewards + gamma * Q_target_next * (1 - dones)
        Q_local = self.critic_local.forward(states, actions)
        critic_loss = F.mse_loss(Q_local, Q_target)

        # Minimize the loss
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.critic_local.parameters(), 1)
        self.critic_optimizer.step()

        # ---------------------------- update actor ---------------------------- #
        # Compute actor loss
        self_actions_pred = self.actor_local.forward(
            states[:, self.state_size * agent_index:self.state_size *
                   (agent_index + 1)])  #actor by self agent's obser
        notSelf_actions = actions[:, self.action_size *
                                  (1 - agent_index):self.action_size *
                                  (2 - agent_index)]  # competitor's actions
        if agent_index == 0:
            actions_pred = torch.cat((self_actions_pred, notSelf_actions),
                                     dim=1).to(device)
        else:
            actions_pred = torch.cat((notSelf_actions, self_actions_pred),
                                     dim=1).to(device)

        actor_loss = -self.critic_local(
            states,
            actions_pred).mean()  # '-' for Reward Maxim, gradient ascent
        # Minimize the loss
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.actor_local.parameters(), 1)
        self.actor_optimizer.step()

        # ----------------------- update target networks ----------------------- #
        self.updateWeight_target(self.critic_local, self.critic_target, TAU)
        self.updateWeight_target(self.actor_local, self.actor_target, TAU)

        # Update epsilon noise value
        self.eps = self.eps - (1 / eps_decay)
        if self.eps < eps_end:
            self.eps = eps_end

    def updateWeight_target(self, local_model, target_model, tau):
        """Soft update TARGET model parameters.
        θ_target = τ*θ_local + (1 - τ)*θ_target

        Params
        ======
            local_model: PyTorch model (weights will be copied from)
            target_model: PyTorch model (weights will be copied to)
            tau (float): interpolation parameter
        """
        for target_param, local_param in zip(target_model.parameters(),
                                             local_model.parameters()):
            target_param.data.copy_(tau * local_param.data +
                                    (1.0 - tau) * target_param.data)
class DDPGAgent:

    def __init__(self, env, gamma, tau, buffer_maxlen, critic_learning_rate, actor_learning_rate):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.env = env
        self.obs_dim = env.observation_space.shape[0]
        self.action_dim = env.action_space.shape[0]

        # hyperparameters
        self.env = env
        self.gamma = gamma
        self.tau = tau

        # initialize actor and critic networks
        self.critic = Critic(self.obs_dim, self.action_dim).to(self.device)
        self.critic_target = Critic(self.obs_dim, self.action_dim).to(self.device)

        self.actor = Actor(self.obs_dim, self.action_dim).to(self.device)
        self.actor_target = Actor(self.obs_dim, self.action_dim).to(self.device)

        # Copy critic target parameters
        for target_param, param in zip(self.critic_target.parameters(), self.critic.parameters()):
            target_param.data.copy_(param.data)

        # optimizers
        self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=critic_learning_rate)
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=actor_learning_rate)

        self.replay_buffer = BasicBuffer(buffer_maxlen)
        self.noise = OUNoise(self.env.action_space)

    def get_action(self, obs):
        state = torch.FloatTensor(obs).unsqueeze(0).to(self.device)
        action = self.actor.forward(state)
        action = action.squeeze(0).cpu().detach().numpy()

        return action

    def update(self, batch_size):
        states, actions, rewards, next_states, _ = self.replay_buffer.sample(batch_size)
        state_batch, action_batch, reward_batch, next_state_batch, masks = self.replay_buffer.sample(batch_size)
        state_batch = torch.FloatTensor(state_batch).to(self.device)
        action_batch = torch.FloatTensor(action_batch).to(self.device)
        reward_batch = torch.FloatTensor(reward_batch).to(self.device)
        next_state_batch = torch.FloatTensor(next_state_batch).to(self.device)
        masks = torch.FloatTensor(masks).to(self.device)

        curr_Q = self.critic.forward(state_batch, action_batch)
        next_actions = self.actor_target.forward(next_state_batch)
        next_Q = self.critic_target.forward(next_state_batch, next_actions.detach())
        expected_Q = reward_batch + self.gamma * next_Q

        # update critic
        q_loss = F.mse_loss(curr_Q, expected_Q.detach())

        self.critic_optimizer.zero_grad()
        q_loss.backward()
        self.critic_optimizer.step()

        # update actor
        policy_loss = -self.critic.forward(state_batch, self.actor.forward(state_batch)).mean()

        self.actor_optimizer.zero_grad()
        policy_loss.backward()
        self.actor_optimizer.step()

        # update target networks
        for target_param, param in zip(self.actor_target.parameters(), self.actor.parameters()):
            target_param.data.copy_(param.data * self.tau + target_param.data * (1.0 - self.tau))

        for target_param, param in zip(self.critic_target.parameters(), self.critic.parameters()):
            target_param.data.copy_(param.data * self.tau + target_param.data * (1.0 - self.tau))
class DDPGAgent:

    def __init__(self, env, agent_id, actor_lr=1e-4, critic_lr=1e-3, gamma=0.99, tau=1e-2):
        self.env = env
        self.agent_id = agent_id
        self.actor_lr = actor_lr
        self.critic_lr = critic_lr
        self.gamma = gamma
        self.tau = tau

        self.device = "cpu"
        self.use_cuda = torch.cuda.is_available()
        if self.use_cuda:
            self.device = "cuda"

        self.obs_dim = self.env.observation_space[agent_id].shape[0]
        self.action_dim = self.env.action_space[agent_id].n
        self.num_agents = self.env.n

        self.critic_input_dim = int(np.sum([env.observation_space[agent].shape[0] for agent in range(env.n)]))
        self.actor_input_dim = self.obs_dim

        self.critic = CentralizedCritic(self.critic_input_dim, self.action_dim * self.num_agents).to(self.device)
        self.critic_target = CentralizedCritic(self.critic_input_dim, self.action_dim * self.num_agents).to(self.device)
        self.actor = Actor(self.actor_input_dim, self.action_dim).to(self.device)
        self.actor_target = Actor(self.actor_input_dim, self.action_dim).to(self.device)

        for target_param, param in zip(self.critic_target.parameters(), self.critic.parameters()):
            target_param.data.copy_(param.data)
        
        for target_param, param in zip(self.actor_target.parameters(), self.actor.parameters()):
            target_param.data.copy_(param.data)
        
        self.MSELoss = nn.MSELoss()
        self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=critic_lr)
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=actor_lr)

    def get_action(self, state):
        state = autograd.Variable(torch.from_numpy(state).float().squeeze(0)).to(self.device)
        action = self.actor.forward(state)
        action = self.onehot_from_logits(action)

        return action
    
    def onehot_from_logits(self, logits, eps=0.0):
        # get best (according to current policy) actions in one-hot form
        argmax_acs = (logits == logits.max(0, keepdim=True)[0]).float()
        if eps == 0.0:
            return argmax_acs
        # get random actions in one-hot form
        rand_acs = Variable(torch.eye(logits.shape[1])[[np.random.choice(
            range(logits.shape[1]), size=logits.shape[0])]], requires_grad=False)
        # chooses between best and random actions using epsilon greedy
        return torch.stack([argmax_acs[i] if r > eps else rand_acs[i] for i, r in
                            enumerate(torch.rand(logits.shape[0]))])
    
    def update(self, indiv_reward_batch, indiv_obs_batch, global_state_batch, global_actions_batch, global_next_state_batch, next_global_actions):
        """
        indiv_reward_batch      : only rewards of agent i
        indiv_obs_batch         : only observations of agent i
        global_state_batch      : observations of all agents are concatenated
        global actions_batch    : actions of all agents are concatenated
        global_next_state_batch : observations of all agents are concatenated
        next_global_actions     : actions of all agents are concatenated
        """
        indiv_reward_batch = torch.FloatTensor(indiv_reward_batch).to(self.device)
        indiv_reward_batch = indiv_reward_batch.view(indiv_reward_batch.size(0), 1).to(self.device) 
        indiv_obs_batch = torch.FloatTensor(indiv_obs_batch).to(self.device)          
        global_state_batch = torch.FloatTensor(global_state_batch).to(self.device)    
        global_actions_batch = torch.stack(global_actions_batch).to(self.device)      
        global_next_state_batch = torch.FloatTensor(global_next_state_batch).to(self.device)
        next_global_actions = next_global_actions

        # update critic        
        self.critic_optimizer.zero_grad()
        
        curr_Q = self.critic.forward(global_state_batch, global_actions_batch)
        next_Q = self.critic_target.forward(global_next_state_batch, next_global_actions)
        estimated_Q = indiv_reward_batch + self.gamma * next_Q
        
        critic_loss = self.MSELoss(curr_Q, estimated_Q.detach())
        critic_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.critic.parameters(), 0.5)
        self.critic_optimizer.step()

        # update actor
        self.actor_optimizer.zero_grad()

        policy_loss = -self.critic.forward(global_state_batch, global_actions_batch).mean()
        curr_pol_out = self.actor.forward(indiv_obs_batch)
        policy_loss += -(curr_pol_out**2).mean() * 1e-3 
        policy_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.critic.parameters(), 0.5)
        self.actor_optimizer.step()
    
    def target_update(self):
        for target_param, param in zip(self.actor_target.parameters(), self.actor.parameters()):
            target_param.data.copy_(param.data)

        for target_param, param in zip(self.critic_target.parameters(), self.critic.parameters()):
            target_param.data.copy_(param.data * self.tau + target_param.data * (1.0 - self.tau))