def learn(self):
        # Sample a batch from the replay buffer
        transitions = self.replay.sample()
        states, full_state, actions, rewards, next_states, next_full_state, dones = transpose_to_tensor(
            transitions, self.config.device)

        ### Update online critic model ###
        # Compute actions for next states with the target actor model
        with torch.no_grad():  # don't use gradients for target
            target_next_actions = [
                self.target_actor(next_states[:, i, :])
                for i in range(self.config.num_agents)
            ]

        target_next_actions = torch.cat(target_next_actions, dim=1)

        # Compute Q values for the next states and next actions with the target critic model
        with torch.no_grad():  # don't use gradients for target
            target_next_qs = self.target_critic(
                next_full_state.to(self.config.device),
                target_next_actions.to(self.config.device))

        # Compute Q values for the current states and actions
        target_qs = rewards.sum(
            1, keepdim=True) + self.config.discount * target_next_qs * (
                1 - dones.max(1, keepdim=True)[0])

        # Compute Q values for the current states and actions with the online critic model
        actions = actions.view(actions.shape[0], -1)
        online_qs = self.online_critic(full_state.to(self.config.device),
                                       actions.to(self.config.device))

        # Compute and minimize the online critic loss
        online_critic_loss = F.mse_loss(online_qs, target_qs.detach())
        self.online_critic_opt.zero_grad()
        online_critic_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.online_critic.parameters(), 1)
        self.online_critic_opt.step()

        ### Update online actor model ###
        # Compute actions for the current states with the online actor model
        online_actions = [
            self.online_actor(states[:, i, :])
            for i in range(self.config.num_agents)
        ]
        online_actions = torch.cat(online_actions, dim=1)
        # Compute the online actor loss with the online critic model
        online_actor_loss = -self.online_critic(
            full_state.to(self.config.device),
            online_actions.to(self.config.device)).mean()
        # Minimize the online critic loss
        self.online_actor_opt.zero_grad()
        online_actor_loss.backward()
        self.online_actor_opt.step()

        ### Update target critic and actor models ###
        soft_update(self.target_actor, self.online_actor,
                    self.config.target_mix)
        soft_update(self.target_critic, self.online_critic,
                    self.config.target_mix)
Esempio n. 2
0
    def optimize(self):
        """
		Samples a random batch from replay memory and performs optimization
		:return:
		"""
        s1, a1, r1, s2 = self.ram.sample(self.args.batch_size)

        s1 = Variable(torch.from_numpy(s1))
        a1 = Variable(torch.from_numpy(a1))
        r1 = Variable(torch.from_numpy(r1))
        s2 = Variable(torch.from_numpy(s2))

        # ---------------------- optimize critic ----------------------
        # Use target actor exploitation policy here for loss evaluation
        a2 = self.target_actor.forward(s2).detach()
        next_val = torch.squeeze(self.target_critic.forward(s2, a2).detach())
        # y_exp = r + gamma*Q'( s2, pi'(s2))
        y_expected = r1 + self.args.gamma * next_val
        # y_pred = Q( s1, a1)
        y_predicted = torch.squeeze(self.critic.forward(s1, a1))
        # compute critic loss, and update the critic
        loss_critic = F.smooth_l1_loss(y_predicted, y_expected)
        self.critic_optimizer.zero_grad()
        loss_critic.backward()
        self.critic_optimizer.step()

        # ---------------------- optimize actor ----------------------
        pred_a1 = self.actor.forward(s1)
        loss_actor = -1 * torch.sum(self.critic.forward(s1, pred_a1))
        self.actor_optimizer.zero_grad()
        loss_actor.backward()
        self.actor_optimizer.step()

        utils.soft_update(self.target_actor, self.actor, self.args.tau)
        utils.soft_update(self.target_critic, self.critic, self.args.tau)
Esempio n. 3
0
    def optimize(self):
        s1, a1, r1, s2, done = self.ram.sample(self.batch_size)

        s1 = Variable(torch.tensor(s1)).cuda()
        a1 = Variable(torch.tensor(a1, dtype=torch.int64)).cuda()
        r1 = Variable(torch.tensor(r1)).cuda()
        s2 = Variable(torch.tensor(s2)).cuda()

        self.optimizer.zero_grad()

        # optimize

        #
        action_predict = torch.argmax(self.learning_net.forward(s2), dim=1)
        r_predict = torch.squeeze(
            self.target_net.forward(s2).gather(1, action_predict.view(-1, 1)))
        r_predict = self.gamma * r_predict
        y_j = r1 + r_predict
        y_j = self.done_state_value(r1, y_j, done)

        # r_ : Q(s_j, a_j)
        r_ = self.learning_net.forward(s1)
        r_ = torch.squeeze(r_.gather(1, a1.view(-1, 1)))

        # loss: (y_j - Q(s_j, a_j))^2
        loss = self.loss_f(y_j, r_)

        loss.backward()
        self.optimizer.step()

        utils.soft_update(self.target_net, self.learning_net, self.tau)
        self.iter += 1
        return loss.cpu()
Esempio n. 4
0
    def learn(self):
        self.learning_steps += 1
        if self.learning_steps % self.target_update_interval == 0:
            soft_update(self.critic_target, self.critic, self.tau)

        if self.per:
            # batch with indices and priority weights
            batch, indices, weights = \
                self.memory.sample(self.batch_size)
        else:
            batch = self.memory.sample(self.batch_size)
            # set priority weights to 1 when we don't use PER.
            weights = 1.

        q1_loss, q2_loss, errors, mean_q1, mean_q2 =\
            self.calc_critic_loss(batch, weights)
        policy_loss, entropies = self.calc_policy_loss(batch, weights)

        update_params(self.q1_optim, self.critic.Q1, q1_loss, self.grad_clip)
        update_params(self.q2_optim, self.critic.Q2, q2_loss, self.grad_clip)
        update_params(self.policy_optim, self.policy, policy_loss,
                      self.grad_clip)

        if self.entropy_tuning:
            entropy_loss = self.calc_entropy_loss(entropies, weights)
            update_params(self.alpha_optim, None, entropy_loss)
            self.alpha = self.log_alpha.exp()

        if self.per:
            # update priority weights
            self.memory.update_priority(indices, errors.cpu().numpy())
