class Agent():
    def __init__(self,
                 alpha,
                 beta,
                 input_dims,
                 tau,
                 n_actions,
                 gamma=0.99,
                 max_size=50000,
                 fc1_dims=400,
                 fc2_dims=300,
                 batch_size=32):
        self.gamma = gamma
        self.tau = tau
        self.batch_size = batch_size
        self.alpha = alpha
        self.beta = beta
        self.memory = ReplayBuffer(max_size, input_dims, n_actions)
        self.noise = OUActionNoise(mu=np.zeros(n_actions))
        self.actor = ActorNetwork(alpha,
                                  input_dims,
                                  fc1_dims,
                                  fc2_dims,
                                  n_actions=n_actions,
                                  name='actor')
        self.critic = CriticNetwork(beta,
                                    input_dims,
                                    fc1_dims,
                                    fc2_dims,
                                    n_actions=n_actions,
                                    name='critic')

        self.target_actor = ActorNetwork(alpha,
                                         input_dims,
                                         fc1_dims,
                                         fc2_dims,
                                         n_actions=n_actions,
                                         name='target_actor')
        self.target_critic = CriticNetwork(beta,
                                           input_dims,
                                           fc1_dims,
                                           fc2_dims,
                                           n_actions=n_actions,
                                           name='target_critic')

        self.update_network_parameters(
            tau=1)  # for the first time target_actor and actor are same

    def choose_action(self, observation):
        self.actor.eval(
        )  # we are setting our actor network to eval mode because we have batch normalization layer
        # and we dont want to calculate statistics for that layer at this step
        state = T.tensor([observation], dtype=T.float).to(self.actor.device)
        mu = self.actor.forward(state).to(self.actor.device)
        mu_prime = mu + T.tensor(self.noise(), dtype=T.float).to(
            self.actor.device)
        self.actor.train()
        return mu_prime.cpu().detach().numpy()[0]

    def remember(self, state, action, reward, state_, done):
        self.memory.store_transition(state, action, reward, state_, done)

    def save_models(self):
        self.actor.save_checkpoint()
        self.target_actor.save_checkpoint()
        self.critic.save_checkpoint()
        self.target_critic.save_checkpoint()

    def load_models(self):
        self.actor.load_checkpoint()
        self.target_actor.load_checkpoint()
        self.critic.load_checkpoint()
        self.target_critic.load_checkpoint()

    def learn(self):
        if self.memory.mem_cntr < self.batch_size:
            return

        states, actions, rewards, states_, done = self.memory.sample_buffer(
            self.batch_size)

        states = T.tensor(states, dtype=T.float).to(self.actor.device)
        states_ = T.tensor(states_, dtype=T.float).to(self.actor.device)
        actions = T.tensor(actions, dtype=T.float).to(self.actor.device)
        rewards = T.tensor(rewards, dtype=T.float).to(self.actor.device)
        done = T.tensor(done).to(self.actor.device)

        target_actions = self.target_actor.forward(states_)
        critic_value_ = self.target_critic.forward(states_, target_actions)
        critic_value = self.critic.forward(states, actions)

        critic_value_[done] = 0.0
        critic_value_ = critic_value_.view(-1)

        target = rewards + self.gamma * critic_value_
        target = target.view(self.batch_size, 1)

        self.critic.optimizer.zero_grad()
        critic_loss = F.mse_loss(target, critic_value)
        critic_loss.backward()
        self.critic.optimizer.step()

        self.actor.optimizer.zero_grad()
        actor_loss = -self.critic.forward(states, self.actor.forward(states))
        actor_loss = T.mean(actor_loss)
        actor_loss.backward()
        self.actor.optimizer.step()

        self.update_network_parameters(
        )  # sending tau None so that local tau variable there takes the value of class tau variable

    def update_network_parameters(self, tau=None):
        if tau is None:
            tau = self.tau

        actor_params = self.actor.named_parameters()
        critic_params = self.critic.named_parameters()
        target_actor_params = self.target_actor.named_parameters()
        target_critic_params = self.target_critic.named_parameters()

        critic_state_dict = dict(critic_params)
        actor_state_dict = dict(actor_params)
        target_critic_state_dict = dict(target_critic_params)
        target_actor_state_dict = dict(target_actor_params)

        for name in critic_state_dict:
            critic_state_dict[name] = tau * critic_state_dict[name].clone() + (
                1 - tau) * target_critic_state_dict[name].clone()

        for name in actor_state_dict:
            actor_state_dict[name] = tau * actor_state_dict[name].clone() + (
                1 - tau) * target_actor_state_dict[name].clone()

        self.target_critic.load_state_dict(critic_state_dict)
        self.target_actor.load_state_dict(actor_state_dict)
