def tensorify(self):
        """Method to save experiences to drive
			   Parameters:
				   None
			   Returns:
				   None
		   """
        self.referesh()  #Referesh first

        if self.__len__() > 1:

            self.sT = torch.tensor(np.vstack(self.s))
            self.nsT = torch.tensor(np.vstack(self.ns))
            self.aT = torch.tensor(np.vstack(self.a))
            self.rT = torch.tensor(np.vstack(self.r))
            self.doneT = torch.tensor(np.vstack(self.done))
            self.global_rewardT = torch.tensor(np.vstack(self.global_reward))
            if self.buffer_gpu:
                self.sT = self.sT.cuda()
                self.nsT = self.nsT.cuda()
                self.aT = self.aT.cuda()
                self.rT = self.rT.cuda()
                self.doneT = self.doneT.cuda()
                self.global_rewardT = self.global_rewardT.cuda()

            #Prioritized indices update
            self.top_r = list(
                np.argsort(np.vstack(self.r), axis=0)[-int(len(self.s) / 10):])
            self.top_g = list(
                np.argsort(np.vstack(self.global_reward),
                           axis=0)[-int(len(self.s) / 10):])

            #Update Stats
            compute_stats(self.rT, self.rstats)
            compute_stats(self.global_rewardT, self.gstats)
示例#2
0
	def update_parameters(self, state_batch, next_state_batch, action_batch, reward_batch, done_batch, global_reward, num_epoch=1, **kwargs):
		"""Runs a step of Bellman upodate and policy gradient using a batch of experiences

			 Parameters:
				  state_batch (tensor): Current States
				  next_state_batch (tensor): Next States
				  action_batch (tensor): Actions
				  reward_batch (tensor): Rewards
				  done_batch (tensor): Done batch
				  num_epoch (int): Number of learning iteration to run with the same data

			 Returns:
				   None

		 """

		if isinstance(state_batch, list): state_batch = torch.cat(state_batch); next_state_batch = torch.cat(next_state_batch); action_batch = torch.cat(action_batch); reward_batch = torch.cat(reward_batch). done_batch = torch.cat(done_batch); global_reward = torch.cat(global_reward)

		for _ in range(num_epoch):
			########### CRITIC UPDATE ####################

			#Compute next q-val, next_v and target
			with torch.no_grad():
				#Policy Noise
				policy_noise = np.random.normal(0, kwargs['policy_noise'], (action_batch.size()[0], action_batch.size()[1]))
				policy_noise = torch.clamp(torch.Tensor(policy_noise), -kwargs['policy_noise_clip'], kwargs['policy_noise_clip'])

				#Compute next action_bacth
				next_action_batch = self.policy_target.clean_action(next_state_batch, return_only_action=True) + policy_noise.cuda() if self.use_gpu else policy_noise
				next_action_batch = torch.clamp(next_action_batch, -1, 1)

				#Compute Q-val and value of next state masking by done
				q1, q2 = self.critic_target.forward(next_state_batch, next_action_batch)
				q1 = (1 - done_batch) * q1
				q2 = (1 - done_batch) * q2
				#next_val = (1 - done_batch) * next_val

				#Select which q to use as next-q (depends on algo)
				if self.algo_name == 'TD3' or self.algo_name == 'TD3_actor_min': next_q = torch.min(q1, q2)
				elif self.algo_name == 'DDPG': next_q = q1
				elif self.algo_name == 'TD3_max': next_q = torch.max(q1, q2)

				#Compute target q and target val
				target_q = reward_batch + (self.gamma * next_q)
				#if self.args.use_advantage: target_val = reward_batch + (self.gamma * next_val)

			if self.actualize:
				##########Actualization Network Update
				current_Ascore = self.ANetwork.forward(state_batch, action_batch)
				utils.compute_stats(current_Ascore, self.alz_score)
				target_Ascore = (self.actualize_lr) * (global_reward * 10.0) + (1 - self.actualize_lr) * current_Ascore.detach()
				actualize_loss = self.loss(target_Ascore, current_Ascore).mean()



			self.critic_optim.zero_grad()
			current_q1, current_q2 = self.critic.forward((state_batch), (action_batch))
			utils.compute_stats(current_q1, self.q)

			dt = self.loss(current_q1, target_q)
			# if self.args.use_advantage:
			#     dt = dt + self.loss(current_val, target_val)
			#     utils.compute_stats(current_val, self.val)

			if self.algo_name == 'TD3' or self.algo_name == 'TD3_max': dt = dt + self.loss(current_q2, target_q)
			utils.compute_stats(dt, self.q_loss)

			# if self.args.critic_constraint:
			#     if dt.item() > self.args.critic_constraint_w:
			#         dt = dt * (abs(self.args.critic_constraint_w / dt.item()))
			dt.backward()

			self.critic_optim.step()
			self.num_critic_updates += 1

			if self.actualize:
				self.actualize_optim.zero_grad()
				actualize_loss.backward()
				self.actualize_optim.step()


			#Delayed Actor Update
			if self.num_critic_updates % kwargs['policy_ups_freq'] == 0:

				actor_actions = self.policy.clean_action(state_batch, return_only_action=False)

				# # Trust Region constraint
				# if self.args.trust_region_actor:
				#     with torch.no_grad(): old_actor_actions = self.actor_target.forward(state_batch)
				#     actor_actions = action_batch - old_actor_actions


				Q1, Q2 = self.critic.forward(state_batch, actor_actions)

				# if self.args.use_advantage: policy_loss = -(Q1 - val)
				policy_loss = -Q1

				utils.compute_stats(-policy_loss,self.policy_loss)
				policy_loss = policy_loss.mean()

				###Actualzie Policy Update
				if self.actualize:
					A1 = self.ANetwork.forward(state_batch, actor_actions)
					utils.compute_stats(A1, self.alz_policy)
					policy_loss += -A1.mean()*0.1



				self.policy_optim.zero_grad()



				policy_loss.backward(retain_graph=True)
				#nn.utils.clip_grad_norm_(self.actor.parameters(), 10)
				# if self.args.action_loss:
				#     action_loss = torch.abs(actor_actions-0.5)
				#     utils.compute_stats(action_loss, self.action_loss)
				#     action_loss = action_loss.mean() * self.args.action_loss_w
				#     action_loss.backward()
				#     #if self.action_loss[-1] > self.policy_loss[-1]: self.args.action_loss_w *= 0.9 #Decay action_w loss if action loss is larger than policy gradient loss
				self.policy_optim.step()


			# if self.args.hard_update:
			#     if self.num_critic_updates % self.args.hard_update_freq == 0:
			#         if self.num_critic_updates % self.args.policy_ups_freq == 0: self.hard_update(self.actor_target, self.actor)
			#         self.hard_update(self.critic_target, self.critic)


			if self.num_critic_updates % kwargs['policy_ups_freq'] == 0: utils.soft_update(self.policy_target, self.policy, self.tau)
			utils.soft_update(self.critic_target, self.critic, self.tau)

			self.total_update += 1
			if self.agent_id == 0:
				self.tracker.update([self.q['mean'], self.q_loss['mean'], self.policy_loss['mean'],self.alz_score['mean'], self.alz_policy['mean']] ,self.total_update)