Esempio n. 5
0
    def update(self, batch, update_actor=True):
        """Updates parameters of TD3 actor and critic given samples from the batch.

    Args:
       batch: A list of timesteps from environment.
       update_actor: a boolean variable, whether to perform a policy update.
    """
        obs = contrib_eager_python_tfe.Variable(
            np.stack(batch.obs).astype('float32'))
        action = contrib_eager_python_tfe.Variable(
            np.stack(batch.action).astype('float32'))
        next_obs = contrib_eager_python_tfe.Variable(
            np.stack(batch.next_obs).astype('float32'))
        mask = contrib_eager_python_tfe.Variable(
            np.stack(batch.mask).astype('float32'))

        if self.get_reward is not None:
            reward = self.get_reward(obs, action, next_obs)
        else:
            reward = contrib_eager_python_tfe.Variable(
                np.stack(batch.reward).astype('float32'))

        if self.use_td3:
            self._update_critic_td3(obs, action, next_obs, reward, mask)
        else:
            self._update_critic_ddpg(obs, action, next_obs, reward, mask)

        if self.critic_step.numpy() % self.policy_update_freq == 0:
            if update_actor:
                self._update_actor(obs, mask)
                soft_update(self.actor.variables, self.actor_target.variables,
                            self.tau)
            soft_update(self.critic.variables, self.critic_target.variables,
                        self.tau)
Esempio n. 6
0
    def optimize(self):
        """
		Samples a random batch from replay memory and performs optimization
		:return:
		"""
        s1, a1, r1, s2 = self.ram.sample(BATCH_SIZE)

        s1 = Variable(torch.from_numpy(s1))
        a1 = Variable(torch.from_numpy(a1))
        r1 = Variable(torch.from_numpy(r1))
        s2 = Variable(torch.from_numpy(s2))

        # ---------------------- optimize critic ----------------------
        a2 = self.target_actor.forward(s2).detach()
        next_val = torch.squeeze(self.target_critic.forward(s2, a2).detach())

        y_expected = r1 + GAMMA * next_val
        y_predicted = torch.squeeze(self.critic.forward(s1, a1))

        loss_critic = F.smooth_l1_loss(y_predicted, y_expected)
        self.critic_optimizer.zero_grad()
        loss_critic.backward()
        self.critic_optimizer.step()

        # ---------------------- optimize actor ----------------------
        pred_a1 = self.actor.forward(s1)
        loss_actor = -1 * torch.sum(self.critic.forward(s1, pred_a1))
        self.actor_optimizer.zero_grad()
        loss_actor.backward()
        self.actor_optimizer.step()

        utils.soft_update(self.target_actor, self.actor, TAU)
        utils.soft_update(self.target_critic, self.critic, TAU)
Esempio n. 7
0
    def learn(self, times=1):
        for i in range(times):
            if self.buffer.len < self.batch_size:
                return
            s, a, r, s_ = self.buffer.sample(self.batch_size)
            s = Variable(torch.from_numpy(s).float())
            a = Variable(torch.from_numpy(a).float())
            r = Variable(torch.from_numpy(r).float())
            s_ = Variable(torch.from_numpy(s_).float())

            #print (s)

            # for Critic
            #a_ = self.target_actor.forward(s_).detach().data.numpy()
            a_ = self.target_actor.forward(s_).detach()
            next_Q = self.target_critic.forward(s_, a_).detach()
            y = r + self.gamma * next_Q
            y_pred = self.critic.forward(s, a)

            loss_critic = F.smooth_l1_loss(y_pred, y)
            self.critic_optimizer.zero_grad()
            loss_critic.backward()
            self.critic_optimizer.step()

            # for Actor
            #pred_a = self.actor.forward(s).data.numpy()
            pred_a = self.actor.forward(s)
            loss_actor = -1 * torch.sum(self.critic.forward(s, pred_a))
            self.actor_optimizer.zero_grad()
            loss_actor.backward()
            self.actor_optimizer.step()

            utils.soft_update(self.target_actor, self.actor, self.tau)
            utils.soft_update(self.target_critic, self.critic, self.tau)
Esempio n. 8
0
    def optimize(self):
        s1, a1, r1, s2, done = self.ram.sample(self.batch_size)

        s1 = Variable(torch.tensor(s1)).cuda()
        a1 = Variable(torch.tensor(a1, dtype=torch.int64)).cuda()
        r1 = Variable(torch.tensor(r1)).cuda()
        s2 = Variable(torch.tensor(s2)).cuda()

        self.optimizer.zero_grad()


        # optimize
        r_predict = self.gamma * torch.max(self.target_net.forward(s2), dim=1).values
        y_j = r1 + r_predict
        y_j = self.done_state_value(r1, y_j, done)

        # r_ : Q(s_j, a_j)
        r_ = self.learning_net.forward(s1)
        # mask = F.one_hot(torch.squeeze(a1), num_classes=self.act_dim)
        # mask = torch.tensor(mask.clone().detach(), dtype=torch.uint8).cuda()
        # r_ = torch.masked_select(r_, mask)
        r_ = torch.squeeze(r_.gather(1, a1.view(-1,1)))

        # loss: (y_j - Q(s_j, a_j))^2
        loss = self.loss_f(y_j, r_)

        loss.backward()
        self.optimizer.step()

        utils.soft_update(self.target_net, self.learning_net, self.tau)
        self.iter += 1
        return loss.cpu()