예제 #2
0
class Agent:
    def __init__(self,
                 actor_dims,
                 critic_dims,
                 n_actions,
                 n_agents,
                 agent_idx,
                 chkpt_dir,
                 alpha=0.01,
                 beta=0.01,
                 fc1=64,
                 fc2=64,
                 gamma=0.95,
                 tau=0.01):
        """

        Args:
            actor_dims:
            critic_dims:
            n_actions: number of actions
            n_agents:
            agent_idx: agent index
            chkpt_dir: checkpoint directory
            alpha: learning rate
            beta: learning rate
            fc1:
            fc2:
            gamma: discount factor
            tau: soft update parameter
        """
        self.gamma = gamma
        self.tau = tau
        self.n_actions = n_actions
        self.agent_name = 'agent_%s' % agent_idx
        # e.g., name = agent_1_actor
        self.actor = ActorNetwork(alpha,
                                  actor_dims,
                                  fc1,
                                  fc2,
                                  n_actions,
                                  chkpt_dir=chkpt_dir,
                                  name=self.agent_name + '_actor')
        self.critic = CriticNetwork(beta,
                                    critic_dims,
                                    fc1,
                                    fc2,
                                    n_agents,
                                    n_actions,
                                    chkpt_dir=chkpt_dir,
                                    name=self.agent_name + '_critic')
        self.target_actor = ActorNetwork(alpha,
                                         actor_dims,
                                         fc1,
                                         fc2,
                                         n_actions,
                                         chkpt_dir=chkpt_dir,
                                         name=self.agent_name +
                                         '_target_actor')
        self.target_critic = CriticNetwork(beta,
                                           critic_dims,
                                           fc1,
                                           fc2,
                                           n_agents,
                                           n_actions,
                                           chkpt_dir=chkpt_dir,
                                           name=self.agent_name +
                                           '_target_critic')

        # initially target networks and networks have the same parameters
        self.update_network_parameters(tau=1)

    def choose_action(self, observation):
        """

        Args:
            observation:

        Returns: action w.r.t. the current policy and exploration

        """
        state = T.tensor([observation], dtype=T.float).to(self.actor.device)
        # action of current policy
        actions = self.actor.forward(state)
        # exploration (0.1 is the parameter of the exploration)
        noise = 0.1 * T.rand(self.n_actions).to(self.actor.device)
        # print(f"action={action}, noise={noise}")
        action = actions + noise
        # action = actions

        return action.detach().cpu().numpy()[0]

    def update_network_parameters(self, tau=None):
        # use default tau if nothing is input
        if tau is None:
            tau = self.tau

        target_actor_params = self.target_actor.named_parameters()
        actor_params = self.actor.named_parameters()

        # soft update of target networks
        target_actor_state_dict = dict(target_actor_params)
        actor_state_dict = dict(actor_params)
        for name in actor_state_dict:
            actor_state_dict[name] = tau * actor_state_dict[name].clone() + \
                                     (1 - tau) * target_actor_state_dict[
                                         name].clone()

        self.target_actor.load_state_dict(actor_state_dict)

        target_critic_params = self.target_critic.named_parameters()
        critic_params = self.critic.named_parameters()

        target_critic_state_dict = dict(target_critic_params)
        critic_state_dict = dict(critic_params)
        for name in critic_state_dict:
            critic_state_dict[name] = tau * critic_state_dict[name].clone() + \
                                      (1 - tau) * target_critic_state_dict[
                                          name].clone()

        self.target_critic.load_state_dict(critic_state_dict)

    def save_models(self):
        self.actor.save_checkpoint()
        self.target_actor.save_checkpoint()
        self.critic.save_checkpoint()
        self.target_critic.save_checkpoint()

    def load_models(self):
        self.actor.load_checkpoint()
        self.target_actor.load_checkpoint()
        self.critic.load_checkpoint()
        self.target_critic.load_checkpoint()
