Exemplo n.º 1
0
class Agent:
    def __init__(self,
                 load_checkpoint,
                 checkpoint_file,
                 env,
                 n_states,
                 n_actions,
                 mem_size=10**6,
                 batch_size=256,
                 n_hid1=256,
                 n_hid2=256,
                 lr=3e-4,
                 gamma=0.99,
                 tau=5e-3,
                 reward_scale=2):

        self.load_checkpoint = load_checkpoint

        self.max_action = float(env.action_space.high[0])
        self.low_action = float(env.action_space.low[0])

        self.batch_size = batch_size
        self.gamma = gamma
        self.tau = tau
        self.reward_scale = reward_scale

        self.memory_counter = 0
        self.memory = ReplayMemory(mem_size, n_states, n_actions)

        self.actor = ActorNetwork(n_states,
                                  n_actions,
                                  n_hid1,
                                  n_hid2,
                                  self.max_action,
                                  lr,
                                  checkpoint_file,
                                  name='_actor')
        self.critic_1 = CriticNetwork(n_states,
                                      n_actions,
                                      n_hid1,
                                      n_hid2,
                                      lr,
                                      checkpoint_file,
                                      name='_crtic1')
        self.critic_2 = CriticNetwork(n_states,
                                      n_actions,
                                      n_hid1,
                                      n_hid2,
                                      lr,
                                      checkpoint_file,
                                      name='_crtic2')

        self.value_net = ValueNetwork(n_states,
                                      n_hid1,
                                      n_hid2,
                                      lr,
                                      checkpoint_file,
                                      name='_value')
        self.target_value_net = ValueNetwork(n_states,
                                             n_hid1,
                                             n_hid2,
                                             lr,
                                             checkpoint_file,
                                             name='_value_target')

        # tau=1 performs an exact copy of the networks to the respective targets
        # self.update_network_parameters(tau=1)
        self.update_network_parameters(self.value_net,
                                       self.target_value_net,
                                       tau=1)
        # self.update_network_parameters_phil(tau=1)

    def store_transition(self, obs, action, reward, obs_, done):
        self.memory.store_transition(obs, action, reward, obs_, done)

    def sample_transitions(self):
        state_batch, action_batch, reward_batch, new_state_batch, done_batch = self.memory.sample_buffer(
            self.batch_size)
        # no need to care about the device, it is the same for all class objects (cuda or cpu is the same despite the class)
        state_batch = torch.tensor(state_batch,
                                   dtype=torch.float).to(self.actor.device)
        action_batch = torch.tensor(action_batch,
                                    dtype=torch.float).to(self.actor.device)
        reward_batch = torch.tensor(reward_batch,
                                    dtype=torch.float).to(self.actor.device)
        new_state_batch = torch.tensor(new_state_batch,
                                       dtype=torch.float).to(self.actor.device)
        done_batch = torch.tensor(done_batch).to(self.actor.device)
        return state_batch, action_batch, reward_batch, new_state_batch, done_batch

    def update_network_parameters(self, network, target_network, tau=None):
        for par, target_par in zip(network.parameters(),
                                   target_network.parameters()):
            target_par.data.copy_(tau * par.data + (1 - tau) * target_par.data)

    def choose_action(self, obs):
        obs = torch.tensor([obs], dtype=torch.float).to(self.actor.device)
        actions, _ = self.actor.sample_normal(obs, reparametrize=False)
        return actions.cpu().detach().numpy()[0]

    def learn_phil(self):
        if self.memory.mem_counter < self.batch_size:
            return

        state, action, reward, new_state, done = \
            self.memory.sample_buffer(self.batch_size)

        reward = torch.tensor(reward,
                              dtype=torch.float).to(self.critic_1.device)
        done = torch.tensor(done).to(self.critic_1.device)
        state_ = torch.tensor(new_state,
                              dtype=torch.float).to(self.critic_1.device)
        state = torch.tensor(state, dtype=torch.float).to(self.critic_1.device)
        action = torch.tensor(action,
                              dtype=torch.float).to(self.critic_1.device)

        value = self.value_net(state).view(-1)
        value_ = self.target_value_net(state_).view(-1)
        value_[done] = 0.0

        actions, log_probs = self.actor.sample_normal(state,
                                                      reparametrize=False)
        # actions, log_probs = self.actor.sample_mvnormal(state, reparameterize=False)
        log_probs = log_probs.view(-1)
        q1_new_policy = self.critic_1.forward(state, actions)
        q2_new_policy = self.critic_2.forward(state, actions)
        critic_value = torch.min(q1_new_policy, q2_new_policy)
        critic_value = critic_value.view(-1)

        self.value_net.optimizer.zero_grad()
        value_target = critic_value - log_probs
        value_loss = 0.5 * (F.mse_loss(value, value_target))
        value_loss.backward(retain_graph=True)
        self.value_net.optimizer.step()

        actions, log_probs = self.actor.sample_normal(state,
                                                      reparametrize=True)
        # actions, log_probs = self.actor.sample_mvnormal(state, reparameterize=False)
        log_probs = log_probs.view(-1)
        q1_new_policy = self.critic_1.forward(state, actions)
        q2_new_policy = self.critic_2.forward(state, actions)
        critic_value = torch.min(q1_new_policy, q2_new_policy)
        critic_value = critic_value.view(-1)

        actor_loss = log_probs - critic_value
        actor_loss = torch.mean(actor_loss)
        self.actor.optimizer.zero_grad()
        actor_loss.backward(retain_graph=True)
        self.actor.optimizer.step()

        self.critic_1.optimizer.zero_grad()
        self.critic_2.optimizer.zero_grad()
        q_hat = self.reward_scale * reward + self.gamma * value_
        q1_old_policy = self.critic_1.forward(state, action).view(-1)
        q2_old_policy = self.critic_2.forward(state, action).view(-1)
        critic_1_loss = 0.5 * F.mse_loss(q1_old_policy, q_hat)
        critic_2_loss = 0.5 * F.mse_loss(q2_old_policy, q_hat)

        critic_loss = critic_1_loss + critic_2_loss
        critic_loss.backward()
        self.critic_1.optimizer.step()
        self.critic_2.optimizer.step()
        self.update_network_parameters(self.value_net, self.target_value_net,
                                       self.tau)
        # self.update_network_parameters_phil()

    def learn(self):
        if self.memory.mem_counter < self.batch_size:
            return

        state_batch, action_batch, reward_batch, new_state_batch, done_batch = self.sample_transitions(
        )
        # state_batch, action_batch, reward_batch, new_state_batch, done_batch = \
        #     self.memory.sample_buffer(self.batch_size)
        #
        # reward_batch = torch.tensor(reward_batch, dtype=torch.float).to(self.critic_1.device)
        # done_batch = torch.tensor(done_batch).to(self.critic_1.device)
        # new_state_batch = torch.tensor(new_state_batch, dtype=torch.float).to(self.critic_1.device)
        # state_batch = torch.tensor(state_batch, dtype=torch.float).to(self.critic_1.device)
        # action_batch = torch.tensor(action_batch, dtype=torch.float).to(self.critic_1.device)
        '''Compute Value Network loss'''
        self.value_net.optimizer.zero_grad()
        val = self.value_net(state_batch).view(-1)
        val_ = self.target_value_net(new_state_batch).view(-1)
        val_[done_batch] = 0.0

        actions, log_probs = self.actor.sample_normal(state_batch,
                                                      reparametrize=False)
        log_probs = log_probs.view(-1)
        q1 = self.critic_1(state_batch, actions)  # action_batch)
        q2 = self.critic_1(state_batch, actions)  # action_batch)
        q = torch.min(q1, q2).view(-1)
        value_target = q - log_probs
        value_loss = 0.5 * F.mse_loss(val, value_target)

        value_loss.backward(retain_graph=True)
        self.value_net.optimizer.step()
        # val = val - q + log_prob
        '''Compute Actor loss'''
        self.actor.optimizer.zero_grad()
        # here we need to reparametrize and thus use rsample to make the distribution differentiable
        # because the log prob of the chosen action will be part of our loss.
        actions, log_probs = self.actor.sample_normal(state_batch,
                                                      reparametrize=True)
        log_probs = log_probs.view(-1)
        q1 = self.critic_1(state_batch, actions)
        q2 = self.critic_2(state_batch, actions)
        q = torch.min(q1, q2).view(-1)
        actor_loss = log_probs - q
        actor_loss = torch.mean(actor_loss)

        actor_loss.backward(retain_graph=True)
        self.actor.optimizer.step()
        '''Compute Critic loss'''
        self.critic_1.optimizer.zero_grad()
        self.critic_2.optimizer.zero_grad()
        val_ = self.target_value_net(new_state_batch).view(
            -1)  # value for the critic update
        val_[done_batch] = 0.0
        q_hat = self.reward_scale * reward_batch + self.gamma * val_
        q1_old_policy = self.critic_1(state_batch, action_batch).view(-1)
        q2_old_policy = self.critic_2(state_batch, action_batch).view(-1)
        critic_1_loss = 0.5 * F.mse_loss(q1_old_policy, q_hat)
        critic_2_loss = 0.5 * F.mse_loss(q2_old_policy, q_hat)

        critic_loss = critic_1_loss + critic_2_loss
        critic_loss.backward()
        self.critic_1.optimizer.step()
        self.critic_2.optimizer.step()

        self.update_network_parameters(self.value_net, self.target_value_net,
                                       self.tau)
        # self.update_network_parameters_phil()

    def save_models(self):
        self.actor.save_checkpoint()
        self.critic_1.save_checkpoint()
        self.critic_2.save_checkpoint()
        self.value_net.save_checkpoint()
        self.target_value_net.save_checkpoint()

    def load_models(self):
        self.actor.load_checkpoint()
        self.critic_1.load_checkpoint()
        self.critic_2.load_checkpoint()
        self.value_net.load_checkpoint()
        self.target_value_net.load_checkpoint()

    def update_network_parameters_phil(self, tau=None):
        if tau is None:
            tau = self.tau

        target_value_params = self.target_value_net.named_parameters()
        value_params = self.value_net.named_parameters()

        target_value_state_dict = dict(target_value_params)
        value_state_dict = dict(value_params)

        for name in value_state_dict:
            value_state_dict[name] = tau*value_state_dict[name].clone() + \
                    (1-tau)*target_value_state_dict[name].clone()

        self.target_value_net.load_state_dict(value_state_dict)