Esempio n. 9
0
    def update_policy(self):
        state_batch, action_batch, reward_batch, next_state_batch, terminal_batch = self.memory.sample_batch(self.batch_size)
        state = to_tensor(np.array(state_batch), device=device)
        action = to_tensor(np.array(action_batch), device=device)
        next_state = to_tensor(np.array(next_state_batch), device=device)

        # compute target Q value
        next_q_value = self.critic_target([next_state, self.actor_target(next_state)])
        target_q_value = to_tensor(reward_batch, device=device) \
                         + self.discount * to_tensor((1 - terminal_batch.astype(np.float)), device=device) * next_q_value

        # Critic and Actor update
        self.critic.zero_grad()
        with torch.set_grad_enabled(True):
            q_values = self.critic([state, action])
            critic_loss = criterion(q_values, target_q_value.detach())
            critic_loss.backward()
            self.critic_optim.step()

        self.actor.zero_grad()
        with torch.set_grad_enabled(True):
            policy_loss = -self.critic([state.detach(), self.actor(state)]).mean()
            policy_loss.backward()
            self.actor_optim.step()

        # Target update
        soft_update(self.actor_target, self.actor, self.tau)
        soft_update(self.critic_target, self.critic, self.tau)

        return to_numpy(-policy_loss), to_numpy(critic_loss), to_numpy(q_values.mean())
    def learn(self, experiences, all_curr_pred_actions, all_next_pred_actions):
        
        agent_idx_device = torch.tensor(self.agent_idx).to(self.device)
        
        states, actions, rewards, next_states, dones = experiences

        rewards = rewards.index_select(1, agent_idx_device)
        dones = dones.index_select(1, agent_idx_device)
        
        # ---------------------------- update critic ---------------------------- #
        # Get predicted next-state actions and Q values from target models
                
        batch_size = next_states.shape[0]
        
        actions_next = torch.cat(all_next_pred_actions, dim=1).to(self.device)
        next_states = next_states.reshape(batch_size, -1)      
        
        with torch.no_grad():
            Q_targets_next = self.critic_target(next_states, actions_next)
        
        # Compute Q targets for current states (y_i)
        Q_targets = rewards + (self.gamma * Q_targets_next * (1 - dones))
        
        # Compute critic loss
        states = states.reshape(batch_size, -1)
        actions = actions.reshape(batch_size, -1)
        
        Q_expected = self.critic_local(states, actions)
        critic_loss = F.mse_loss(Q_expected, Q_targets.detach())
        # Minimize the loss
        self.critic_optim.zero_grad()
        critic_loss.backward()
        self.critic_optim.step()
        
        # ---------------------------- update actor ---------------------------- #
        # Compute actor loss
        self.actor_optim.zero_grad()
        predicted_actions = torch.cat([action if idx == self.agent_idx \
                   else action.detach()
                   for idx, action in enumerate(all_curr_pred_actions)],
                   dim=1).to(self.device)

        actor_loss = -self.critic_local(states, predicted_actions).mean()
        # minimize loss
        actor_loss.backward()
        self.actor_optim.step()
        
        al = actor_loss.cpu().detach().item()
        cl = critic_loss.cpu().detach().item()
        
        if self.tensorboard_writer is not None:            
            self.tensorboard_writer.add_scalar("agent{}/actor_loss".format(self.agent_idx), al, self.iteration)
            self.tensorboard_writer.add_scalar("agent{}/critic_loss".format(self.agent_idx), cl, self.iteration)
            self.tensorboard_writer.file_writer.flush()
            
        self.iteration += 1

        # ----------------------- update target networks ----------------------- #
        soft_update(self.critic_target, self.critic_local, self.tau)
        soft_update(self.actor_target, self.actor_local, self.tau)           
Esempio n. 11
0
    def update(self):
        if len(self.replay_buffer) < self.batch_size:
            return
        states, actions, rewards, next_states, dones = self.replay_buffer.sample(
            batch_size=self.batch_size, device='cpu')

        #===============================Critic Update===============================
        with torch.no_grad():
            target = rewards + GAMMA * (1 - dones) * self.Q_target(
                (next_states, self.P_target(next_states)))
        Q = self.Q_online((states, actions))
        td_error = self.loss_td(target, Q)
        self.q_optimizer.zero_grad()
        td_error.backward()
        self.q_optimizer.step()

        #===============================Actor Update===============================
        q = self.Q_online((states, self.P_online(states)))
        loss_a = -torch.mean(q)
        self.p_optimizer.zero_grad()
        loss_a.backward()
        self.p_optimizer.step()

        #===============================Target Update===============================
        soft_update(self.Q_target, self.Q_online, tau=1e-2)
        soft_update(self.P_target, self.P_online, tau=1e-2)
Esempio n. 12
0
 def update_all_targets(self):
     """
     Update all target networks (called after normal updates have been
     performed for each agent)
     """
     
     soft_update(self.critic, self.target_critic, self.tau)
     soft_update(self.policy, self.target_policy, self.tau)