예제 #3
0
class Agent:
    def __init__(self,
                 actor_dims,
                 critic_dims,
                 n_actions,
                 n_agents,
                 agent_idx,
                 chkpt_dir,
                 alpha=0.01,
                 beta=0.01,
                 fc1=64,
                 fc2=64,
                 gamma=0.95,
                 tau=0.01):
        self.gamma = gamma
        self.tau = tau
        self.n_actions = n_actions
        self.agent_name = 'agent_%s' % agent_idx
        self.actor = ActorNetwork(alpha,
                                  actor_dims,
                                  fc1,
                                  fc2,
                                  n_actions,
                                  chkpt_dir=chkpt_dir,
                                  name=self.agent_name + '_actor')
        self.critic = CriticNetwork(beta,
                                    critic_dims,
                                    fc1,
                                    fc2,
                                    n_agents,
                                    n_actions,
                                    chkpt_dir=chkpt_dir,
                                    name=self.agent_name + '_critic')
        self.target_actor = ActorNetwork(alpha,
                                         actor_dims,
                                         fc1,
                                         fc2,
                                         n_actions,
                                         chkpt_dir=chkpt_dir,
                                         name=self.agent_name +
                                         '_target_actor')
        self.target_critic = CriticNetwork(beta,
                                           critic_dims,
                                           fc1,
                                           fc2,
                                           n_agents,
                                           n_actions,
                                           chkpt_dir=chkpt_dir,
                                           name=self.agent_name +
                                           '_target_critic')

        self.update_network_parameters(tau=1)

    def choose_action(self, observation):
        state = T.tensor([observation], dtype=T.float).to(self.actor.device)
        actions = self.actor.forward(state)
        noise = T.rand(self.n_actions).to(self.actor.device)
        action = actions + noise

        return action.detach().cpu().numpy()[0]

    def update_network_parameters(self, tau=None):
        if tau is None:
            tau = self.tau

        target_actor_params = self.target_actor.named_parameters()
        actor_params = self.actor.named_parameters()

        target_actor_state_dict = dict(target_actor_params)
        actor_state_dict = dict(actor_params)
        for name in actor_state_dict:
            actor_state_dict[name] = tau*actor_state_dict[name].clone() + \
                    (1-tau)*target_actor_state_dict[name].clone()

        self.target_actor.load_state_dict(actor_state_dict)

        target_critic_params = self.target_critic.named_parameters()
        critic_params = self.critic.named_parameters()

        target_critic_state_dict = dict(target_critic_params)
        critic_state_dict = dict(critic_params)
        for name in critic_state_dict:
            critic_state_dict[name] = tau*critic_state_dict[name].clone() + \
                    (1-tau)*target_critic_state_dict[name].clone()

        self.target_critic.load_state_dict(critic_state_dict)

    def save_models(self):
        self.actor.save_checkpoint()
        self.target_actor.save_checkpoint()
        self.critic.save_checkpoint()
        self.target_critic.save_checkpoint()

    def load_models(self):
        self.actor.load_checkpoint()
        self.target_actor.load_checkpoint()
        self.critic.load_checkpoint()
        self.target_critic.load_checkpoint()