示例#3
0
	def update_parameters(self, state_batch, next_state_batch, action_batch, reward_batch, done_batch, agent_id, num_epoch=1, **kwargs):
		"""Runs a step of Bellman upodate and policy gradient using a batch of experiences

			 Parameters:
				  state_batch (tensor): Current States
				  next_state_batch (tensor): Next States
				  action_batch (tensor): Actions
				  reward_batch (tensor): Rewards
				  done_batch (tensor): Done batch
				  num_epoch (int): Number of learning iteration to run with the same data

			 Returns:
				   None

		 """

		if isinstance(state_batch, list): state_batch = torch.cat(state_batch); next_state_batch = torch.cat(next_state_batch); action_batch = torch.cat(action_batch); reward_batch = torch.cat(reward_batch). done_batch = torch.cat(done_batch)
		batch_size = len(state_batch)

		for _ in range(num_epoch):
			########### CRITIC UPDATE ####################

			#Compute next q-val, next_v and target
			with torch.no_grad():


				#Compute next action_bacth
				next_action_batch = torch.cat([self.policy_target.clean_action(next_state_batch[:, id, :], id) for id in range(self.num_agents)], 1)
				if self.algo_name == 'TD3':
					# Policy Noise
					policy_noise = np.random.normal(0, kwargs['policy_noise'], (action_batch.size()[0], action_batch.size()[1] * action_batch.size()[2]))
					policy_noise = torch.clamp(torch.Tensor(policy_noise), -kwargs['policy_noise_clip'], kwargs['policy_noise_clip'])
					next_action_batch += policy_noise.cuda() if self.use_gpu else policy_noise
				next_action_batch = torch.clamp(next_action_batch, -1, 1)

				#Compute Q-val and value of next state masking by done

				q1, q2 = self.critics_target[agent_id].forward(next_state_batch.view(batch_size, -1), next_action_batch)
				q1 = (1 - done_batch) * q1
				q2 = (1 - done_batch) * q2
				#next_val = (1 - done_batch) * next_val

				#Select which q to use as next-q (depends on algo)
				if self.algo_name == 'TD3':next_q = torch.min(q1, q2)
				elif self.algo_name == 'DDPG': next_q = q1

				#Compute target q and target val
				target_q = reward_batch[:,agent_id].unsqueeze(1) + (self.gamma * next_q)
				#if self.args.use_advantage: target_val = reward_batch + (self.gamma * next_val)



			self.critic_optims[agent_id].zero_grad()
			current_q1, current_q2 = self.critics[agent_id].forward((state_batch.view(batch_size, -1)), (action_batch.view(batch_size, -1)))
			utils.compute_stats(current_q1, self.q)

			dt = self.loss(current_q1, target_q)
			# if self.args.use_advantage:
			#     dt = dt + self.loss(current_val, target_val)
			#     utils.compute_stats(current_val, self.val)

			if self.algo_name == 'TD3': dt = dt + self.loss(current_q2, target_q)
			utils.compute_stats(dt, self.q_loss)

			# if self.args.critic_constraint:
			#     if dt.item() > self.args.critic_constraint_w:
			#         dt = dt * (abs(self.args.critic_constraint_w / dt.item()))
			dt.backward()

			self.critic_optims[agent_id].step()
			self.num_critic_updates += 1

			#Delayed Actor Update
			if self.num_critic_updates % kwargs['policy_ups_freq'] == 0 or self.algo_name == 'DDPG':

				agent_action = self.policy.clean_action(state_batch[:,agent_id,:], agent_id)
				joint_action = action_batch.clone()
				joint_action[:,agent_id,:] = agent_action[:]

				#print(np.max(torch.abs(joint_action - action_batch).detach().cpu().numpy()), np.max(torch.abs(joint_action[:,agent_id,:] - agent_action).detach().cpu().numpy()))
				# # Trust Region constraint
				# if self.args.trust_region_actor:
				#     with torch.no_grad(): old_actor_actions = self.actor_target.forward(state_batch)
				#     actor_actions = action_batch - old_actor_actions


				Q1, Q2 = self.critics[agent_id].forward(state_batch.view(batch_size, -1), joint_action.view(batch_size, -1))

				# if self.args.use_advantage: policy_loss = -(Q1 - val)
				policy_loss = -Q1

				utils.compute_stats(-policy_loss,self.policy_loss)
				policy_loss = policy_loss.mean()


				self.policy_optim.zero_grad()



				policy_loss.backward(retain_graph=True)
				#nn.utils.clip_grad_norm_(self.actor.parameters(), 10)
				# if self.args.action_loss:
				#     action_loss = torch.abs(actor_actions-0.5)
				#     utils.compute_stats(action_loss, self.action_loss)
				#     action_loss = action_loss.mean() * self.args.action_loss_w
				#     action_loss.backward()
				#     #if self.action_loss[-1] > self.policy_loss[-1]: self.args.action_loss_w *= 0.9 #Decay action_w loss if action loss is larger than policy gradient loss
				self.policy_optim.step()


			# if self.args.hard_update:
			#     if self.num_critic_updates % self.args.hard_update_freq == 0:
			#         if self.num_critic_updates % self.args.policy_ups_freq == 0: self.hard_update(self.actor_target, self.actor)
			#         self.hard_update(self.critic_target, self.critic)


			if self.num_critic_updates % kwargs['policy_ups_freq'] == 0 or self.algo_name == 'DDPG': utils.soft_update(self.policy_target, self.policy, self.tau)
			for critic, critic_target in zip(self.critics, self.critics_target):
				utils.soft_update(critic_target, critic, self.tau)

			self.total_update += 1
			if self.agent_id == 0:
				self.tracker.update([self.q['mean'], self.q_loss['mean'], self.policy_loss['mean']] ,self.total_update)