Exemplo n.º 2
0
class Agent():
    def __init__(self,
                 load_checkpoint,
                 checkpoint_file,
                 env,
                 n_states,
                 n_actions,
                 update_actor_interval=2,
                 warmup=1000,
                 mem_size=10**6,
                 batch_size=100,
                 n_hid1=400,
                 n_hid2=300,
                 lr_alpha=1e-3,
                 lr_beta=1e-3,
                 gamma=0.99,
                 tau=5e-3,
                 noise_mean=0,
                 noise_sigma=0.1):

        self.load_checkpoint = load_checkpoint
        self.checkpoint_file = checkpoint_file
        # needed for clamping in the learn function
        self.env = env
        self.max_action = float(env.action_space.high[0])
        self.low_action = float(env.action_space.low[0])

        self.n_actions = n_actions
        # to keep track of how often we call "learn" function, for the actor network
        self.learn_step_counter = 0
        # to handle countdown to the end of the warmup period, incremented every time we call an action
        self.time_step = 0
        self.update_actor_interval = update_actor_interval
        self.warmup = warmup
        self.gamma = gamma
        self.tau = tau
        self.batch_size = batch_size
        self.noise_mean = noise_mean
        self.noise_sigma = noise_sigma

        self.actor = TD3ActorNetwork(n_states,
                                     n_actions,
                                     n_hid1,
                                     n_hid2,
                                     lr_alpha,
                                     checkpoint_file,
                                     name='actor')
        self.target_actor = TD3ActorNetwork(n_states,
                                            n_actions,
                                            n_hid1,
                                            n_hid2,
                                            lr_alpha,
                                            checkpoint_file,
                                            name='target_actor')

        self.critic_1 = TD3CriticNetwork(n_states,
                                         n_actions,
                                         n_hid1,
                                         n_hid2,
                                         lr_beta,
                                         checkpoint_file,
                                         name='critic_1')
        self.critic_2 = TD3CriticNetwork(n_states,
                                         n_actions,
                                         n_hid1,
                                         n_hid2,
                                         lr_beta,
                                         checkpoint_file,
                                         name='critic_2')
        self.target_critic_1 = TD3CriticNetwork(n_states,
                                                n_actions,
                                                n_hid1,
                                                n_hid2,
                                                lr_beta,
                                                checkpoint_file,
                                                name='target_critic_1')
        self.target_critic_2 = TD3CriticNetwork(n_states,
                                                n_actions,
                                                n_hid1,
                                                n_hid2,
                                                lr_beta,
                                                checkpoint_file,
                                                name='target_critic_2')

        self.memory = ReplayMemory(mem_size, n_states, n_actions)

        # tau=1 perform an exact copy of the networks to the respective targets
        # self.update_network_parameters(tau=1)
        self.update_network_parameters(self.actor, self.target_actor, tau=1)
        self.update_network_parameters(self.critic_1,
                                       self.target_critic_1,
                                       tau=1)
        self.update_network_parameters(self.critic_2,
                                       self.target_critic_2,
                                       tau=1)

    def choose_action(self, obs):
        if self.time_step < self.warmup:
            self.time_step += 1
            action = torch.tensor(self.env.action_space.sample())
        else:
            obs = torch.tensor(obs, dtype=torch.float).to(self.actor.device)
            action = self.actor(obs)

            # exploratory noise, scaled wrt action scale (max_action)
            noise = torch.tensor(
                np.random.normal(self.noise_mean,
                                 self.noise_sigma * self.max_action,
                                 size=self.n_actions)).to(self.actor.device)
            action += noise
        action = torch.clamp(action, self.low_action, self.max_action)
        return action.cpu().detach().numpy()

    def choose_action_eval(self, obs):
        obs = torch.tensor(obs, dtype=torch.float).to(self.actor.device)
        action = self.actor(obs)
        action = torch.clamp(action, self.low_action, self.max_action)
        return action.cpu().detach().numpy()

    def store_transition(self, obs, action, reward, obs_, done):
        self.memory.store_transition(obs, action, reward, obs_, done)

    def sample_transitions(self):
        state_batch, action_batch, reward_batch, new_state_batch, done_batch = self.memory.sample_buffer(
            self.batch_size)
        # no need to care about the device, it is the same for all class objects (cuda or cpu is the same despite the class)
        state_batch = torch.tensor(state_batch,
                                   dtype=torch.float).to(self.actor.device)
        action_batch = torch.tensor(action_batch,
                                    dtype=torch.float).to(self.actor.device)
        reward_batch = torch.tensor(reward_batch,
                                    dtype=torch.float).to(self.actor.device)
        new_state_batch = torch.tensor(new_state_batch,
                                       dtype=torch.float).to(self.actor.device)
        done_batch = torch.tensor(done_batch).to(self.actor.device)
        return state_batch, action_batch, reward_batch, new_state_batch, done_batch

    def __copy_param(self, net_param_1, net_param_2, tau):
        # a.copy_(b) reads content from b and copy it to a
        for par, target_par in zip(net_param_1, net_param_2):
            #with torch.no_grad():
            val_to_copy = tau * par.weight + (1 - tau) * target_par.weight
            target_par.weight.copy_(val_to_copy)
            if target_par.bias is not None:
                val_to_copy = tau * par.bias + (1 - tau) * target_par.bias
                target_par.bias.copy_(val_to_copy)

    def update_network_parameters(self, network, target_network, tau=None):
        for par, target_par in zip(network.parameters(),
                                   target_network.parameters()):
            target_par.data.copy_(tau * par.data + (1 - tau) * target_par.data)

        #
        # # TODO: Controlla equivalenza con metodo Phil
        # # during the class initialization we call this method with tau=1, to perform an exact copy of the nets to targets
        # # then when we call this without specifying tau, we use the field stored
        # if tau is None:
        #     tau = self.tau
        #
        # actor_params = self.actor.children()
        # target_actor_params = self.target_actor.children()
        # self.__copy_param(actor_params, target_actor_params, tau)
        #
        # critic_params1 = self.critic_1.children()
        # target_critic_1_params = self.target_critic_1.children()
        # self.__copy_param(critic_params1, target_critic_1_params, tau)
        #
        # critic_params2 = self.critic_2.children()
        # target_critic_2_params = self.target_critic_2.children()
        # self.__copy_param(critic_params2, target_critic_2_params, tau)

    def learn(self):
        self.learn_step_counter += 1

        # deal with the situation in which we still not have filled the memory to batch size
        if self.memory.mem_counter < self.batch_size:
            return
        state_batch, action_batch, reward_batch, new_state_batch, done_batch = self.sample_transitions(
        )
        # +- 0.5 as per paper. To be tested if min and max actions are not equal (e.g. -2 and 1)
        noise = torch.tensor(
            np.clip(
                np.random.normal(self.noise_mean, 0.2, size=self.n_actions),
                -0.5, 0.5)).to(self.actor.device)
        target_next_action = torch.clamp(
            self.target_actor(new_state_batch) + noise, self.low_action,
            self.max_action)

        target_q1_ = self.target_critic_1(new_state_batch, target_next_action)
        target_q2_ = self.target_critic_1(new_state_batch, target_next_action)
        target_q_ = torch.min(
            target_q1_,
            target_q2_)  # take the min q_vale for every element in the batch
        target_q_[done_batch] = 0.0
        target = target_q_.view(-1)  # probably not needed
        target = reward_batch + self.gamma * target  #_q
        target = target.view(self.batch_size, 1)  # probably not needed

        q_val1 = self.critic_1(state_batch, action_batch)
        q_val2 = self.critic_1(state_batch, action_batch)

        critic_loss1 = F.mse_loss(q_val1, target)
        critic_loss2 = F.mse_loss(q_val2, target)
        critic_loss = critic_loss1 + critic_loss2

        self.critic_1.optimizer.zero_grad()
        self.critic_2.optimizer.zero_grad()
        critic_loss.backward()
        #critic_loss1.backward()
        #critic_loss2.backward()

        self.critic_1.optimizer.step()
        self.critic_2.optimizer.step()

        if self.learn_step_counter % self.update_actor_interval:
            action = self.actor(state_batch)
            # compute actor loss proportional to the estimated value from q1 given state, action pairs, where the action
            # is recomputed using the new policy
            actor_loss = -torch.mean(self.critic_1(state_batch, action))

            self.actor.optimizer.zero_grad()
            actor_loss.backward()
            self.actor.optimizer.step()

            self.update_network_parameters(self.actor, self.target_actor,
                                           self.tau)
            self.update_network_parameters(self.critic_1, self.target_critic_1,
                                           self.tau)
            self.update_network_parameters(self.critic_2, self.target_critic_2,
                                           self.tau)

    def save_models(self):
        self.actor.save_checkpoint()
        self.target_actor.save_checkpoint()
        self.critic_1.save_checkpoint()
        self.critic_2.save_checkpoint()
        self.target_critic_1.save_checkpoint()
        self.target_critic_2.save_checkpoint()

    def load_models(self):
        self.actor.load_checkpoint()
        self.target_actor.load_checkpoint()
        self.critic_1.load_checkpoint()
        self.critic_2.load_checkpoint()
        self.target_critic_1.load_checkpoint()
        self.target_critic_2.load_checkpoint()