예제 #4
0
class Agent(object):
    def __init__(self,
                 alpha,
                 beta,
                 input_dims,
                 action_bound,
                 tau,
                 env,
                 gamma=0.99,
                 n_actions=2,
                 max_size=1000000,
                 layer1_size=400,
                 layer2_size=300,
                 batch_size=64):
        self.gamma = gamma
        self.tau = tau
        self.memory = ReplayBuffer(max_size, input_dims, n_actions)
        self.batch_size = batch_size
        self.action_bound = action_bound
        self.actor = ActorNetwork(alpha,
                                  input_dims,
                                  layer1_size,
                                  layer2_size,
                                  n_actions=n_actions,
                                  name='Actor')
        self.critic = CriticNetwork(beta,
                                    input_dims,
                                    layer1_size,
                                    layer2_size,
                                    n_actions=n_actions,
                                    name='Critic')

        self.target_actor = ActorNetwork(alpha,
                                         input_dims,
                                         layer1_size,
                                         layer2_size,
                                         n_actions=n_actions,
                                         name='TargetActor')
        self.target_critic = CriticNetwork(beta,
                                           input_dims,
                                           layer1_size,
                                           layer2_size,
                                           n_actions=n_actions,
                                           name='TargetCritic')

        self.noise = OUActionNoise(mu=np.zeros(n_actions))

        self.update_network_parameters(tau=1)

    def choose_action(self, observation):
        self.actor.eval()
        observation = T.tensor(observation,
                               dtype=T.float).to(self.actor.device)
        mu = self.actor.forward(observation).to(self.actor.device)
        mu_prime = mu + T.tensor(self.noise(), dtype=T.float).to(
            self.actor.device)
        self.actor.train()
        return (mu_prime * T.tensor(self.action_bound)).cpu().detach().numpy()

    def remember(self, state, action, reward, new_state, done):
        self.memory.store_transition(state, action, reward, new_state, done)

    def learn(self):
        if self.memory.mem_cntr < self.batch_size:
            return
        state, action, reward, new_state, done = \
                                      self.memory.sample_buffer(self.batch_size)

        reward = T.tensor(reward, dtype=T.float).to(self.critic.device)
        done = T.tensor(done).to(self.critic.device)
        new_state = T.tensor(new_state, dtype=T.float).to(self.critic.device)
        action = T.tensor(action, dtype=T.float).to(self.critic.device)
        state = T.tensor(state, dtype=T.float).to(self.critic.device)

        self.target_actor.eval()
        self.target_critic.eval()
        self.critic.eval()
        target_actions = self.target_actor.forward(new_state)
        critic_value_ = self.target_critic.forward(new_state, target_actions)
        critic_value = self.critic.forward(state, action)

        target = []
        for j in range(self.batch_size):
            target.append(reward[j] + self.gamma * critic_value_[j] * done[j])
        target = T.tensor(target).to(self.critic.device)
        target = target.view(self.batch_size, 1)

        self.critic.train()
        self.critic.optimizer.zero_grad()
        critic_loss = F.mse_loss(target, critic_value)
        critic_loss.backward()
        self.critic.optimizer.step()

        self.critic.eval()
        self.actor.optimizer.zero_grad()
        mu = self.actor.forward(state)
        self.actor.train()
        actor_loss = -self.critic.forward(state, mu)
        actor_loss = T.mean(actor_loss)
        actor_loss.backward()
        self.actor.optimizer.step()

        self.update_network_parameters()

    def update_network_parameters(self, tau=None):
        if tau is None:
            tau = self.tau

        actor_params = self.actor.named_parameters()
        critic_params = self.critic.named_parameters()
        target_actor_params = self.target_actor.named_parameters()
        target_critic_params = self.target_critic.named_parameters()

        critic_state_dict = dict(critic_params)
        actor_state_dict = dict(actor_params)
        target_critic_dict = dict(target_critic_params)
        target_actor_dict = dict(target_actor_params)

        for name in critic_state_dict:
            critic_state_dict[name] = tau*critic_state_dict[name].clone() + \
                                      (1-tau)*target_critic_dict[name].clone()

        self.target_critic.load_state_dict(critic_state_dict)

        for name in actor_state_dict:
            actor_state_dict[name] = tau*actor_state_dict[name].clone() + \
                                      (1-tau)*target_actor_dict[name].clone()
        self.target_actor.load_state_dict(actor_state_dict)
        """
        #Verify that the copy assignment worked correctly
        target_actor_params = self.target_actor.named_parameters()
        target_critic_params = self.target_critic.named_parameters()

        critic_state_dict = dict(target_critic_params)
        actor_state_dict = dict(target_actor_params)
        print('\nActor Networks', tau)
        for name, param in self.actor.named_parameters():
            print(name, T.equal(param, actor_state_dict[name]))
        print('\nCritic Networks', tau)
        for name, param in self.critic.named_parameters():
            print(name, T.equal(param, critic_state_dict[name]))
        input()
        """

    def save_models(self):
        self.actor.save_checkpoint()
        self.target_actor.save_checkpoint()
        self.critic.save_checkpoint()
        self.target_critic.save_checkpoint()

    def load_models(self):
        self.actor.load_checkpoint()
        self.target_actor.load_checkpoint()
        self.critic.load_checkpoint()
        self.target_critic.load_checkpoint()

    def check_actor_params(self):
        current_actor_params = self.actor.named_parameters()
        current_actor_dict = dict(current_actor_params)
        original_actor_dict = dict(self.original_actor.named_parameters())
        original_critic_dict = dict(self.original_critic.named_parameters())
        current_critic_params = self.critic.named_parameters()
        current_critic_dict = dict(current_critic_params)
        print('Checking Actor parameters')

        for param in current_actor_dict:
            print(
                param,
                T.equal(original_actor_dict[param], current_actor_dict[param]))
        print('Checking critic parameters')
        for param in current_critic_dict:
            print(
                param,
                T.equal(original_critic_dict[param],
                        current_critic_dict[param]))
        input()