示例#4
0
	def update_parameters(self, state_batch, next_state_batch, action_batch, reward_batch, mask_batch, updates, **ignore):
		# state_batch = torch.FloatTensor(state_batch)
		# next_state_batch = torch.FloatTensor(next_state_batch)
		# action_batch = torch.FloatTensor(action_batch)
		# reward_batch = torch.FloatTensor(reward_batch)
		# mask_batch = torch.FloatTensor(np.float32(mask_batch))

		# reward_batch = reward_batch.unsqueeze(1)  # reward_batch = [batch_size, 1]
		# mask_batch = mask_batch.unsqueeze(1)  # mask_batch = [batch_size, 1]

		"""
		Use two Q-functions to mitigate positive bias in the policy improvement step that is known
		to degrade performance of value based methods. Two Q-functions also significantly speed
		up training, especially on harder task.
		"""
		expected_q1_value, expected_q2_value = self.critic(state_batch, action_batch)
		new_action, log_prob, _, mean, log_std = self.policy.noisy_action(state_batch, return_only_action=False)
		utils.compute_stats(expected_q1_value, self.q)


		if self.policy_type == "Gaussian":
			"""
			Including a separate function approximator for the soft value can stabilize training.
			"""
			expected_value = self.value(state_batch)
			utils.compute_stats(expected_value, self.val)
			target_value = self.value_target(next_state_batch)
			next_q_value = reward_batch + mask_batch * self.gamma * target_value  # Reward Scale * r(st,at) - γV(target)(st+1))
		else:
			"""
			There is no need in principle to include a separate function approximator for the state value.
			We use a target critic network for deterministic policy and eradicate the value value network completely.
			"""
			next_state_action, _, _, _, _, = self.policy.noisy_action(next_state_batch, return_only_action=False)
			target_critic_1, target_critic_2 = self.critic_target(next_state_batch, next_state_action)
			target_critic = torch.min(target_critic_1, target_critic_2)
			next_q_value = reward_batch + mask_batch * self.gamma * target_critic  # Reward Scale * r(st,at) - γQ(target)(st+1)

		"""
		Soft Q-function parameters can be trained to minimize the soft Bellman residual
		JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]
		∇JQ = ∇Q(st,at)(Q(st,at) - r(st,at) - γV(target)(st+1))
		"""
		q1_value_loss = self.soft_q_criterion(expected_q1_value, next_q_value.detach())
		q2_value_loss = self.soft_q_criterion(expected_q2_value, next_q_value.detach())
		utils.compute_stats(q1_value_loss, self.q_loss)
		q1_new, q2_new = self.critic(state_batch, new_action)
		expected_new_q_value = torch.min(q1_new, q2_new)

		if self.policy_type == "Gaussian":
			"""
			Including a separate function approximator for the soft value can stabilize training and is convenient to 
			train simultaneously with the other networks
			Update the V towards the min of two Q-functions in order to reduce overestimation bias from function approximation error.
			JV = 𝔼st~D[0.5(V(st) - (𝔼at~π[Qmin(st,at) - log π(at|st)]))^2]
			∇JV = ∇V(st)(V(st) - Q(st,at) + logπ(at|st))
			"""
			next_value = expected_new_q_value - (self.alpha * log_prob)
			value_loss = self.value_criterion(expected_value, next_value.detach())
			utils.compute_stats(value_loss, self.value_loss)
		else:
			pass

		"""
		Reparameterization trick is used to get a low variance estimator
		f(εt;st) = action sampled from the policy
		εt is an input noise vector, sampled from some fixed distribution
		Jπ = 𝔼st∼D,εt∼N[logπ(f(εt;st)|st)−Q(st,f(εt;st))]
		∇Jπ =∇log π + ([∇at log π(at|st) − ∇at Q(st,at)])∇f(εt;st)
		"""
		policy_loss = ((self.alpha * log_prob) - expected_new_q_value)
		utils.compute_stats(policy_loss, self.policy_loss)
		policy_loss = policy_loss.mean()

		# Regularization Loss
		mean_loss = 0.001 * mean.pow(2)
		std_loss = 0.001 * log_std.pow(2)
		utils.compute_stats(mean_loss, self.mean_loss)
		utils.compute_stats(std_loss, self.std_loss)
		mean_loss = mean_loss.mean()
		std_loss = std_loss.mean()


		policy_loss += mean_loss + std_loss

		self.critic_optim.zero_grad()
		q1_value_loss.backward()
		self.critic_optim.step()

		self.critic_optim.zero_grad()
		q2_value_loss.backward()
		self.critic_optim.step()

		if self.policy_type == "Gaussian":
			self.value_optim.zero_grad()
			value_loss.backward()
			self.value_optim.step()
		else:
			value_loss = torch.tensor(0.)

		self.policy_optim.zero_grad()
		policy_loss.backward()
		self.policy_optim.step()

		self.total_update += 1
		if self.agent_id == 0:
			self.tracker.update([self.q['mean'], self.q_loss['mean'], self.val['mean'], self.value_loss['mean']
								, self.policy_loss['mean'], self.mean_loss['mean'], self.std_loss['mean']], self.total_update)

		"""
		We update the target weights to match the current value function weights periodically
		Update target parameter after every n(args.target_update_interval) updates
		"""
		if updates % self.target_update_interval == 0 and self.policy_type == "Deterministic":
			utils.soft_update(self.critic_target, self.critic, self.tau)

		elif updates % self.target_update_interval == 0 and self.policy_type == "Gaussian":
			utils.soft_update(self.value_target, self.value, self.tau)
		return value_loss.item(), q1_value_loss.item(), q2_value_loss.item(), policy_loss.item()