Esempio n. 13
0
    def update(self):

        if len(self.memory) < self.BATCH_SIZE:
            return

        # get training batch

        transitions = self.memory.sample(self.BATCH_SIZE)

        batch = Transition(*zip(*transitions))

        state_batch = torch.cat(batch.state)

        action_batch = torch.cat(batch.action)

        reward_batch = torch.cat(batch.reward).unsqueeze(1)

        next_state = torch.cat(batch.next_state)

        # update value network

        state_action = torch.cat((state_batch, action_batch), dim=1)
        state_action_value = self.value_network(state_action)

        next_action = self.action_target_network(next_state).detach()

        next_state_action = torch.cat((next_state, next_action), dim=1)
        next_state_action_value = self.value_target_network(
            next_state_action).detach()

        expected_state_action_value = (self.DISCOUNT *
                                       next_state_action_value) + reward_batch

        value_loss = self.criterion(state_action_value,
                                    expected_state_action_value)

        self.value_optimizer.zero_grad()

        value_loss.backward()
        self.value_optimizer.step()

        # update action network

        optim_action = self.action_network(state_batch)

        optim_state_action = torch.cat((state_batch, optim_action), dim=1)

        action_loss = -self.value_network(optim_state_action)
        action_loss = action_loss.mean()

        self.action_optimizer.zero_grad()

        action_loss.backward()
        self.action_optimizer.step()

        # update target network
        soft_update(self.value_target_network, self.value_network, 0.01)
        soft_update(self.action_target_network, self.action_network, 0.01)
    def update_parameters(self, memory, batch_size, updates):
        # Sample a batch from memory
        state_batch, action_batch, reward_batch, next_state_batch, mask_batch = memory.sample(batch_size=batch_size)

        state_batch = torch.FloatTensor(state_batch).to(self.device)
        next_state_batch = torch.FloatTensor(next_state_batch).to(self.device)
        action_batch = torch.FloatTensor(action_batch).to(self.device)
        reward_batch = torch.FloatTensor(reward_batch).to(self.device).unsqueeze(1)
        mask_batch = torch.FloatTensor(mask_batch).to(self.device).unsqueeze(1)

        with torch.no_grad():
            next_state_action, next_state_log_pi, _ = self.policy.sample(next_state_batch)
            qf1_next_target, qf2_next_target = self.critic_target(next_state_batch, next_state_action)
            min_qf_next_target = torch.min(qf1_next_target, qf2_next_target) - self.alpha * next_state_log_pi
            next_q_value = reward_batch + mask_batch * self.gamma * (min_qf_next_target)

        qf1, qf2 = self.critic(state_batch, action_batch)  # Two Q-functions to mitigate positive bias in the policy improvement step
        qf1_loss = F.mse_loss(qf1, next_q_value) # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]
        qf2_loss = F.mse_loss(qf2, next_q_value) # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]

        pi, log_pi, _ = self.policy.sample(state_batch)

        qf1_pi, qf2_pi = self.critic(state_batch, pi)
        min_qf_pi = torch.min(qf1_pi, qf2_pi)

        policy_loss = ((self.alpha * log_pi) - min_qf_pi).mean() # Jπ = 𝔼st∼D,εt∼N[α * logπ(f(εt;st)|st) − Q(st,f(εt;st))]

        self.critic_optim.zero_grad()
        qf1_loss.backward()
        self.critic_optim.step()

        self.critic_optim.zero_grad()
        qf2_loss.backward()
        self.critic_optim.step()
        
        self.policy_optim.zero_grad()
        policy_loss.backward()
        self.policy_optim.step()

        if self.automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean()

            self.alpha_optim.zero_grad()
            alpha_loss.backward()
            self.alpha_optim.step()

            self.alpha = self.log_alpha.exp()
            alpha_tlogs = self.alpha.clone() # For TensorboardX logs
        else:
            alpha_loss = torch.tensor(0.).to(self.device)
            alpha_tlogs = torch.tensor(self.alpha) # For TensorboardX logs


        if updates % self.target_update_interval == 0:
            soft_update(self.critic_target, self.critic, self.tau)

        return qf1_loss.item(), qf2_loss.item(), policy_loss.item(), alpha_loss.item(), alpha_tlogs.item()
Esempio n. 15
0
    def learn(self):
        self.learning_steps += 1
        if self.learning_steps % self.target_update_interval == 0:
            soft_update(self.critic_target, self.critic, self.tau)

        # Update the latent model.
        self.learn_latent()
        # Update policy and critic.
        self.learn_sac()
Esempio n. 16
0
 def train_critic(self):
     loss = self.td_critic_loss(self.critic)
     self.critic_optimizer.zero_grad()
     self.critic_upd_steps += 1
     writer.add_scalar("critic loss", loss, self.critic_upd_steps)
     loss.backward()
     self.critic_optimizer.step()
     soft_update(self.critic_target, self.critic, args.tau)
     return loss
Esempio n. 17
0
File: sac.py Progetto: km01/myrl
    def update_parameters(self, batch):

        obs, act, rew, done, obs_next = batch
        obs = torch.FloatTensor(obs).to(self.device)
        act = torch.FloatTensor(act).to(self.device)
        rew = torch.FloatTensor(rew).unsqueeze(-1).to(self.device)
        done = torch.BoolTensor(done).unsqueeze(-1).to(self.device)
        obs_next = torch.FloatTensor(obs_next).to(self.device)

        with torch.no_grad():
            next_v = self.value_target(obs_next).masked_fill(done, 0.)
            q_targ = rew + self.gamma * next_v

        self.critic_optim.zero_grad()
        q1, q2 = self.critic(obs, act)
        critic_loss = (q1 - q_targ).pow(2.).mul(0.5) + (
            q2 - q_targ).pow(2.).mul(0.5)
        critic_loss = critic_loss.mean()
        critic_loss.backward()
        self.critic_optim.step()
        with torch.no_grad():
            critic_loss = (torch.min(q1, q2) - q_targ).pow(2).mul(0.5).mean()

        self.policy_optim.zero_grad()
        policy = TanhGaussian(*self.policy(obs))
        action = policy.sample()
        log_pi = policy.log_prob(action, param_grad=False).sum(dim=-1,
                                                               keepdim=True)

        action_value = torch.min(*self.critic(obs, action))

        with torch.no_grad():
            v_targ = action_value - self.alpha * log_pi

        self.value_optim.zero_grad()
        v = self.value(obs)
        value_loss = (v - v_targ).pow(2.).mul(0.5)
        value_loss = value_loss.mean()
        value_loss.backward()
        self.value_optim.step()

        policy_loss = self.alpha * log_pi - action_value
        policy_loss = policy_loss.mean()
        policy_loss.backward()
        self.policy_optim.step()

        soft_update(self.value_target, self.value, self.tau)
        loss_info = {
            'critic_loss': critic_loss.item(),
            'policy_loss': policy_loss.item(),
            'value_loss': value_loss.item(),
            'policy_entropy': -log_pi.mean().item()
        }
        return loss_info