예제 #5
0
class Agent():
    def __init__(self,
                 alpha,
                 beta,
                 input_dims,
                 tau,
                 env,
                 action_bound,
                 gamma=0.99,
                 update_actor_interval=2,
                 warmup=1000,
                 n_actions=2,
                 max_size=1000000,
                 layer1_size=400,
                 layer2_size=300,
                 batch_size=100,
                 noise=0.1):
        self.gamma = gamma
        self.tau = tau
        self.max_action = env.action_space.high
        self.min_action = env.action_space.low
        self.memory = ReplayBuffer(max_size, input_dims, n_actions)
        self.batch_size = batch_size
        self.learn_step_cntr = 0
        self.time_step = 0
        self.warmup = warmup
        self.n_actions = n_actions
        self.update_actor_iter = update_actor_interval
        self.action_bound = action_bound
        self.actor = ActorNetwork(alpha,
                                  input_dims,
                                  layer1_size,
                                  layer2_size,
                                  n_actions=n_actions,
                                  name='actor')

        self.critic_1 = CriticNetwork(beta,
                                      input_dims,
                                      layer1_size,
                                      layer2_size,
                                      n_actions=n_actions,
                                      name='critic_1')
        self.critic_2 = CriticNetwork(beta,
                                      input_dims,
                                      layer1_size,
                                      layer2_size,
                                      n_actions=n_actions,
                                      name='critic_2')

        self.target_actor = ActorNetwork(alpha,
                                         input_dims,
                                         layer1_size,
                                         layer2_size,
                                         n_actions=n_actions,
                                         name='target_actor')
        self.target_critic_1 = CriticNetwork(beta,
                                             input_dims,
                                             layer1_size,
                                             layer2_size,
                                             n_actions=n_actions,
                                             name='target_critic_1')
        self.target_critic_2 = CriticNetwork(beta,
                                             input_dims,
                                             layer1_size,
                                             layer2_size,
                                             n_actions=n_actions,
                                             name='target_critic_2')

        self.noise = noise
        self.update_network_parameters(tau=1)

    def choose_action(self, observation):
        if self.time_step < self.warmup:
            mu = T.tensor(
                np.random.normal(scale=self.noise, size=(self.n_actions, )))
        else:
            state = T.tensor(observation, dtype=T.float).to(self.actor.device)
            mu = self.actor.forward(state).to(self.actor.device)
        mu_prime = mu + T.tensor(np.random.normal(scale=self.noise),
                                 dtype=T.float).to(self.actor.device)

        mu_prime = T.clamp(mu_prime, self.min_action[0], self.max_action[0])
        self.time_step += 1

        return (mu_prime * T.tensor(self.action_bound)).cpu().detach().numpy()

    def remember(self, state, action, reward, new_state, done):
        self.memory.store_transition(state, action, reward, new_state, done)

    def learn(self):
        if self.memory.mem_cntr < self.batch_size:
            return
        state, action, reward, new_state, done = \
                self.memory.sample_buffer(self.batch_size)

        reward = T.tensor(reward, dtype=T.float).to(self.critic_1.device)
        done = T.tensor(done).to(self.critic_1.device)
        state_ = T.tensor(new_state, dtype=T.float).to(self.critic_1.device)
        state = T.tensor(state, dtype=T.float).to(self.critic_1.device)
        action = T.tensor(action, dtype=T.float).to(self.critic_1.device)

        target_actions = self.target_actor.forward(state_)
        target_actions = target_actions + \
                T.clamp(T.tensor(np.random.normal(scale=0.2)), -0.5, 0.5)
        target_actions = T.clamp(target_actions, self.min_action[0],
                                 self.max_action[0])

        q1_ = self.target_critic_1.forward(state_, target_actions)
        q2_ = self.target_critic_2.forward(state_, target_actions)

        q1 = self.critic_1.forward(state, action)
        q2 = self.critic_2.forward(state, action)

        q1_[done] = 0.0
        q2_[done] = 0.0

        q1_ = q1_.view(-1)
        q2_ = q2_.view(-1)

        critic_value_ = T.min(q1_, q2_)

        target = reward + self.gamma * critic_value_
        target = target.view(self.batch_size, 1)

        self.critic_1.optimizer.zero_grad()
        self.critic_2.optimizer.zero_grad()

        q1_loss = F.mse_loss(target, q1)
        q2_loss = F.mse_loss(target, q2)
        critic_loss = q1_loss + q2_loss
        critic_loss.backward()
        self.critic_1.optimizer.step()
        self.critic_2.optimizer.step()

        self.learn_step_cntr += 1

        if self.learn_step_cntr % self.update_actor_iter != 0:
            return

        self.actor.optimizer.zero_grad()
        actor_q1_loss = self.critic_1.forward(state, self.actor.forward(state))
        actor_loss = -T.mean(actor_q1_loss)
        actor_loss.backward()
        self.actor.optimizer.step()

        self.update_network_parameters()

    def update_network_parameters(self, tau=None):
        if tau is None:
            tau = self.tau

        actor_params = self.actor.named_parameters()
        critic_1_params = self.critic_1.named_parameters()
        critic_2_params = self.critic_2.named_parameters()
        target_actor_params = self.target_actor.named_parameters()
        target_critic_1_params = self.target_critic_1.named_parameters()
        target_critic_2_params = self.target_critic_2.named_parameters()

        critic_1 = dict(critic_1_params)
        critic_2 = dict(critic_2_params)
        actor = dict(actor_params)
        target_actor = dict(target_actor_params)
        target_critic_1 = dict(target_critic_1_params)
        target_critic_2 = dict(target_critic_2_params)

        for name in critic_1:
            critic_1[name] = tau*critic_1[name].clone() + \
                    (1-tau)*target_critic_1[name].clone()

        for name in critic_2:
            critic_2[name] = tau*critic_2[name].clone() + \
                    (1-tau)*target_critic_2[name].clone()

        for name in actor:
            actor[name] = tau*actor[name].clone() + \
                    (1-tau)*target_actor[name].clone()

        self.target_critic_1.load_state_dict(critic_1)
        self.target_critic_2.load_state_dict(critic_2)
        self.target_actor.load_state_dict(actor)

    def save_models(self):
        self.actor.save_checkpoint()
        self.target_actor.save_checkpoint()
        self.critic_1.save_checkpoint()
        self.critic_2.save_checkpoint()
        self.target_critic_1.save_checkpoint()
        self.target_critic_2.save_checkpoint()

    def load_models(self):
        self.actor.load_checkpoint()
        self.target_actor.load_checkpoint()
        self.critic_1.load_checkpoint()
        self.critic_2.load_checkpoint()
        self.target_critic_1.load_checkpoint()
        self.target_critic_2.load_checkpoint()