class DDPGAgent():
    def __init__(self,
                 load_checkpoint,
                 n_states,
                 n_actions,
                 checkpoint_file,
                 mem_size=10**6,
                 batch_size=64,
                 n_hid1=400,
                 n_hid2=300,
                 alpha=1e-4,
                 beta=1e-3,
                 gamma=0.99,
                 tau=0.99):
        self.batch_size = batch_size
        self.gamma = gamma
        self.tau = tau

        self.actor = ActorNetwork(n_states,
                                  n_actions,
                                  n_hid1,
                                  n_hid2,
                                  alpha,
                                  checkpoint_file,
                                  name='actor')
        self.critic = CriticNetwork(n_states,
                                    n_actions,
                                    n_hid1,
                                    n_hid2,
                                    beta,
                                    checkpoint_file,
                                    name='critic')

        self.actor_target = ActorNetwork(n_states,
                                         n_actions,
                                         n_hid1,
                                         n_hid2,
                                         alpha,
                                         checkpoint_file,
                                         name='actor_target')
        self.critic_target = CriticNetwork(n_states,
                                           n_actions,
                                           n_hid1,
                                           n_hid2,
                                           beta,
                                           checkpoint_file,
                                           name='critic_target')

        self.noise = OUActionNoise(mu=np.zeros(n_actions))
        self.memory = ReplayMemory(mem_size, n_states, n_actions)
        self.update_network_parameters_phil(tau=1)
        if load_checkpoint:
            self.actor.eval()
        self.load_checkpoint = load_checkpoint

    def reset_noise(self):
        self.noise.reset()

    def __copy_param(self, net_param_1, net_param_2, tau):
        # a.copy_(b) reads content from b and copy it to a
        for par, target_par in zip(net_param_1, net_param_2):
            with torch.no_grad():
                val_to_copy = tau * par.weight + (1 - tau) * target_par.weight
                target_par.weight.copy_(val_to_copy)
                if target_par.bias is not None:
                    val_to_copy = tau * par.bias + (1 - tau) * target_par.bias
                    target_par.bias.copy_(val_to_copy)

    def update_network_parameters(self, tau=None):
        # TODO: Controlla equivalenza con metodo Phil
        # during the class initialization we call this method with tau=1, to perform an exact copy of the nets to targets
        # then when we call this without specifying tau, we use the field stored
        if tau is None:
            tau = self.tau

        actor_params = self.actor.children()
        actor_target_params = self.actor_target.children()
        self.__copy_param(actor_params, actor_target_params, tau)

        critic_params = self.critic.children()
        critic_target_params = self.critic_target.children()
        self.__copy_param(critic_params, critic_target_params, tau)

    def choose_action(self, obs):
        # when using layer norm, we do not want to calculate statistics for the forward propagation. Not needed
        # if using batchnorm or dropout
        self.actor.eval()
        obs = torch.tensor(obs, dtype=torch.float).to(self.actor.device)
        # compute actions
        mu = self.actor(obs)
        # add some random noise for exploration
        mu_prime = mu
        if not self.load_checkpoint:
            mu_prime = mu + torch.tensor(self.noise(), dtype=torch.float).to(
                self.actor.device)
            self.actor.train()
        return mu_prime.cpu().detach().numpy()

    def store_transitions(self, obs, action, reward, obs_, done):
        self.memory.store_transition(obs, action, reward, obs_, done)

    def sample_transitions(self):
        state_batch, action_batch, reward_batch, new_state_batch, done_batch = self.memory.sample_buffer(
            self.batch_size)
        # no need to care about the device, it is the same for all class objects (cuda or cpu is the same despite the class)
        state_batch = torch.tensor(state_batch,
                                   dtype=torch.float).to(self.actor.device)
        action_batch = torch.tensor(action_batch,
                                    dtype=torch.float).to(self.actor.device)
        reward_batch = torch.tensor(reward_batch,
                                    dtype=torch.float).to(self.actor.device)
        new_state_batch = torch.tensor(new_state_batch,
                                       dtype=torch.float).to(self.actor.device)
        done_batch = torch.tensor(done_batch).to(self.actor.device)
        return state_batch, action_batch, reward_batch, new_state_batch, done_batch

    def save_models(self):
        self.actor.save_checkpoint()
        self.actor_target.save_checkpoint()
        self.critic.save_checkpoint()
        self.critic_target.save_checkpoint()

    def load_models(self):
        self.actor.load_checkpoint()
        self.actor_target.load_checkpoint()
        self.critic.load_checkpoint()
        self.critic_target.load_checkpoint()

    def learn(self):
        # deal with the situation in which we still not have filled the memory to batch size
        if self.memory.mem_counter < self.batch_size:
            return
        state_batch, action_batch, reward_batch, new_state_batch, done_batch = self.sample_transitions(
        )
        ''' compute actor_target actions and critic_target values, then use obtained values to compute target y_i '''
        target_actions = self.actor_target(
            new_state_batch
        )  #  + torch.tensor(self.noise(), dtype=torch.float).to(self.actor.device)
        target_critic_value_ = self.critic_target(new_state_batch,
                                                  target_actions)
        # target_critic_value_next[done_batch==1] = 0.0  # if done_batch is integer valued
        target_critic_value_[
            done_batch] = 0.0  # if done_batch is bool -- see if it works this way
        target_critic_value_ = target_critic_value_.view(
            -1)  # necessary for operations on matching shapes
        target = reward_batch + self.gamma * target_critic_value_
        target = target.view(self.batch_size, 1)
        ''' zero out gradients '''
        self.actor.optimizer.zero_grad()
        self.critic.optimizer.zero_grad()
        ''' compute critic loss '''
        critic_value = self.critic(state_batch, action_batch)
        critic_loss = F.mse_loss(target, critic_value)
        ''' compute actor loss'''
        # cannot directly use critic value, because it is evaluating a certain (s,a) pair.
        # The formula given in the paper - it appears that - wants to use critic to evaluate
        # the actions produced by an updated actor, given the state
        # actor_loss = torch.mean(critic_value)
        actor_loss = -self.critic(state_batch, self.actor(state_batch))
        actor_loss = torch.mean(actor_loss)

        critic_loss.backward()
        actor_loss.backward()

        self.actor.optimizer.step()
        self.critic.optimizer.step()

        self.update_network_parameters_phil()

    def __copy_params_phil(self, net_a, net_b, tau):
        net_a_params = net_a.named_parameters()
        net_b_params = net_b.named_parameters()
        net_a_state_dict = dict(net_a_params)
        net_b_state_dict = dict(net_b_params)
        for name in net_a_state_dict:
            net_a_state_dict[name] = tau * net_a_state_dict[name].clone() + (
                1 - tau) * net_b_state_dict[name].clone()
        return net_a_state_dict

    def update_network_parameters_phil(self, tau=None):
        if tau is None:
            tau = self.tau

        updated_actor_state_dict = self.__copy_params_phil(
            self.actor, self.actor_target, tau)
        updated_critic_state_dict = self.__copy_params_phil(
            self.critic, self.critic_target, tau)

        self.actor_target.load_state_dict(updated_actor_state_dict)
        self.critic_target.load_state_dict(updated_critic_state_dict)