Esempio n. 18
0
    def update_parameters(self):
        self.learning_steps += 1
        # Update Q networks.
        for i in range(self.Q_updates_per_step):
            # Sample a batch from memory.
            states, actions, rewards, next_states, masks = self.memory.sample(
                batch_size=self.batch_size)
            self.critic_optim.zero_grad()
            qf1_loss, qf2_loss = self.calc_critic_loss(states, actions,
                                                       rewards, next_states,
                                                       masks)
            qf1_loss.backward()
            qf2_loss.backward()
            clip_grad_norm_(self.critic_online.parameters(),
                            max_norm=self.max_grad_norm)
            self.critic_optim.step()

        policy_loss, kl, entropies, cross_entropies, policy_term = self.calc_policy_loss(
            states, self.learning_steps)
        # Update coefficents
        if self.limit_kl:
            self.alpha_optim.zero_grad()
            alpha_loss = self.log_alpha * (
                (self.target_kl - self.target_entropy) - cross_entropies)
            alpha_loss.backward()
            self.alpha_optim.step()
            self.alpha_optim.zero_grad()
            rho_loss = self.log_rho * (entropies - self.target_entropy)
            rho_loss.backward()
            self.rho_optim.step()
            self.rho_optim.zero_grad()
        # Update policy networks.
        self.backup_policy()
        self.policy_optim.zero_grad()
        policy_loss.backward()
        clip_grad_norm_(self.policy.parameters(), max_norm=self.max_grad_norm)
        self.policy_optim.step()

        # Soft update target critic network.
        soft_update(self.critic_target, self.critic_online, self.tau)
        # Log training information.
        if self.learning_steps % self.log_interval == 0 and self.learning_steps > 100:
            self.writer.add_scalar('loss/Q1',
                                   qf1_loss.detach().item(),
                                   self.learning_steps)
            self.writer.add_scalar('loss/policy',
                                   policy_loss.detach().item(),
                                   self.learning_steps)
            self.writer.add_scalar('stats/kl', kl, self.learning_steps)
            self.writer.add_scalar('stats/entropies', entropies,
                                   self.learning_steps)
            self.writer.add_scalar('stats/cross_entropies', cross_entropies,
                                   self.learning_steps)
Esempio n. 19
0
    def train(self):

        state_batches, action_batches, reward_batches, next_state_batches, done_batches = self.get_batches(
        )

        state_batches = torch.Tensor(state_batches).to(self.device)
        action_batches = torch.Tensor(action_batches).to(self.device)
        reward_batches = torch.Tensor(reward_batches).reshape(
            self.config.batch_size, self.n_agents, 1).to(self.device)
        next_state_batches = torch.Tensor(next_state_batches).to(self.device)
        done_batches = torch.Tensor(
            (done_batches == False) * 1).reshape(self.config.batch_size,
                                                 self.n_agents,
                                                 1).to(self.device)

        target_next_actions = self.policy_target.forward(next_state_batches)
        target_next_q = self.critic_target.forward(next_state_batches,
                                                   target_next_actions)
        main_q = self.critic(state_batches, action_batches)
        '''
        How to concat each agent's Q value?
        '''
        #target_next_q = target_next_q
        #main_q = main_q.mean(dim=1)
        '''
        Reward Norm
        '''
        # reward_batches = (reward_batches - reward_batches.mean(dim=0)) / reward_batches.std(dim=0) / 1024

        # Critic Loss
        self.critic.zero_grad()
        baselines = reward_batches + done_batches * self.config.gamma * target_next_q
        loss_critic = torch.nn.MSELoss()(main_q, baselines.detach())
        loss_critic.backward()
        torch.nn.utils.clip_grad_norm_(self.critic.parameters(), 0.5)
        self.critic_optimizer.step()

        # Actor Loss
        self.policy.zero_grad()
        clear_action_batches = self.policy.forward(state_batches)
        loss_actor = -self.critic.forward(state_batches,
                                          clear_action_batches).mean()
        loss_actor += (clear_action_batches**2).mean() * 1e-3
        loss_actor.backward()
        torch.nn.utils.clip_grad_norm_(self.policy.parameters(), 0.5)
        self.policy_optimizer.step()

        # This is for logging
        self.c_loss = loss_critic.item()
        self.a_loss = loss_actor.item()

        soft_update(self.policy, self.policy_target, self.config.tau)
        soft_update(self.critic, self.critic_target, self.config.tau)
    def update_parameters(self, memory, batch_size, updates):
        # Sample a batch from memory
        state_batch, action_batch, reward_batch, next_state_batch, mask_batch = memory.sample(batch_size=batch_size)

        state_batch = torch.FloatTensor(state_batch).to(self.device)
        next_state_batch = torch.FloatTensor(next_state_batch).to(self.device)
        action_batch = torch.FloatTensor(action_batch).to(self.device)
        reward_batch = torch.FloatTensor(reward_batch).to(self.device).unsqueeze(1)
        mask_batch = torch.FloatTensor(mask_batch).to(self.device).unsqueeze(1)

        with torch.no_grad():
            next_state_action, next_state_log_pi, _ = self.policy.sample(next_state_batch)
            qf1_next_target, qf2_next_target = self.critic_target(next_state_batch, next_state_action)
            min_qf_next_target = torch.min(qf1_next_target, qf2_next_target) - self.alpha * next_state_log_pi
            next_q_value = reward_batch + mask_batch * self.gamma * (min_qf_next_target)

        # Two Q-functions to mitigate positive bias in the policy improvement step
        qf1, qf2 = self.critic(state_batch, action_batch) 
        qf1_loss = F.mse_loss(qf1, next_q_value) 
        qf2_loss = F.mse_loss(qf2, next_q_value) 
  
        pi, log_pi, _ = self.policy.sample(state_batch)

        qf1_pi, qf2_pi = self.critic(state_batch, pi)
        min_qf_pi = torch.min(qf1_pi, qf2_pi)

        policy_loss = ((self.alpha * log_pi) - min_qf_pi).mean() 
        
        self.critic_optim.zero_grad()
        qf1_loss.backward()
        self.critic_optim.step()

        self.critic_optim.zero_grad()
        qf2_loss.backward()
        self.critic_optim.step()
        
        self.policy_optim.zero_grad()
        policy_loss.backward()
        self.policy_optim.step()

        alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean()

        self.alpha_optim.zero_grad()
        alpha_loss.backward()
        self.alpha_optim.step()

        self.alpha = self.log_alpha.exp()
        alpha_tlogs = self.alpha.clone() # For TensorboardX logs
         
        soft_update(self.critic_target, self.critic, self.tau)    