示例#5
0
    def update_parameters(self,
                          state_batch,
                          next_state_batch,
                          action_batch,
                          reward_batch,
                          done_batch,
                          global_reward,
                          agent_id,
                          num_epoch=1,
                          **kwargs):
        """Runs a step of Bellman upodate and policy gradient using a batch of experiences



		 """

        if isinstance(state_batch, list):
            state_batch = torch.cat(state_batch)
            next_state_batch = torch.cat(next_state_batch)
            action_batch = torch.cat(action_batch)
            reward_batch = torch.cat(reward_batch).done_batch = torch.cat(
                done_batch)
            global_reward = torch.cat(global_reward)

        for _ in range(num_epoch):
            ########### CRITIC UPDATE ####################

            #Compute next q-val, next_v and target
            with torch.no_grad():

                #Policy Noise
                policy_noise = np.random.normal(
                    0, kwargs['policy_noise'],
                    (action_batch.size()[0], action_batch.size()[1]))
                policy_noise = torch.clamp(torch.Tensor(policy_noise),
                                           -kwargs['policy_noise_clip'],
                                           kwargs['policy_noise_clip'])

                #Compute next action_bacth
                next_action_batch = self.policy_target.clean_action(
                    next_state_batch, agent_id) + policy_noise.cuda(
                    ) if self.use_gpu else policy_noise
                next_action_batch = torch.clamp(next_action_batch, -1, 1)

                #Compute Q-val and value of next state masking by done
                q1, q2 = self.critic_target.forward(next_state_batch,
                                                    next_action_batch)
                q1 = (1 - done_batch) * q1
                q2 = (1 - done_batch) * q2

                #Select which q to use as next-q (depends on algo)
                if self.algo_name == 'TD3': next_q = torch.min(q1, q2)
                elif self.algo_name == 'DDPG': next_q = q1

                #Compute target q and target val
                target_q = reward_batch + (self.gamma * next_q)

            self.critic_optim.zero_grad()
            current_q1, current_q2 = self.critic.forward((state_batch),
                                                         (action_batch))
            utils.compute_stats(current_q1, self.q)

            dt = self.loss(current_q1, target_q)

            if self.algo_name == 'TD3':
                dt = dt + self.loss(current_q2, target_q)
            utils.compute_stats(dt, self.q_loss)
            dt.backward()

            self.critic_optim.step()
            self.num_critic_updates += 1

            #Delayed Actor Update
            if self.num_critic_updates % kwargs['policy_ups_freq'] == 0:

                actor_actions = self.policy.clean_action(state_batch, agent_id)
                Q1, Q2 = self.critic.forward(state_batch, actor_actions)

                # if self.args.use_advantage: policy_loss = -(Q1 - val)
                policy_loss = -Q1

                utils.compute_stats(-policy_loss, self.policy_loss)
                policy_loss = policy_loss.mean()

                self.policy_optim.zero_grad()

                policy_loss.backward(retain_graph=True)
                self.policy_optim.step()

            if self.num_critic_updates % kwargs['policy_ups_freq'] == 0:
                utils.soft_update(self.policy_target, self.policy, self.tau)
            utils.soft_update(self.critic_target, self.critic, self.tau)

            self.total_update += 1
            if self.agent_id == 0:
                self.tracker.update([
                    self.q['mean'], self.q_loss['mean'],
                    self.policy_loss['mean']
                ], self.total_update)