예제 #6
0
파일: SAC.py 프로젝트: bhargavCSSE/adv-RL
class Agent():
    def __init__(self, alpha=0.0003, beta= 0.0003, input_dims=[8], env=None, 
                gamma=0.99, n_actions=2, max_size=1000000, tau=0.005, 
                ent_alpha = 0.0001, batch_size=256, reward_scale=2, 
                layer1_size=256, layer2_size=256, chkpt_dir='tmp/sac'):
        self.gamma = gamma
        self.tau = tau
        self.memory = ReplayBuffer(max_size, input_dims, n_actions)
        self.batch_size = batch_size
        self.n_actions = n_actions
        self.ent_alpha = ent_alpha
        self.reward_scale = reward_scale

        self.actor = ActorNetwork(alpha, input_dims, n_actions=n_actions, 
                                fc1_dims=layer1_size, fc2_dims=layer2_size ,
                                name='actor', chkpt_dir=chkpt_dir)

        self.critic_1 = CriticNetwork(beta, input_dims, n_actions=n_actions, 
                                fc1_dims=layer1_size, fc2_dims=layer2_size ,name='critic_1',
                                chkpt_dir=chkpt_dir)
        self.critic_2 = CriticNetwork(beta, input_dims, n_actions=n_actions, 
                                fc1_dims=layer1_size, fc2_dims=layer2_size ,name='critic_2',
                                chkpt_dir=chkpt_dir)
        self.target_critic_1 = CriticNetwork(beta, input_dims, n_actions=n_actions, 
                                fc1_dims=layer1_size, fc2_dims=layer2_size ,name='target_critic_1',
                                chkpt_dir=chkpt_dir)
        self.target_critic_2 = CriticNetwork(beta, input_dims, n_actions=n_actions, 
                                fc1_dims=layer1_size, fc2_dims=layer2_size ,name='target_critic_2',
                                chkpt_dir=chkpt_dir)


        self.update_network_parameters(tau=1)

    def choose_actions(self, observation, learn_mode=False):
        if not learn_mode:
            state = T.Tensor([observation]).to(self.actor.device)
        else:
            state = observation

        action_probs = self.actor.forward(state)
        max_probability_action = T.argmax(action_probs, dim=-1)
        action_distribution = Categorical(action_probs)
        action = action_distribution.sample().cpu().detach().numpy()[0]

        z = action_probs == 0.0
        z = z.float()*1e-8
        log_probs = T.log(action_probs + z)

        return action, action_probs, log_probs, max_probability_action

    def remember(self, state, action, reward, new_state, done):
        self.memory.store_transition(state, action, reward, new_state, done)
    
    def update_network_parameters(self, tau=None):
        if tau is None:
            tau = self.tau
        
        target_critic_1_params = self.target_critic_1.named_parameters()
        target_critic_2_params = self.target_critic_2.named_parameters()
        critic_1_params = self.critic_1.named_parameters()
        critic_2_params = self.critic_2.named_parameters()

        target_critic_1_state_dict = dict(target_critic_1_params)
        critic_1_state_dict = dict(critic_1_params)

        target_critic_2_state_dict = dict(target_critic_2_params)
        critic_2_state_dict = dict(critic_2_params)

        for name in critic_1_state_dict:
            critic_1_state_dict[name] = tau*critic_1_state_dict[name].clone() + (1-tau)*target_critic_1_state_dict[name].clone()
        
        for name in critic_2_state_dict:
            critic_2_state_dict[name] = tau*critic_2_state_dict[name].clone() + (1-tau)*target_critic_2_state_dict[name].clone()
        
        self.target_critic_1.load_state_dict(critic_1_state_dict)
        self.target_critic_2.load_state_dict(critic_2_state_dict)
    
    def save_models(self):
        # print('.....saving models.....')
        self.actor.save_checkpoint()
        self.critic_1.save_checkpoint()
        self.critic_2.save_checkpoint()
        self.target_critic_1.save_checkpoint()
        self.target_critic_2.save_checkpoint()
    
    def load_models(self):
        print('.....loading models.....')
        self.actor.load_checkpoint()
        self.critic_1.load_checkpoint()
        self.critic_2.load_checkpoint()
        self.target_critic_1.load_checkpoint()
        self.target_critic_2.load_checkpoint()
    
    def learn(self):
        if self.memory.mem_cntr < self.batch_size:
            return 10, 10, 10, 10, 10
        
        state, action, reward, next_state, done = self.memory.sample_buffer(self.batch_size)

        reward = T.tensor(reward, dtype=T.float).to(self.actor.device)
        done = T.tensor(done, dtype=T.float).to(self.actor.device)
        state_ = T.tensor(next_state, dtype=T.float).to(self.actor.device)
        state = T.tensor(state, dtype=T.float).to(self.actor.device)
        action = T.tensor(action, dtype=T.float).to(self.actor.device)
        
        # Critics Learning

        with T.no_grad():
            action_, probs, log_probs, max_action = self.choose_actions(state_, learn_mode=True)
            qf1_target_ = self.target_critic_1(state_)
            qf2_target_ = self.target_critic_2(state_)
            min_qf_target_ = probs*(T.min(qf1_target_, qf2_target_) - self.ent_alpha*log_probs)
            min_qf_target_ = min_qf_target_.sum(dim=1).view(-1)
            next_q_value = reward + (1.0-done)*self.gamma*min_qf_target_
        
        action = action.view(self.batch_size, 1)
        qf1 = self.critic_1(state).gather(1, action.long())
        qf2 = self.critic_2(state).gather(1, action.long())

        self.critic_1.optimizer.zero_grad()
        qf1_loss = F.mse_loss(qf1, next_q_value)
        qf1_loss.backward(retain_graph=False)
        self.critic_1.optimizer.step()

        self.critic_2.optimizer.zero_grad()
        qf2_loss = F.mse_loss(qf2, next_q_value)
        qf2_loss.backward(retain_graph=False)
        self.critic_2.optimizer.step()

        self.update_network_parameters()

        # Actor loss
        qf1_pi = self.critic_1(state)
        qf2_pi = self.critic_2(state)
        min_qf_pi = T.min(qf1_pi, qf2_pi)
        inside_term = self.ent_alpha*log_probs - min_qf_pi
        actor_loss = (probs*inside_term).sum(dim=1).mean()
        self.actor.optimizer.zero_grad()
        actor_loss.backward(retain_graph=False)
        self.actor.optimizer.step()

        
    def compute_grads(self):
        if self.memory.mem_cntr < self.batch_size:
            return False
        
        state, action, reward, next_state, done = self.memory.sample_buffer(self.batch_size)

        reward = T.tensor(reward, dtype=T.float).to(self.actor.device)
        done = T.tensor(done, dtype=T.float).to(self.actor.device)
        state_ = T.tensor(next_state, dtype=T.float).to(self.actor.device)
        action = T.tensor(action, dtype=T.float).to(self.actor.device)
        state = T.tensor(state, dtype=T.float).to(self.actor.device)
        state.requires_grad = True

        with T.no_grad():
            action_, probs, log_probs, max_action = self.choose_actions(state_, learn_mode=True)
            qf1_target_ = self.target_critic_1(state_)
            qf2_target_ = self.target_critic_2(state_)
            min_qf_target_ = probs*(T.min(qf1_target_, qf2_target_) - self.ent_alpha*log_probs)
            min_qf_target_ = min_qf_target_.sum(dim=1).view(-1)
            next_q_value = reward + (1.0-done)*self.gamma*min_qf_target_

        self.actor.optimizer.zero_grad()
        qf1_pi = self.critic_1(state)
        qf2_pi = self.critic_2(state)
        min_qf_pi = T.min(qf1_pi, qf2_pi)
        inside_term = self.ent_alpha*log_probs - min_qf_pi
        actor_loss = (probs*inside_term).sum(dim=1).mean()

        actor_loss.backward()
        data_grad = state.grad.data

        return data_grad.mean(axis=0)