Esempio n. 21
0
File: sac.py Progetto: km01/myrl
    def update_parameters(self, batch):
        obs, act, rew, done, obs_next = batch
        obs = torch.FloatTensor(obs).to(self.device)
        act = torch.LongTensor(act).unsqueeze(-1).to(self.device)
        rew = torch.FloatTensor(rew).unsqueeze(-1).to(self.device)
        done = torch.BoolTensor(done).unsqueeze(-1).to(self.device)
        obs_next = torch.FloatTensor(obs_next).to(self.device)

        with torch.no_grad():
            next_policy = Cat(raw_base=self.policy(obs_next))
            next_q = torch.min(*self.critic_target(obs_next))
            next_eval = (next_policy.probs * next_q).sum(dim=-1, keepdim=True)
            next_entr = -(next_policy.probs * next_policy.logits).sum(
                dim=-1, keepdim=True)
            next_v = (next_eval + self.alpha * next_entr).masked_fill(done, 0.)
            q_targ = rew + self.gamma * next_v

        self.critic_optim.zero_grad()
        q1, q2 = self.critic(obs)

        q_pred = torch.min(q1, q2).detach()
        q1, q2 = q1.gather(dim=-1, index=act), q2.gather(dim=-1, index=act)
        critic_loss = (q1 - q_targ).pow(2.).mul(0.5) + (
            q2 - q_targ).pow(2.).mul(0.5)
        critic_loss = critic_loss.mean()
        critic_loss.backward()
        self.critic_optim.step()

        with torch.no_grad():
            critic_loss = (torch.min(q1, q2) - q_targ).pow(2.).mul(0.5).mean()

        self.policy_optim.zero_grad()
        policy = Cat(raw_base=self.policy(obs))
        policy_entr = -(policy.probs.detach() *
                        policy.logits).sum(dim=-1).mean()
        policy_eval = (policy.probs * q_pred).sum(dim=-1).mean()
        policy_loss = self.alpha * policy_entr - policy_eval
        policy_loss.backward()
        self.policy_optim.step()

        soft_update(self.critic_target, self.critic, self.tau)

        loss_info = {
            'critic_loss': critic_loss.item(),
            'policy_loss': policy_loss.item(),
            'policy_entr': policy_entr.item()
        }

        return loss_info
Esempio n. 22
0
    def learn(self, experiences, gamma):
        states, actions, rewards, next_states, dones = experiences

        Q_targets_next = self.qnetwork_target(next_states).detach().max(
            1)[0].unsqueeze(1)
        Q_targets = rewards + (gamma * Q_targets_next * (1 - dones))
        Q_expected = self.qnetwork_local(states).gather(1, actions)

        # Compute loss and backpropagate
        loss = F.mse_loss(Q_expected, Q_targets)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # Update target network
        soft_update(self.qnetwork_local, self.qnetwork_target, self.tau)
Esempio n. 23
0
    def update_parameters(self, memory, batch_size, updates):
        """

        :param memory:
        :param batch_size:
        :param updates:
        :return:
        """
        state_batch, action_batch, reward_batch, next_state_batch, mask_batch = memory.sample(
            batch_size=batch_size)

        states = torch.FloatTensor(state_batch).to(self.device)
        next_states = torch.FloatTensor(next_state_batch).to(self.device)
        actions = torch.FloatTensor(action_batch).to(self.device)
        rewards = torch.FloatTensor(reward_batch).to(self.device).unsqueeze(1)
        done = torch.FloatTensor(mask_batch).to(self.device).unsqueeze(1)

        # UPDATE CRITIC #
        # Get predicted next-state actions and Q values from target models
        actions_next = self.target_policy_net(next_states)
        q_targets_next = self.target_value_net(next_states,
                                               actions_next.detach())
        # Compute Q targets for current states (y_i)
        q_targets = rewards + (self.gamma * q_targets_next * (1.0 - done))
        # Compute critic loss
        q_expected = self.critic_net(states, actions)
        critic_loss = self.critic_loss(q_expected, q_targets)
        # Minimize the loss
        self.critic_opt.zero_grad()
        critic_loss.backward()
        self.critic_opt.step()

        # UPDATE ACTOR #
        # Compute actor loss
        actions_pred = self.actor_net(states)
        actor_loss = -self.critic_net(states, actions_pred).mean()
        # Maximize the expected return
        self.actor_opt.zero_grad()
        actor_loss.backward()
        self.actor_opt.step()

        # UPDATE TARGET NETWORK #
        if updates % self.target_update == 0:
            soft_update(self.critic_net, self.target_value_net, self.tau)
            soft_update(self.actor_net, self.target_policy_net, self.tau)

        return actor_loss.item(), critic_loss.item()