示例#6
0
    def update_parameters(self,
                          state_batch,
                          next_state_batch,
                          action_batch,
                          reward_batch,
                          done_batch,
                          global_reward,
                          agent_id,
                          num_epoch=1,
                          **kwargs):
        """Runs a step of Bellman upodate and policy gradient using a batch of experiences

			 Parameters:
				  state_batch (tensor): Current States
				  next_state_batch (tensor): Next States
				  action_batch (tensor): Actions
				  reward_batch (tensor): Rewards
				  done_batch (tensor): Done batch
				  num_epoch (int): Number of learning iteration to run with the same data

			 Returns:
				   None

		 """

        if isinstance(state_batch, list):
            state_batch = torch.cat(state_batch)
            next_state_batch = torch.cat(next_state_batch)
            action_batch = torch.cat(action_batch)
            reward_batch = torch.cat(reward_batch).done_batch = torch.cat(
                done_batch)
            global_reward = torch.cat(global_reward)
        batch_size = len(state_batch)

        for _ in range(num_epoch):
            ########### CRITIC UPDATE ####################

            #Compute next q-val, next_v and target
            with torch.no_grad():

                #Compute next action_bacth
                next_action_batch = torch.cat([
                    self.policy_target.clean_action(next_state_batch[:, id, :],
                                                    id)
                    for id in range(self.num_agents)
                ], 1)
                if self.algo_name == 'TD3':
                    # Policy Noise
                    policy_noise = np.random.normal(
                        0, kwargs['policy_noise'],
                        (action_batch.size()[0],
                         action_batch.size()[1] * action_batch.size()[2]))
                    policy_noise = torch.clamp(torch.Tensor(policy_noise),
                                               -kwargs['policy_noise_clip'],
                                               kwargs['policy_noise_clip'])
                    next_action_batch += policy_noise.cuda(
                    ) if self.use_gpu else policy_noise
                next_action_batch = torch.clamp(next_action_batch, -1, 1)

                #Compute Q-val and value of next state masking by done
                q1, q2 = self.critics_target[agent_id].forward(
                    next_state_batch.view(batch_size, -1), next_action_batch)
                q1 = (1 - done_batch) * q1
                q2 = (1 - done_batch) * q2
                #next_val = (1 - done_batch) * next_val

                #Select which q to use as next-q (depends on algo)
                if self.algo_name == 'TD3': next_q = torch.min(q1, q2)
                elif self.algo_name == 'DDPG': next_q = q1

                #Compute target q and target val
                target_q = reward_batch[:, agent_id].unsqueeze(1) + (
                    self.gamma * next_q)

            self.critic_optims[agent_id].zero_grad()
            current_q1, current_q2 = self.critics[agent_id].forward(
                (state_batch.view(batch_size, -1)),
                (action_batch.view(batch_size, -1)))
            utils.compute_stats(current_q1, self.q)

            dt = self.loss(current_q1, target_q)

            if self.algo_name == 'TD3':
                dt = dt + self.loss(current_q2, target_q)
            utils.compute_stats(dt, self.q_loss)
            dt.backward()

            self.critic_optims[agent_id].step()
            self.num_critic_updates += 1

            #Delayed Actor Update
            if self.num_critic_updates % kwargs[
                    'policy_ups_freq'] == 0 or self.algo_name == 'DDPG':

                agent_action = self.policy.clean_action(
                    state_batch[:, agent_id, :], agent_id)
                joint_action = action_batch.clone()
                joint_action[:, agent_id, :] = agent_action[:]

                Q1, Q2 = self.critics[agent_id].forward(
                    state_batch.view(batch_size, -1),
                    joint_action.view(batch_size, -1))
                policy_loss = -Q1

                utils.compute_stats(-policy_loss, self.policy_loss)
                policy_loss = policy_loss.mean()

                self.policy_optim.zero_grad()

                policy_loss.backward(retain_graph=True)
                self.policy_optim.step()

            if self.num_critic_updates % kwargs[
                    'policy_ups_freq'] == 0 or self.algo_name == 'DDPG':
                utils.soft_update(self.policy_target, self.policy, self.tau)
            for critic, critic_target in zip(self.critics,
                                             self.critics_target):
                utils.soft_update(critic_target, critic, self.tau)

            self.total_update += 1
            if self.agent_id == 0:
                self.tracker.update([
                    self.q['mean'], self.q_loss['mean'],
                    self.policy_loss['mean']
                ], self.total_update)