Esempio n. 24
0
    def update(self):
        if len(self.replay_buffer) < self.batch_size:
            return
        states, actions, rewards, next_states, dones = self.replay_buffer.sample(
            batch_size=self.batch_size, device=self.device)
        # discounted rewards
        # rewards = torch.from_numpy(discount((rewards.view(rewards.shape[0])).cpu().numpy())).float().to(self.device)

        ### debug shape : ok
        #===============================Critic Update===============================
        self.Q_online.train()
        Q = self.Q_online((states, actions))

        with torch.no_grad():  # don't need backprop for target value
            self.Q_target.eval()
            self.P_target.eval()
            target = rewards + self.gamma * (1 - dones) * self.Q_target(
                (next_states, self.P_target(next_states)))
        critic_loss_fn = torch.nn.MSELoss()
        critic_loss = critic_loss_fn(Q, target).mean()
        # update
        self.q_optimizer.zero_grad()
        critic_loss.backward()
        self.q_optimizer.step()
        # print("critic loss", critic_loss.item())

        #===============================Actor Update===============================
        # fix online_critic , update online_actor
        self.Q_online.eval()
        for p in self.Q_online.parameters():
            p.requires_grad = False
        for p in self.P_online.parameters():
            p.requires_grad = True
        policy_loss = -self.Q_online((states, self.P_online(states)))
        policy_loss = policy_loss.mean()
        self.p_optimizer.zero_grad()
        policy_loss.backward()
        self.p_optimizer.step()
        # print("policy loss", policy_loss.item())
        for p in self.Q_online.parameters():
            p.requires_grad = True
        #===============================Target Update===============================
        soft_update(self.Q_target, self.Q_online, tau=1e-3)
        soft_update(self.P_target, self.P_online, tau=1e-3)
        self.eps -= EPSILON_DECAY
        if self.eps <= 0:
            self.eps = 0
    def optimize(self):
        """
		Samples a random batch from replay memory and performs optimization
		:return:
		"""
        s1, a1, r1, s2 = self.ram.sample(BATCH_SIZE)

        s1 = Variable(torch.from_numpy(s1))
        a1 = Variable(torch.from_numpy(a1))
        r1 = Variable(torch.from_numpy(r1))
        s2 = Variable(torch.from_numpy(s2))

        # ---------------------- optimize critic ----------------------
        # Use target actor exploitation policy here for loss evaluation
        a2 = self.target_actor.forward(s2).detach()
        next_val = torch.squeeze(self.target_critic.forward(s2, a2).detach())

        # y_exp = r + gamma*Q'( s2, pi'(s2))
        y_expected = r1 + GAMMA * next_val

        # y_pred = Q( s1, a1)
        y_predicted = torch.squeeze(self.critic.forward(s1, a1))

        # compute critic loss, and update the critic
        loss_critic = F.smooth_l1_loss(y_predicted, y_expected)
        self.critic_optimizer.zero_grad()
        loss_critic.backward()
        self.critic_optimizer.step()

        # ---------------------- optimize actor ----------------------
        pred_a1 = self.actor.forward(s1)

        # compute actor loss and update actor
        loss_actor = -1 * torch.sum(self.critic.forward(s1, pred_a1))
        self.actor_optimizer.zero_grad()
        loss_actor.backward()
        self.actor_optimizer.step()

        # updating target networks according to: y = TAU*x + (1 - TAU)*y
        utils.soft_update(self.target_actor, self.actor, TAU)
        utils.soft_update(self.target_critic, self.critic, TAU)

        if self.iter % 100 == 0:
            print('Iteration :- ', self.iter, ' Loss_actor :- ', loss_actor.data.numpy(),\
             ' Loss_critic :- ', loss_critic.data.numpy())
        self.iter += 1
Esempio n. 26
0
    def optimize(self):
        """
        Samples a random batch from replay memory and performs optimization
        :return:
        """
        if self.args.pri:
            s1, a1, r1, s2, tree_idx, weights = self.ram.sample(
                self.batch_size)
            weights = torch.from_numpy(weights).to(self.device)
        else:
            s1, a1, r1, s2 = self.ram.sample(self.batch_size)
        s1 = Variable(torch.from_numpy(s1)).to(self.device)
        a1 = Variable(torch.from_numpy(a1)).to(self.device)
        r1 = Variable(torch.from_numpy(r1)).to(self.device)
        s2 = Variable(torch.from_numpy(s2)).to(self.device)

        # ---------------------- optimize critic ----------------------
        #  Use target actor exploitation policy here for loss evaluation
        a2 = self.target_actor.forward(s2).detach()
        next_val = torch.squeeze(self.target_critic.forward(s2, a2).detach())
        y_expected = r1 + self.gamma * next_val
        y_predicted = torch.squeeze(self.critic.forward(s1, a1))
        if self.args.pri:
            td_error = torch.abs(y_predicted - y_expected)
            loss_critic = torch.sum(weights * td_error**2)
            self.ram.update_tree(tree_idx, td_error.detach().cpu().numpy())
        else:
            loss_critic = F.mse_loss(y_predicted, y_expected)

        self.critic_optimizer.zero_grad()
        loss_critic.backward()
        self.critic_optimizer.step()

        # ---------------------- optimize actor ----------------------
        pred_a1 = self.actor.forward(s1)
        loss_actor = -1 * torch.mean(self.critic.forward(s1, pred_a1))
        self.actor_optimizer.zero_grad()
        loss_actor.backward()
        self.actor_optimizer.step()

        # logging
        self.writer.add_scalar('loss/critic', loss_critic, self.iter)
        self.writer.add_scalar('loss/actor', loss_actor, self.iter)
        utils.soft_update(self.AC_T, self.AC, self.tau)
        self.iter += 1
Esempio n. 27
0
    def train(self):
        state_batches, action_batches, reward_batches, next_state_batches, done_batches = self.get_batches(
        )

        state_batches = Variable(torch.Tensor(state_batches).to(self.device))
        action_batches = Variable(
            torch.Tensor(action_batches).reshape(-1, 1).to(self.device))
        reward_batches = Variable(
            torch.Tensor(reward_batches).reshape(-1, 1).to(self.device))
        next_state_batches = Variable(
            torch.Tensor(next_state_batches).to(self.device))
        done_batches = Variable(
            torch.Tensor(
                (done_batches == False) * 1).reshape(-1, 1).to(self.device))

        target_next_actions = self.actor_target.forward(
            next_state_batches).detach()
        target_next_q = self.critic_target.forward(
            next_state_batches, target_next_actions).detach()

        main_q = self.critic(state_batches, action_batches)

        # Critic Loss
        self.critic.zero_grad()
        baselines = reward_batches + done_batches * self.config.gamma * target_next_q
        loss_critic = torch.nn.MSELoss()(main_q, baselines)
        loss_critic.backward()
        self.critic_optimizer.step()

        # Actor Loss
        self.actor.zero_grad()
        clear_action_batches = self.actor.forward(state_batches)
        loss_actor = (
            -self.critic.forward(state_batches, clear_action_batches)).mean()
        loss_actor.backward()
        self.actor_optimizer.step()

        # This is for logging
        self.c_loss = loss_critic.item()
        self.a_loss = loss_actor.item()

        soft_update(self.actor, self.actor_target, self.config.tau)
        soft_update(self.critic, self.critic_target, self.config.tau)
Esempio n. 28
0
    def _learn_from_memory(self):
        '''从记忆学习,更新两个网络的参数
        '''
        # 随机获取记忆里的Transmition
        trans_pieces = self.sample(self.batch_size)
        s0 = np.vstack([x.s0 for x in trans_pieces])
        a0 = np.array([x.a0 for x in trans_pieces])
        r1 = np.array([x.reward for x in trans_pieces])
        # is_done = np.array([x.is_done for x in trans_pieces])
        s1 = np.vstack([x.s1 for x in trans_pieces])

        # 优化评论家网络参数
        a1 = self.target_actor.forward(s1).detach()
        next_val = torch.squeeze(self.target_critic.forward(s1, a1).detach())

        # y_exp = r + gamma*Q'( s2, pi'(s2))

        y_expected = torch.from_numpy(r1).type(
            torch.FloatTensor) + self.gamma * next_val
        y_expected = y_expected.type(torch.FloatTensor)
        # y_pred = Q( s1, a1)
        a0 = torch.from_numpy(a0)  # 转换成Tensor
        y_predicted = torch.squeeze(self.critic.forward(s0, a0))
        # compute critic loss, and update the critic
        loss_critic = F.smooth_l1_loss(y_predicted, y_expected)
        self.critic_optimizer.zero_grad()
        loss_critic.backward()
        self.critic_optimizer.step()

        # 优化演员网络参数,优化的目标是使得Q增大
        pred_a0 = self.actor.forward(s0)  # 直接使用a0会不收敛
        #反向梯度下降(梯度上升),以某状态的价值估计为策略目标函数
        loss_actor = -1 * torch.sum(self.critic.forward(s0, pred_a0))
        self.actor_optimizer.zero_grad()
        loss_actor.backward()
        self.actor_optimizer.step()

        # 软更新参数
        soft_update(self.target_actor, self.actor, self.tau)
        soft_update(self.target_critic, self.critic, self.tau)
        return (loss_critic.item(), loss_actor.item())
    def _learn_from_memory(self):
        '''从experience学习,更新两个网络的参数;
        '''
        # 随机获取记忆里的Transmition
        trans_pieces = self.sample(self.batch_size)
        s0 = np.vstack([x.s0 for x in trans_pieces])
        a0 = np.array([x.a0 for x in trans_pieces])
        r1 = np.array([x.reward for x in trans_pieces])
        s1 = np.vstack([x.s1 for x in trans_pieces])

        # 优化critic网络参数,最小化loss
        a1 = self.target_actor.forward(s1).detach()
        next_val = torch.squeeze(self.target_critic.forward(s1, a1).detach())
        r1 = torch.from_numpy(r1)
        # r1 = r1.type(torch.DoubleTensor)
        next_val = next_val.type(torch.DoubleTensor)
        y_expected = r1 + self.gamma * next_val
        y_expected = y_expected.type(torch.FloatTensor)

        a0 = torch.from_numpy(a0)  # 转换成Tensor
        y_predicted = torch.squeeze(self.critic.forward(s0, a0))

        # 最小化loss,更新critic
        loss_critic = F.smooth_l1_loss(y_predicted, y_expected)
        self.critic_optimizer.zero_grad()
        loss_critic.backward()
        self.critic_optimizer.step()

        # 优化actor网络参数,优化的目标是使得Q增大
        pred_a0 = self.actor.forward(s0)  # 直接使用a0会不收敛
        #反向梯度下降(梯度上升),以某状态的价值估计为策略目标函数
        loss_actor = -1 * torch.sum(self.critic.forward(s0, pred_a0))
        self.actor_optimizer.zero_grad()
        loss_actor.backward()
        self.actor_optimizer.step()

        # 软更新参数,跟新target网络
        soft_update(self.target_actor, self.actor, self.tau)
        soft_update(self.target_critic, self.critic, self.tau)
        return (loss_critic.item(), loss_actor.item())
Esempio n. 30
0
    def replay_buffer_training(self, sample, train_results, n):

        s, a, r, t, stag = [sample[k] for k in ['s', 'a', 'r', 't', 'stag']]

        self.train_mode()

        with torch.no_grad():
            pi_tag = self.pi_target(stag)

            noise = (torch.randn_like(pi_tag) * self.policy_noise).clamp(
                -self.noise_clip, self.noise_clip)
            pi_tag = (pi_tag + noise).clamp(-1, 1)

            q_target_1 = self.q_target_1(stag, pi_tag)
            q_target_2 = self.q_target_2(stag, pi_tag)

        q_target = torch.min(q_target_1, q_target_2)
        g = r + (1 - t) * self.gamma**self.n_steps * q_target

        qa = self.q_net_1(s, a)
        loss_q = F.mse_loss(qa, g, reduction='mean')

        self.optimizer_q_1.zero_grad()
        loss_q.backward()
        if self.clip_q:
            nn.utils.clip_grad_norm(self.q_net_1.parameters(), self.clip_q)
        self.optimizer_q_1.step()

        qa = self.q_net_2(s, a)
        loss_q = F.mse_loss(qa, g, reduction='mean')

        self.optimizer_q_2.zero_grad()
        loss_q.backward()
        if self.clip_q:
            nn.utils.clip_grad_norm(self.q_net_2.parameters(), self.clip_q)
        self.optimizer_q_2.step()

        if not n % self.td3_delayed_policy_update:

            pi = self.pi_net(s)

            v = self.q_net_1(s, pi)
            loss_p = (-v).mean()

            self.optimizer_p.zero_grad()
            loss_p.backward()
            if self.clip_p:
                nn.utils.clip_grad_norm(self.pi_net.parameters(), self.clip_p)
            self.optimizer_p.step()

            train_results['scalar']['objective'].append(float(-loss_p))

            soft_update(self.pi_net, self.pi_target, self.tau)
            soft_update(self.q_net_1, self.q_target_1, self.tau)
            soft_update(self.q_net_2, self.q_target_2, self.tau)

        train_results['scalar']['loss_q'].append(float(loss_q))

        return train_results